Skip to main content

candle_mi/util/
pca.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! PCA via power iteration with deflation.
4//!
5//! Provides [`pca_top_k`] for computing the top principal components of a
6//! data matrix using only candle tensor ops (`matmul`, arithmetic,
7//! normalization).  This runs transparently on CPU or GPU with zero
8//! host↔device transfers.
9//!
10//! The algorithm works on the **kernel matrix** (`X @ X^T`), which is
11//! efficient when the number of samples `n` is much smaller than the
12//! number of features `d` (e.g., 150 × 2304 for the character-count
13//! helix experiment).
14
15use candle_core::{DType, Device, Tensor};
16
17use crate::error::Result;
18
19// ---------------------------------------------------------------------------
20// Result type
21// ---------------------------------------------------------------------------
22
23/// Result of a PCA decomposition via power iteration.
24#[derive(Debug, Clone)]
25pub struct PcaResult {
26    /// Principal component directions, shape `[k, n_features]`.
27    ///
28    /// Each row is a unit-length direction in feature space, ordered by
29    /// decreasing eigenvalue.
30    pub components: Tensor,
31
32    /// Eigenvalues of the kernel matrix, one per component.
33    pub eigenvalues: Vec<f32>,
34
35    /// Fraction of total variance explained by each component.
36    ///
37    /// Values sum to at most 1.0.  The total variance is the trace of the
38    /// (centered) kernel matrix.
39    pub explained_variance_ratio: Vec<f32>,
40}
41
42// ---------------------------------------------------------------------------
43// Public API
44// ---------------------------------------------------------------------------
45
46/// Compute the top `k` principal components of `matrix` via power iteration
47/// with deflation on the kernel matrix.
48///
49/// # Shapes
50///
51/// - `matrix`: `[n_samples, n_features]` — each row is one observation
52/// - returns [`PcaResult`] with `components` of shape `[k, n_features]`
53///
54/// # Algorithm
55///
56/// 1. Center the matrix (subtract column means).
57/// 2. Build the kernel `K = X @ X^T` — shape `[n, n]`.
58/// 3. For each of `k` components: power-iterate on `K`, extract the
59///    eigenvalue, recover the PC direction in feature space via
60///    `w = X^T @ v / ‖X^T @ v‖`, then deflate `K`.
61/// 4. Explained variance ratios are `λ_i / trace(K_original)`.
62///
63/// # Errors
64///
65/// Returns [`MIError::Model`](crate::MIError::Model) if any tensor operation fails (shape
66/// mismatch, device error, etc.).
67pub fn pca_top_k(matrix: &Tensor, k: usize, n_iter: usize) -> Result<PcaResult> {
68    let device = matrix.device();
69    let (n, _d) = matrix.dims2()?;
70
71    // 1. Center: subtract column means
72    // CAST: usize → f64, n is small (≤ 150 for helix); exact in f64 mantissa
73    #[allow(clippy::as_conversions, clippy::cast_precision_loss)]
74    let mean = (matrix.sum(0)? / (n as f64))?; // [d]
75    let centered = matrix.broadcast_sub(&mean)?; // [n, d]
76
77    // 2. Kernel matrix K = X @ X^T  →  [n, n]
78    // CONTIGUOUS: transpose produces non-unit strides; matmul requires contiguous layout
79    let centered_t = centered.t()?.contiguous()?; // [d, n]
80    let k_original = centered.matmul(&centered_t)?; // [n, n]
81    let mut k_mat = k_original.copy()?;
82
83    // Total variance = trace(K) — sum of diagonal elements
84    let trace = trace_2d(&k_original, n)?;
85
86    let mut eigenvalues = Vec::with_capacity(k);
87    let mut components = Vec::with_capacity(k); // in feature space [d]
88
89    // 3. Power iteration + deflation
90    for _ in 0..k {
91        let v = power_iterate(&k_mat, n, n_iter, device)?; // [n]
92
93        // Eigenvalue: λ = v^T K v
94        let kv = k_mat.matmul(&v.unsqueeze(1)?)?; // [n, 1]
95        let lambda_t = v.unsqueeze(0)?.matmul(&kv)?; // [1, 1]
96        // PROMOTE: extract eigenvalue as F32 for precision
97        let lambda: f32 = lambda_t
98            .squeeze(0)?
99            .squeeze(0)?
100            .to_dtype(DType::F32)?
101            .to_scalar()?;
102
103        // Recover PC direction in feature space: w = X^T @ v, normalise
104        let w = centered_t.matmul(&v.unsqueeze(1)?)?.squeeze(1)?; // [d]
105        let w_norm = w.sqr()?.sum_all()?.sqrt()?;
106        let w_unit = w.broadcast_div(&w_norm)?; // [d]
107
108        // Deflate: K ← K − λ v v^T
109        let vvt = v.unsqueeze(1)?.matmul(&v.unsqueeze(0)?)?; // [n, n]
110        let lambda_f64 = f64::from(lambda);
111        k_mat = (k_mat - (vvt * lambda_f64)?)?;
112
113        eigenvalues.push(lambda);
114        components.push(w_unit);
115    }
116
117    // 4. Explained variance ratios
118    let explained_variance_ratio: Vec<f32> = eigenvalues
119        .iter()
120        .map(|&lam| if trace > 0.0 { lam / trace } else { 0.0 })
121        .collect();
122
123    // Stack components: [k, d]
124    let comp_refs: Vec<&Tensor> = components.iter().collect();
125    let stacked = Tensor::stack(&comp_refs, 0)?;
126
127    Ok(PcaResult {
128        components: stacked,
129        eigenvalues,
130        explained_variance_ratio,
131    })
132}
133
134// ---------------------------------------------------------------------------
135// Helpers
136// ---------------------------------------------------------------------------
137
138/// Run `n_iter` rounds of power iteration on `mat` to find the dominant
139/// eigenvector.  Returns a unit-length vector of shape `[n]`.
140fn power_iterate(mat: &Tensor, n: usize, n_iter: usize, device: &Device) -> Result<Tensor> {
141    // Initialise with a random unit vector
142    let mut v = Tensor::randn(0.0_f32, 1.0, (n,), device)?;
143    let v_norm = v.sqr()?.sum_all()?.sqrt()?;
144    v = v.broadcast_div(&v_norm)?;
145
146    for _ in 0..n_iter {
147        // v ← K v / ‖K v‖
148        let kv = mat.matmul(&v.unsqueeze(1)?)?.squeeze(1)?; // [n]
149        let norm = kv.sqr()?.sum_all()?.sqrt()?;
150        v = kv.broadcast_div(&norm)?;
151    }
152
153    Ok(v)
154}
155
156/// Compute the trace of a 2-D square tensor as an `f32` by extracting
157/// diagonal elements one by one.
158///
159/// Candle 0.9 has no `diagonal()` method, so we narrow each row to
160/// a single element and sum.
161fn trace_2d(mat: &Tensor, n: usize) -> Result<f32> {
162    let mut sum = 0.0_f32;
163    for i in 0..n {
164        // Extract element [i, i]: narrow row i, then narrow col i
165        // PROMOTE: extract as F32 for accumulation precision
166        let val: f32 = mat
167            .narrow(0, i, 1)?
168            .narrow(1, i, 1)?
169            .squeeze(0)?
170            .squeeze(0)?
171            .to_dtype(DType::F32)?
172            .to_scalar()?;
173        sum += val;
174    }
175    Ok(sum)
176}
177
178// ---------------------------------------------------------------------------
179// Tests
180// ---------------------------------------------------------------------------
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    /// Smoke test: PCA on a tiny matrix recovers the dominant direction.
187    #[test]
188    fn pca_smoke() -> Result<()> {
189        // 5 points along (1, 0) + small noise in dim 2
190        let data = Tensor::new(
191            &[
192                [1.0_f32, 0.0],
193                [2.0, 0.1],
194                [3.0, -0.1],
195                [4.0, 0.05],
196                [5.0, -0.05],
197            ],
198            &Device::Cpu,
199        )?;
200
201        let result = pca_top_k(&data, 2, 50)?;
202
203        // PC1 should capture most variance (> 99%)
204        assert!(
205            result.explained_variance_ratio[0] > 0.99,
206            "PC1 variance ratio {:.4} should be > 0.99",
207            result.explained_variance_ratio[0],
208        );
209
210        // PC1 direction should be close to (1, 0) or (-1, 0)
211        let pc1: Vec<f32> = result.components.get(0)?.to_vec1()?;
212        assert!(
213            pc1[0].abs() > 0.99,
214            "PC1[0] = {:.4}, expected close to ±1.0",
215            pc1[0],
216        );
217
218        // Two components should sum to ~1.0
219        let total: f32 = result.explained_variance_ratio.iter().sum();
220        assert!(
221            (total - 1.0).abs() < 0.01,
222            "Total variance {total:.4} should be ~1.0",
223        );
224
225        Ok(())
226    }
227}