use scirs2_core::ndarray::{Array1, Array2, Axis};
use scirs2_linalg::solve_linear_system;
use crate::error::{Result, TransformError};
const EPS: f64 = 1e-12;
fn invert_small(mat: &Array2<f64>) -> Result<Array2<f64>> {
let k = mat.nrows();
if k != mat.ncols() {
return Err(TransformError::InvalidInput(
"invert_small: matrix must be square".into(),
));
}
let mut aug = Array2::<f64>::zeros((k, 2 * k));
for i in 0..k {
for j in 0..k {
aug[[i, j]] = mat[[i, j]];
}
aug[[i, k + i]] = 1.0;
}
for col in 0..k {
let mut pivot_row = col;
let mut max_val = aug[[col, col]].abs();
for row in (col + 1)..k {
let v = aug[[row, col]].abs();
if v > max_val {
max_val = v;
pivot_row = row;
}
}
if max_val < EPS {
return Err(TransformError::ComputationError(
"invert_small: singular matrix".into(),
));
}
if pivot_row != col {
for j in 0..(2 * k) {
let tmp = aug[[col, j]];
aug[[col, j]] = aug[[pivot_row, j]];
aug[[pivot_row, j]] = tmp;
}
}
let diag = aug[[col, col]];
for j in 0..(2 * k) {
aug[[col, j]] /= diag;
}
for row in 0..k {
if row == col {
continue;
}
let factor = aug[[row, col]];
for j in 0..(2 * k) {
let v = aug[[col, j]] * factor;
aug[[row, j]] -= v;
}
}
}
let mut inv = Array2::<f64>::zeros((k, k));
for i in 0..k {
for j in 0..k {
inv[[i, j]] = aug[[i, k + j]];
}
}
Ok(inv)
}
fn mat_mul_at_b(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let p = a.nrows();
let q = a.ncols();
let r = b.ncols();
assert_eq!(b.nrows(), p);
let mut out = Array2::<f64>::zeros((q, r));
for i in 0..q {
for j in 0..r {
let mut s = 0.0;
for k in 0..p {
s += a[[k, i]] * b[[k, j]];
}
out[[i, j]] = s;
}
}
out
}
fn mat_mul(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let m = a.nrows();
let k = a.ncols();
let n = b.ncols();
assert_eq!(b.nrows(), k);
let mut out = Array2::<f64>::zeros((m, n));
for i in 0..m {
for l in 0..k {
for j in 0..n {
out[[i, j]] += a[[i, l]] * b[[l, j]];
}
}
}
out
}
#[derive(Debug, Clone)]
pub struct PPCAModel {
pub w: Array2<f64>,
pub sigma2: f64,
pub mean: Array1<f64>,
pub n_iter: usize,
pub log_likelihood: f64,
}
impl PPCAModel {
fn m_matrix(&self) -> Array2<f64> {
let q = self.w.ncols();
let wt_w = mat_mul_at_b(&self.w, &self.w);
let mut m = wt_w;
for i in 0..q {
m[[i, i]] += self.sigma2;
}
m
}
pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
let (n, p) = (x.nrows(), x.ncols());
if p != self.w.nrows() {
return Err(TransformError::DimensionMismatch(format!(
"Expected {p} features, got {p}"
)));
}
let m = self.m_matrix();
let m_inv = invert_small(&m)?;
let posterior_scale = mat_mul(&m_inv, &mat_mul_at_b(&self.w, &Array2::eye(p)));
let mut xc = x.to_owned();
for i in 0..n {
for j in 0..p {
xc[[i, j]] -= self.mean[j];
}
}
let wt = posterior_scale; let mut z = Array2::<f64>::zeros((n, wt.nrows()));
for i in 0..n {
for j in 0..wt.nrows() {
let mut s = 0.0;
for l in 0..p {
s += xc[[i, l]] * wt[[j, l]];
}
z[[i, j]] = s;
}
}
Ok(z)
}
pub fn inverse_transform(&self, z: &Array2<f64>) -> Result<Array2<f64>> {
let (n, q) = (z.nrows(), z.ncols());
if q != self.w.ncols() {
return Err(TransformError::DimensionMismatch(format!(
"Expected {q} latent dims, got {q}"
)));
}
let p = self.w.nrows();
let mut x_hat = Array2::<f64>::zeros((n, p));
for i in 0..n {
for j in 0..p {
let mut s = 0.0;
for l in 0..q {
s += z[[i, l]] * self.w[[j, l]];
}
x_hat[[i, j]] = s + self.mean[j];
}
}
Ok(x_hat)
}
pub fn log_likelihood(&self, x: &Array2<f64>) -> Result<f64> {
let (n, p) = (x.nrows(), x.ncols());
if p != self.w.nrows() {
return Err(TransformError::DimensionMismatch(
"Feature dimension mismatch".into(),
));
}
let q = self.w.ncols();
let sigma2 = self.sigma2.max(EPS);
let m = self.m_matrix();
let m_inv = invert_small(&m)?;
let log_det_m = log_det_small(&m)?;
let log_det_c = (p - q) as f64 * sigma2.ln() + log_det_m;
let wm_inv = mat_mul(&self.w, &m_inv);
let log2pi = (2.0 * std::f64::consts::PI).ln();
let const_term = -0.5 * (p as f64 * log2pi + log_det_c);
let mut total_ll = const_term * n as f64;
for i in 0..n {
let mut diff = Array1::<f64>::zeros(p);
for j in 0..p {
diff[j] = x[[i, j]] - self.mean[j];
}
let term1: f64 = diff.iter().map(|v| v * v).sum::<f64>() / sigma2;
let mut tmp = vec![0.0f64; q];
for l in 0..q {
for j in 0..p {
tmp[l] += wm_inv[[j, l]] * diff[j];
}
}
let term2: f64 = tmp.iter().zip(tmp.iter()).map(|(a, b)| a * b).sum::<f64>() / sigma2 / sigma2;
total_ll -= 0.5 * (term1 - term2);
}
Ok(total_ll)
}
pub fn impute_missing(
&self,
x: &Array2<f64>,
missing_mask: &Array2<bool>,
) -> Result<Array2<f64>> {
let (n, p) = (x.nrows(), x.ncols());
if p != self.w.nrows() {
return Err(TransformError::DimensionMismatch(
"Feature dimension mismatch".into(),
));
}
if missing_mask.shape() != [n, p] {
return Err(TransformError::DimensionMismatch(
"missing_mask shape must match X".into(),
));
}
let m = self.m_matrix();
let m_inv = invert_small(&m)?;
let q = self.w.ncols();
let mut x_imputed = x.to_owned();
for i in 0..n {
let obs: Vec<usize> = (0..p).filter(|&j| !missing_mask[[i, j]]).collect();
let miss: Vec<usize> = (0..p).filter(|&j| missing_mask[[i, j]]).collect();
if miss.is_empty() {
continue;
}
if obs.is_empty() {
for &j in &miss {
x_imputed[[i, j]] = self.mean[j];
}
continue;
}
let n_obs = obs.len();
let mut w_obs = Array2::<f64>::zeros((n_obs, q));
let mut x_obs_c = Array1::<f64>::zeros(n_obs);
for (ii, &j) in obs.iter().enumerate() {
for l in 0..q {
w_obs[[ii, l]] = self.w[[j, l]];
}
x_obs_c[ii] = x[[i, j]] - self.mean[j];
}
let wot_wo = mat_mul_at_b(&w_obs, &w_obs);
let mut m_obs = wot_wo;
for l in 0..q {
m_obs[[l, l]] += self.sigma2;
}
let m_obs_inv = invert_small(&m_obs)?;
let mut ez = vec![0.0f64; q];
for l in 0..q {
for (ii, _) in obs.iter().enumerate() {
ez[l] += m_obs_inv[[l, 0]] * 0.0; }
}
let mut wot_xc = vec![0.0f64; q];
for l in 0..q {
for (ii, _) in obs.iter().enumerate() {
wot_xc[l] += w_obs[[ii, l]] * x_obs_c[ii];
}
}
let mut ez = vec![0.0f64; q];
for l in 0..q {
for l2 in 0..q {
ez[l] += m_obs_inv[[l, l2]] * wot_xc[l2];
}
}
for &j in &miss {
let mut val = self.mean[j];
for l in 0..q {
val += self.w[[j, l]] * ez[l];
}
x_imputed[[i, j]] = val;
}
}
Ok(x_imputed)
}
}
fn log_det_small(mat: &Array2<f64>) -> Result<f64> {
let k = mat.nrows();
let mut a = mat.to_owned();
let mut sign = 1.0f64;
let mut log_det = 0.0f64;
for col in 0..k {
let mut max_val = a[[col, col]].abs();
let mut max_row = col;
for row in (col + 1)..k {
if a[[row, col]].abs() > max_val {
max_val = a[[row, col]].abs();
max_row = row;
}
}
if max_val < EPS {
return Ok(f64::NEG_INFINITY);
}
if max_row != col {
for j in 0..k {
let tmp = a[[col, j]];
a[[col, j]] = a[[max_row, j]];
a[[max_row, j]] = tmp;
}
sign = -sign;
}
log_det += a[[col, col]].abs().ln();
let pivot = a[[col, col]];
for row in (col + 1)..k {
let factor = a[[row, col]] / pivot;
for j in col..k {
let v = a[[col, j]] * factor;
a[[row, j]] -= v;
}
}
}
Ok(log_det)
}
#[derive(Debug, Clone)]
pub struct PPCAConfig {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
pub seed: u64,
}
impl Default for PPCAConfig {
fn default() -> Self {
Self {
n_components: 2,
max_iter: 200,
tol: 1e-6,
seed: 42,
}
}
}
pub fn fit_em(x: &Array2<f64>, config: &PPCAConfig) -> Result<PPCAModel> {
let (n, p) = (x.nrows(), x.ncols());
let q = config.n_components;
if n < 2 {
return Err(TransformError::InvalidInput(
"PPCA requires at least 2 samples".into(),
));
}
if q == 0 || q >= p {
return Err(TransformError::InvalidInput(format!(
"n_components must be in 1..{p}, got {q}"
)));
}
let mean = x.mean_axis(Axis(0)).ok_or_else(|| {
TransformError::ComputationError("Failed to compute mean".into())
})?;
let mut xc = x.to_owned();
for i in 0..n {
for j in 0..p {
xc[[i, j]] -= mean[j];
}
}
let s_raw = mat_mul_at_b(&xc, &xc);
let mut s = s_raw;
for i in 0..p {
for j in 0..p {
s[[i, j]] /= n as f64;
}
}
let mut rng = scirs2_core::random::rng();
let mut w = Array2::<f64>::zeros((p, q));
for i in 0..p {
for j in 0..q {
w[[i, j]] = rng.gen_range(-0.1..0.1);
}
}
let trace_s: f64 = (0..p).map(|i| s[[i, i]]).sum::<f64>();
let mut sigma2 = trace_s / (p as f64 * 2.0).max(EPS);
let mut prev_ll = f64::NEG_INFINITY;
let mut final_iter = 0usize;
for iter in 0..config.max_iter {
let wt_w = mat_mul_at_b(&w, &w);
let mut m = wt_w;
for i in 0..q {
m[[i, i]] += sigma2;
}
let m_inv = invert_small(&m)?;
let sw = mat_mul(&s, &w);
let sw_m_inv = mat_mul(&sw, &m_inv);
let wt_s = mat_mul_at_b(&w, &s); let wt_s_w = mat_mul(&wt_s, &w); let m_inv_wt_s_w = mat_mul(&m_inv, &wt_s_w); let m_inv_wt_s_w_m_inv = mat_mul(&m_inv_wt_s_w, &m_inv);
let mut ezzt = m_inv_wt_s_w_m_inv;
for i in 0..q {
for j in 0..q {
ezzt[[i, j]] += sigma2 * m_inv[[i, j]];
}
}
let ezzt_inv = invert_small(&ezzt)?;
let w_new = mat_mul(&sw_m_inv, &ezzt_inv);
let sw_m_inv_wt_new = mat_mul(&sw_m_inv, &mat_mul_at_b(&w_new, &Array2::eye(q)));
let trace_sw: f64 = (0..p)
.map(|i| {
let mut s_row = 0.0;
for l in 0..q {
s_row += sw_m_inv[[i, l]] * w_new[[i, l]];
}
s_row
})
.sum();
let sigma2_new = ((trace_s - trace_sw) / p as f64).max(EPS);
w = w_new;
sigma2 = sigma2_new;
if iter % 5 == 0 || iter == config.max_iter - 1 {
let tmp_model = PPCAModel {
w: w.clone(),
sigma2,
mean: mean.clone(),
n_iter: iter + 1,
log_likelihood: 0.0,
};
let ll = tmp_model.log_likelihood(x).unwrap_or(f64::NEG_INFINITY);
let delta = (ll - prev_ll).abs();
if iter > 0 && delta < config.tol {
final_iter = iter + 1;
prev_ll = ll;
break;
}
prev_ll = ll;
}
final_iter = iter + 1;
}
Ok(PPCAModel {
w,
sigma2,
mean,
n_iter: final_iter,
log_likelihood: prev_ll,
})
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn make_low_rank_data(n: usize, p: usize, q: usize, noise: f64) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let mut z = Array2::<f64>::zeros((n, q));
for i in 0..n {
for j in 0..q {
z[[i, j]] = rng.gen_range(-2.0..2.0);
}
}
let mut w = Array2::<f64>::zeros((p, q));
for i in 0..p {
for j in 0..q {
w[[i, j]] = rng.gen_range(-1.0..1.0);
}
}
let mut x = mat_mul(&z, &mat_mul_at_b(&w, &Array2::eye(q)));
for i in 0..n {
for j in 0..p {
x[[i, j]] += noise * rng.gen_range(-1.0..1.0);
}
}
x
}
#[test]
fn test_ppca_fit_basic() {
let x = make_low_rank_data(50, 10, 2, 0.1);
let config = PPCAConfig {
n_components: 2,
max_iter: 50,
tol: 1e-4,
seed: 0,
};
let model = fit_em(&x, &config).expect("PPCA fit failed");
assert_eq!(model.w.shape(), &[10, 2]);
assert!(model.sigma2 > 0.0);
assert!(model.log_likelihood.is_finite() || model.log_likelihood == f64::NEG_INFINITY);
}
#[test]
fn test_ppca_transform_inverse() {
let x = make_low_rank_data(40, 8, 2, 0.05);
let config = PPCAConfig {
n_components: 2,
max_iter: 50,
tol: 1e-4,
seed: 1,
};
let model = fit_em(&x, &config).expect("fit failed");
let z = model.transform(&x).expect("transform failed");
assert_eq!(z.shape(), &[40, 2]);
let x_hat = model.inverse_transform(&z).expect("inverse_transform failed");
assert_eq!(x_hat.shape(), &[40, 8]);
let err: f64 = x
.iter()
.zip(x_hat.iter())
.map(|(a, b)| (a - b).powi(2))
.sum::<f64>()
/ (40.0 * 8.0);
assert!(err < 5.0, "Reconstruction error {err} too large");
}
#[test]
fn test_ppca_impute_missing() {
let x = make_low_rank_data(30, 6, 2, 0.1);
let config = PPCAConfig {
n_components: 2,
max_iter: 30,
tol: 1e-4,
seed: 2,
};
let model = fit_em(&x, &config).expect("fit failed");
let mut missing = Array2::<bool>::from_elem((30, 6), false);
missing[[0, 0]] = true;
missing[[1, 2]] = true;
let x_imp = model.impute_missing(&x, &missing).expect("impute failed");
assert!(x_imp[[0, 0]].is_finite());
assert!(x_imp[[1, 2]].is_finite());
for i in 0..30 {
for j in 0..6 {
if !missing[[i, j]] {
assert_eq!(x_imp[[i, j]], x[[i, j]]);
}
}
}
}
#[test]
fn test_ppca_log_likelihood() {
let x = make_low_rank_data(30, 6, 2, 0.2);
let config = PPCAConfig {
n_components: 2,
max_iter: 30,
tol: 1e-4,
seed: 3,
};
let model = fit_em(&x, &config).expect("fit failed");
let ll = model.log_likelihood(&x).expect("ll failed");
assert!(ll.is_finite() || ll == f64::NEG_INFINITY);
}
#[test]
fn test_ppca_invalid_inputs() {
let x = Array2::<f64>::zeros((10, 5));
let bad_config = PPCAConfig {
n_components: 5, ..Default::default()
};
assert!(fit_em(&x, &bad_config).is_err());
let bad_config2 = PPCAConfig {
n_components: 0,
..Default::default()
};
assert!(fit_em(&x, &bad_config2).is_err());
}
}