use crate::error::{Result, TransformError};
use scirs2_core::ndarray::{Array1, Array2, Axis};
const EPS: f64 = 1e-12;
fn invert_small(mat: &Array2<f64>) -> Result<Array2<f64>> {
let k = mat.nrows();
if k != mat.ncols() {
return Err(TransformError::InvalidInput("Matrix must be square".into()));
}
let mut aug = Array2::<f64>::zeros((k, 2 * k));
for i in 0..k {
for j in 0..k {
aug[[i, j]] = mat[[i, j]];
}
aug[[i, k + i]] = 1.0;
}
for col in 0..k {
let mut max_val = aug[[col, col]].abs();
let mut max_row = col;
for row in (col + 1)..k {
if aug[[row, col]].abs() > max_val {
max_val = aug[[row, col]].abs();
max_row = row;
}
}
if max_val < EPS {
return Err(TransformError::ComputationError("Singular matrix".into()));
}
if max_row != col {
for j in 0..(2 * k) {
aug.swap([col, j], [max_row, j]);
}
}
let diag = aug[[col, col]];
for j in 0..(2 * k) {
aug[[col, j]] /= diag;
}
for row in 0..k {
if row == col {
continue;
}
let factor = aug[[row, col]];
for j in 0..(2 * k) {
let v = aug[[col, j]] * factor;
aug[[row, j]] -= v;
}
}
}
let mut inv = Array2::<f64>::zeros((k, k));
for i in 0..k {
for j in 0..k {
inv[[i, j]] = aug[[i, k + j]];
}
}
Ok(inv)
}
fn mm_atb(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let (m, k) = (a.nrows(), a.ncols());
let n = b.ncols();
assert_eq!(b.nrows(), m);
let mut c = Array2::<f64>::zeros((k, n));
for i in 0..k {
for l in 0..m {
for j in 0..n {
c[[i, j]] += a[[l, i]] * b[[l, j]];
}
}
}
c
}
fn mm(a: &Array2<f64>, b: &Array2<f64>) -> Array2<f64> {
let (m, k) = (a.nrows(), a.ncols());
let n = b.ncols();
assert_eq!(b.nrows(), k);
let mut c = Array2::<f64>::zeros((m, n));
for i in 0..m {
for l in 0..k {
for j in 0..n {
c[[i, j]] += a[[i, l]] * b[[l, j]];
}
}
}
c
}
#[derive(Debug, Clone)]
pub struct BayesianPCAConfig {
pub n_components: usize,
pub max_iter: usize,
pub tol: f64,
pub a0: f64,
pub b0: f64,
pub c0: f64,
pub d0: f64,
pub prune_threshold: f64,
}
impl Default for BayesianPCAConfig {
fn default() -> Self {
Self {
n_components: 10,
max_iter: 200,
tol: 1e-5,
a0: 1e-3,
b0: 1e-3,
c0: 1e-3,
d0: 1e-3,
prune_threshold: 1e3,
}
}
}
#[derive(Debug, Clone)]
pub struct BayesianPCA {
config: BayesianPCAConfig,
pub w_mean: Option<Array2<f64>>,
pub alpha_mean: Option<Array1<f64>>,
pub tau_mean: Option<f64>,
pub data_mean: Option<Array1<f64>>,
pub elbo_history: Vec<f64>,
pub n_iter: usize,
}
impl BayesianPCA {
pub fn new(config: BayesianPCAConfig) -> Self {
Self {
config,
w_mean: None,
alpha_mean: None,
tau_mean: None,
data_mean: None,
elbo_history: Vec::new(),
n_iter: 0,
}
}
pub fn fit_vb(&mut self, x: &Array2<f64>) -> Result<()> {
let (n, p) = (x.nrows(), x.ncols());
let q = self.config.n_components;
if n < 2 {
return Err(TransformError::InvalidInput(
"BayesianPCA requires at least 2 samples".into(),
));
}
if q == 0 || q >= p {
return Err(TransformError::InvalidInput(format!(
"n_components must be in 1..{p}, got {q}"
)));
}
let mu = x.mean_axis(Axis(0)).ok_or_else(|| {
TransformError::ComputationError("Failed to compute mean".into())
})?;
let mut xc = x.to_owned();
for i in 0..n {
for j in 0..p {
xc[[i, j]] -= mu[j];
}
}
let xt_x = mm_atb(&xc, &xc); let s: Array2<f64> = {
let mut tmp = xt_x.clone();
for i in 0..p {
for j in 0..p {
tmp[[i, j]] /= n as f64;
}
}
tmp
};
let trace_s: f64 = (0..p).map(|i| s[[i, i]]).sum();
let mut rng = scirs2_core::random::rng();
let mut w = Array2::<f64>::zeros((p, q));
for i in 0..p {
for j in 0..q {
w[[i, j]] = 0.01 * (scirs2_core::random::Rng::gen_range(&mut rng, -1.0..1.0_f64));
}
}
let mut alpha = Array1::<f64>::from_elem(q, 1.0_f64);
let mut tau = (p as f64) / (trace_s + EPS);
let mut elbo_history: Vec<f64> = Vec::new();
let mut prev_elbo = f64::NEG_INFINITY;
let mut final_iter = 0usize;
for iter in 0..self.config.max_iter {
let wt_w = mm_atb(&w, &w); let mut sigma_z_inv = wt_w.clone();
for i in 0..q {
for j in 0..q {
sigma_z_inv[[i, j]] *= tau;
}
sigma_z_inv[[i, i]] += 1.0;
}
let sigma_z = invert_small(&sigma_z_inv)?;
let xc_w = mm(&xc, &w); let z_mean = {
let xc_w_sigma = mm(&xc_w, &sigma_z); let mut tmp = xc_w_sigma;
for i in 0..n {
for j in 0..q {
tmp[[i, j]] *= tau;
}
}
tmp
};
let mut e_zzt = mm_atb(&z_mean, &z_mean); for i in 0..q {
for j in 0..q {
e_zzt[[i, j]] += n as f64 * sigma_z[[i, j]];
}
}
let xt_z = mm_atb(&xc, &z_mean); let mut lhs = e_zzt.clone(); for k in 0..q {
lhs[[k, k]] += alpha[k] / (tau + EPS);
}
let lhs_inv = invert_small(&lhs)?;
w = mm(&xt_z, &lhs_inv);
let sigma_w = invert_small(&lhs)?; let sigma_w = lhs_inv;
for k in 0..q {
let w_k_norm_sq: f64 = (0..p).map(|i| w[[i, k]] * w[[i, k]]).sum();
let sigma_kk = sigma_w[[k, k]];
let e_w_k_sq = w_k_norm_sq + p as f64 * sigma_kk;
let a_k = self.config.a0 + 0.5 * p as f64;
let b_k = self.config.b0 + 0.5 * e_w_k_sq;
alpha[k] = a_k / b_k.max(EPS);
}
let c_new = self.config.c0 + 0.5 * (n * p) as f64;
let trace_wt_xt_z: f64 = {
let xz = mm_atb(&xc, &z_mean); (0..p).map(|i| (0..q).map(|j| w[[i, j]] * xt_z[[i, j]]).sum::<f64>()).sum()
};
let wt_w_now = mm_atb(&w, &w); let trace_ezztww: f64 = (0..q)
.map(|i| (0..q).map(|j| e_zzt[[i, j]] * wt_w_now[[j, i]]).sum::<f64>())
.sum();
let d_new = self.config.d0
+ 0.5 * (n as f64 * trace_s - 2.0 * trace_wt_xt_z + trace_ezztww);
tau = c_new / d_new.max(EPS);
let elbo = self.compute_elbo(
n, p, tau, &alpha, &w, trace_s, &e_zzt, trace_wt_xt_z, &sigma_z, &sigma_w,
);
elbo_history.push(elbo);
final_iter = iter + 1;
let delta = (elbo - prev_elbo).abs();
if iter > 0 && delta < self.config.tol {
prev_elbo = elbo;
break;
}
prev_elbo = elbo;
}
self.w_mean = Some(w);
self.alpha_mean = Some(alpha);
self.tau_mean = Some(tau);
self.data_mean = Some(mu);
self.elbo_history = elbo_history;
self.n_iter = final_iter;
Ok(())
}
pub fn effective_rank(&self) -> usize {
let alpha = match &self.alpha_mean {
Some(a) => a,
None => return 0,
};
alpha
.iter()
.filter(|&&a| a < self.config.prune_threshold)
.count()
}
pub fn active_components(&self) -> Vec<usize> {
let alpha = match &self.alpha_mean {
Some(a) => a,
None => return Vec::new(),
};
alpha
.iter()
.enumerate()
.filter(|(_, &a)| a < self.config.prune_threshold)
.map(|(i, _)| i)
.collect()
}
pub fn prune_irrelevant_components(&self) -> Option<Array2<f64>> {
let w = self.w_mean.as_ref()?;
let active = self.active_components();
if active.is_empty() {
return None;
}
let p = w.nrows();
let k_eff = active.len();
let mut w_pruned = Array2::<f64>::zeros((p, k_eff));
for (new_k, &old_k) in active.iter().enumerate() {
for i in 0..p {
w_pruned[[i, new_k]] = w[[i, old_k]];
}
}
Some(w_pruned)
}
pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
let w_full = self
.w_mean
.as_ref()
.ok_or_else(|| TransformError::NotFitted("BayesianPCA not fitted".into()))?;
let mu = self
.data_mean
.as_ref()
.ok_or_else(|| TransformError::NotFitted("BayesianPCA not fitted".into()))?;
let tau = self
.tau_mean
.ok_or_else(|| TransformError::NotFitted("BayesianPCA not fitted".into()))?;
let w_pruned = self
.prune_irrelevant_components()
.unwrap_or_else(|| w_full.clone());
let (n, p) = (x.nrows(), x.ncols());
if p != w_pruned.nrows() {
return Err(TransformError::DimensionMismatch("Feature dim mismatch".into()));
}
let q_eff = w_pruned.ncols();
let mut xc = x.to_owned();
for i in 0..n {
for j in 0..p {
xc[[i, j]] -= mu[j];
}
}
let wt_w = mm_atb(&w_pruned, &w_pruned); let mut sigma_z_inv = wt_w;
for i in 0..q_eff {
for j in 0..q_eff {
sigma_z_inv[[i, j]] *= tau;
}
sigma_z_inv[[i, i]] += 1.0;
}
let sigma_z = invert_small(&sigma_z_inv)?;
let xc_w = mm(&xc, &w_pruned); let mut z = mm(&xc_w, &sigma_z); for i in 0..n {
for j in 0..q_eff {
z[[i, j]] *= tau;
}
}
Ok(z)
}
#[allow(clippy::too_many_arguments)]
fn compute_elbo(
&self,
n: usize,
p: usize,
tau: f64,
alpha: &Array1<f64>,
w: &Array2<f64>,
trace_s: f64,
e_zzt: &Array2<f64>,
trace_wt_xt_z: f64,
sigma_z: &Array2<f64>,
sigma_w: &Array2<f64>,
) -> f64 {
let q = w.ncols();
let tau_safe = tau.max(EPS);
let wt_w = mm_atb(w, w);
let trace_ezztww: f64 = (0..q)
.map(|i| (0..q).map(|j| e_zzt[[i, j]] * wt_w[[j, i]]).sum::<f64>())
.sum();
let e_recon = n as f64 * trace_s - 2.0 * trace_wt_xt_z + trace_ezztww;
let ll_term = 0.5 * n as f64 * p as f64 * (tau_safe.ln() - std::f64::consts::LN_2 - std::f64::consts::PI.ln())
- 0.5 * tau_safe * e_recon;
let log_det_sz = log_det_small_safe(sigma_z);
let kl_z = -0.5 * n as f64 * (log_det_sz + q as f64);
let alpha_reg: f64 = alpha
.iter()
.zip(0..q)
.map(|(&ak, k)| {
let w_k_sq: f64 = (0..p).map(|i| w[[i, k]] * w[[i, k]]).sum();
0.5 * (p as f64 * ak.max(EPS).ln() - ak * (w_k_sq + p as f64 * sigma_w[[k, k]]))
})
.sum();
ll_term - kl_z + alpha_reg
}
}
fn log_det_small_safe(a: &Array2<f64>) -> f64 {
let k = a.nrows().min(a.ncols());
(0..k).map(|i| a[[i, i]].abs().max(EPS).ln()).sum()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn make_low_rank(n: usize, p: usize, q: usize, noise: f64) -> Array2<f64> {
let mut rng = scirs2_core::random::rng();
let mut z = Array2::<f64>::zeros((n, q));
let mut w = Array2::<f64>::zeros((p, q));
for i in 0..n {
for j in 0..q {
z[[i, j]] = scirs2_core::random::Rng::gen_range(&mut rng, -2.0..2.0_f64);
}
}
for i in 0..p {
for j in 0..q {
w[[i, j]] = scirs2_core::random::Rng::gen_range(&mut rng, -1.0..1.0_f64);
}
}
let xc = mm(&z, &{
let mut wt = Array2::<f64>::zeros((q, p));
for i in 0..p {
for j in 0..q {
wt[[j, i]] = w[[i, j]];
}
}
wt
});
let mut x = xc;
for i in 0..n {
for j in 0..p {
x[[i, j]] += noise * scirs2_core::random::Rng::gen_range(&mut rng, -1.0..1.0_f64);
}
}
x
}
#[test]
fn test_bayesian_pca_fit() {
let x = make_low_rank(50, 10, 2, 0.1);
let config = BayesianPCAConfig {
n_components: 5,
max_iter: 50,
tol: 1e-4,
..Default::default()
};
let mut model = BayesianPCA::new(config);
model.fit_vb(&x).expect("BayesianPCA fit failed");
let w = model.w_mean.as_ref().expect("w_mean missing");
assert_eq!(w.shape(), &[10, 5]);
assert!(model.tau_mean.expect("tau_mean should be set after fit") > 0.0);
assert!(!model.alpha_mean.as_ref().expect("alpha_mean should be set after fit").iter().any(|v| !v.is_finite()));
}
#[test]
fn test_bayesian_pca_effective_rank() {
let x = make_low_rank(60, 12, 2, 0.05);
let config = BayesianPCAConfig {
n_components: 8,
max_iter: 100,
tol: 1e-5,
prune_threshold: 1e3,
..Default::default()
};
let mut model = BayesianPCA::new(config);
model.fit_vb(&x).expect("fit failed");
let rank = model.effective_rank();
assert!(rank > 0 && rank <= 8, "effective_rank={rank}");
}
#[test]
fn test_bayesian_pca_prune() {
let x = make_low_rank(40, 8, 2, 0.1);
let config = BayesianPCAConfig {
n_components: 6,
max_iter: 80,
tol: 1e-5,
prune_threshold: 100.0,
..Default::default()
};
let mut model = BayesianPCA::new(config);
model.fit_vb(&x).expect("fit failed");
if let Some(w_p) = model.prune_irrelevant_components() {
assert!(w_p.ncols() <= 6);
assert_eq!(w_p.nrows(), 8);
}
}
#[test]
fn test_bayesian_pca_transform() {
let x = make_low_rank(40, 10, 2, 0.1);
let config = BayesianPCAConfig {
n_components: 4,
max_iter: 50,
tol: 1e-4,
..Default::default()
};
let mut model = BayesianPCA::new(config);
model.fit_vb(&x).expect("fit failed");
let z = model.transform(&x).expect("transform failed");
assert_eq!(z.nrows(), 40);
assert!(z.iter().all(|v| v.is_finite()));
}
#[test]
fn test_bayesian_pca_invalid() {
let x = Array2::<f64>::zeros((10, 5));
let config = BayesianPCAConfig {
n_components: 5, ..Default::default()
};
let mut model = BayesianPCA::new(config);
assert!(model.fit_vb(&x).is_err());
}
}