use candle_core::{DType, Device, Tensor};
use crate::error::Result;
#[derive(Debug, Clone)]
pub struct PcaResult {
pub components: Tensor,
pub eigenvalues: Vec<f32>,
pub explained_variance_ratio: Vec<f32>,
}
pub fn pca_top_k(matrix: &Tensor, k: usize, n_iter: usize) -> Result<PcaResult> {
let device = matrix.device();
let (n, _d) = matrix.dims2()?;
#[allow(clippy::as_conversions, clippy::cast_precision_loss)]
let mean = (matrix.sum(0)? / (n as f64))?; let centered = matrix.broadcast_sub(&mean)?;
let centered_t = centered.t()?.contiguous()?; let k_original = centered.matmul(¢ered_t)?; let mut k_mat = k_original.copy()?;
let trace = trace_2d(&k_original, n)?;
let mut eigenvalues = Vec::with_capacity(k);
let mut components = Vec::with_capacity(k);
for _ in 0..k {
let v = power_iterate(&k_mat, n, n_iter, device)?;
let kv = k_mat.matmul(&v.unsqueeze(1)?)?; let lambda_t = v.unsqueeze(0)?.matmul(&kv)?; let lambda: f32 = lambda_t
.squeeze(0)?
.squeeze(0)?
.to_dtype(DType::F32)?
.to_scalar()?;
let w = centered_t.matmul(&v.unsqueeze(1)?)?.squeeze(1)?; let w_norm = w.sqr()?.sum_all()?.sqrt()?;
let w_unit = w.broadcast_div(&w_norm)?;
let vvt = v.unsqueeze(1)?.matmul(&v.unsqueeze(0)?)?; let lambda_f64 = f64::from(lambda);
k_mat = (k_mat - (vvt * lambda_f64)?)?;
eigenvalues.push(lambda);
components.push(w_unit);
}
let explained_variance_ratio: Vec<f32> = eigenvalues
.iter()
.map(|&lam| if trace > 0.0 { lam / trace } else { 0.0 })
.collect();
let comp_refs: Vec<&Tensor> = components.iter().collect();
let stacked = Tensor::stack(&comp_refs, 0)?;
Ok(PcaResult {
components: stacked,
eigenvalues,
explained_variance_ratio,
})
}
fn power_iterate(mat: &Tensor, n: usize, n_iter: usize, device: &Device) -> Result<Tensor> {
let mut v = Tensor::randn(0.0_f32, 1.0, (n,), device)?;
let v_norm = v.sqr()?.sum_all()?.sqrt()?;
v = v.broadcast_div(&v_norm)?;
for _ in 0..n_iter {
let kv = mat.matmul(&v.unsqueeze(1)?)?.squeeze(1)?; let norm = kv.sqr()?.sum_all()?.sqrt()?;
v = kv.broadcast_div(&norm)?;
}
Ok(v)
}
fn trace_2d(mat: &Tensor, n: usize) -> Result<f32> {
let mut sum = 0.0_f32;
for i in 0..n {
let val: f32 = mat
.narrow(0, i, 1)?
.narrow(1, i, 1)?
.squeeze(0)?
.squeeze(0)?
.to_dtype(DType::F32)?
.to_scalar()?;
sum += val;
}
Ok(sum)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pca_smoke() -> Result<()> {
let data = Tensor::new(
&[
[1.0_f32, 0.0],
[2.0, 0.1],
[3.0, -0.1],
[4.0, 0.05],
[5.0, -0.05],
],
&Device::Cpu,
)?;
let result = pca_top_k(&data, 2, 50)?;
assert!(
result.explained_variance_ratio[0] > 0.99,
"PC1 variance ratio {:.4} should be > 0.99",
result.explained_variance_ratio[0],
);
let pc1: Vec<f32> = result.components.get(0)?.to_vec1()?;
assert!(
pc1[0].abs() > 0.99,
"PC1[0] = {:.4}, expected close to ±1.0",
pc1[0],
);
let total: f32 = result.explained_variance_ratio.iter().sum();
assert!(
(total - 1.0).abs() < 0.01,
"Total variance {total:.4} should be ~1.0",
);
Ok(())
}
}