candle_mi/util/pca.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! PCA via power iteration with deflation.
4//!
5//! Provides [`pca_top_k`] for computing the top principal components of a
6//! data matrix using only candle tensor ops (`matmul`, arithmetic,
7//! normalization). This runs transparently on CPU or GPU with zero
8//! host↔device transfers.
9//!
10//! The algorithm works on the **kernel matrix** (`X @ X^T`), which is
11//! efficient when the number of samples `n` is much smaller than the
12//! number of features `d` (e.g., 150 × 2304 for the character-count
13//! helix experiment).
14
15use candle_core::{DType, Device, Tensor};
16
17use crate::error::Result;
18
19// ---------------------------------------------------------------------------
20// Result type
21// ---------------------------------------------------------------------------
22
23/// Result of a PCA decomposition via power iteration.
24#[derive(Debug, Clone)]
25pub struct PcaResult {
26 /// Principal component directions, shape `[k, n_features]`.
27 ///
28 /// Each row is a unit-length direction in feature space, ordered by
29 /// decreasing eigenvalue.
30 pub components: Tensor,
31
32 /// Eigenvalues of the kernel matrix, one per component.
33 pub eigenvalues: Vec<f32>,
34
35 /// Fraction of total variance explained by each component.
36 ///
37 /// Values sum to at most 1.0. The total variance is the trace of the
38 /// (centered) kernel matrix.
39 pub explained_variance_ratio: Vec<f32>,
40}
41
42// ---------------------------------------------------------------------------
43// Public API
44// ---------------------------------------------------------------------------
45
46/// Compute the top `k` principal components of `matrix` via power iteration
47/// with deflation on the kernel matrix.
48///
49/// # Shapes
50///
51/// - `matrix`: `[n_samples, n_features]` — each row is one observation
52/// - returns [`PcaResult`] with `components` of shape `[k, n_features]`
53///
54/// # Algorithm
55///
56/// 1. Center the matrix (subtract column means).
57/// 2. Build the kernel `K = X @ X^T` — shape `[n, n]`.
58/// 3. For each of `k` components: power-iterate on `K`, extract the
59/// eigenvalue, recover the PC direction in feature space via
60/// `w = X^T @ v / ‖X^T @ v‖`, then deflate `K`.
61/// 4. Explained variance ratios are `λ_i / trace(K_original)`.
62///
63/// # Errors
64///
65/// Returns [`MIError::Model`](crate::MIError::Model) if any tensor operation fails (shape
66/// mismatch, device error, etc.).
67pub fn pca_top_k(matrix: &Tensor, k: usize, n_iter: usize) -> Result<PcaResult> {
68 let device = matrix.device();
69 let (n, _d) = matrix.dims2()?;
70
71 // 1. Center: subtract column means
72 // CAST: usize → f64, n is small (≤ 150 for helix); exact in f64 mantissa
73 #[allow(clippy::as_conversions, clippy::cast_precision_loss)]
74 let mean = (matrix.sum(0)? / (n as f64))?; // [d]
75 let centered = matrix.broadcast_sub(&mean)?; // [n, d]
76
77 // 2. Kernel matrix K = X @ X^T → [n, n]
78 // CONTIGUOUS: transpose produces non-unit strides; matmul requires contiguous layout
79 let centered_t = centered.t()?.contiguous()?; // [d, n]
80 let k_original = centered.matmul(¢ered_t)?; // [n, n]
81 let mut k_mat = k_original.copy()?;
82
83 // Total variance = trace(K) — sum of diagonal elements
84 let trace = trace_2d(&k_original, n)?;
85
86 let mut eigenvalues = Vec::with_capacity(k);
87 let mut components = Vec::with_capacity(k); // in feature space [d]
88
89 // 3. Power iteration + deflation
90 for _ in 0..k {
91 let v = power_iterate(&k_mat, n, n_iter, device)?; // [n]
92
93 // Eigenvalue: λ = v^T K v
94 let kv = k_mat.matmul(&v.unsqueeze(1)?)?; // [n, 1]
95 let lambda_t = v.unsqueeze(0)?.matmul(&kv)?; // [1, 1]
96 // PROMOTE: extract eigenvalue as F32 for precision
97 let lambda: f32 = lambda_t
98 .squeeze(0)?
99 .squeeze(0)?
100 .to_dtype(DType::F32)?
101 .to_scalar()?;
102
103 // Recover PC direction in feature space: w = X^T @ v, normalise
104 let w = centered_t.matmul(&v.unsqueeze(1)?)?.squeeze(1)?; // [d]
105 let w_norm = w.sqr()?.sum_all()?.sqrt()?;
106 let w_unit = w.broadcast_div(&w_norm)?; // [d]
107
108 // Deflate: K ← K − λ v v^T
109 let vvt = v.unsqueeze(1)?.matmul(&v.unsqueeze(0)?)?; // [n, n]
110 let lambda_f64 = f64::from(lambda);
111 k_mat = (k_mat - (vvt * lambda_f64)?)?;
112
113 eigenvalues.push(lambda);
114 components.push(w_unit);
115 }
116
117 // 4. Explained variance ratios
118 let explained_variance_ratio: Vec<f32> = eigenvalues
119 .iter()
120 .map(|&lam| if trace > 0.0 { lam / trace } else { 0.0 })
121 .collect();
122
123 // Stack components: [k, d]
124 let comp_refs: Vec<&Tensor> = components.iter().collect();
125 let stacked = Tensor::stack(&comp_refs, 0)?;
126
127 Ok(PcaResult {
128 components: stacked,
129 eigenvalues,
130 explained_variance_ratio,
131 })
132}
133
134// ---------------------------------------------------------------------------
135// Helpers
136// ---------------------------------------------------------------------------
137
138/// Run `n_iter` rounds of power iteration on `mat` to find the dominant
139/// eigenvector. Returns a unit-length vector of shape `[n]`.
140fn power_iterate(mat: &Tensor, n: usize, n_iter: usize, device: &Device) -> Result<Tensor> {
141 // Initialise with a random unit vector
142 let mut v = Tensor::randn(0.0_f32, 1.0, (n,), device)?;
143 let v_norm = v.sqr()?.sum_all()?.sqrt()?;
144 v = v.broadcast_div(&v_norm)?;
145
146 for _ in 0..n_iter {
147 // v ← K v / ‖K v‖
148 let kv = mat.matmul(&v.unsqueeze(1)?)?.squeeze(1)?; // [n]
149 let norm = kv.sqr()?.sum_all()?.sqrt()?;
150 v = kv.broadcast_div(&norm)?;
151 }
152
153 Ok(v)
154}
155
156/// Compute the trace of a 2-D square tensor as an `f32` by extracting
157/// diagonal elements one by one.
158///
159/// Candle 0.9 has no `diagonal()` method, so we narrow each row to
160/// a single element and sum.
161fn trace_2d(mat: &Tensor, n: usize) -> Result<f32> {
162 let mut sum = 0.0_f32;
163 for i in 0..n {
164 // Extract element [i, i]: narrow row i, then narrow col i
165 // PROMOTE: extract as F32 for accumulation precision
166 let val: f32 = mat
167 .narrow(0, i, 1)?
168 .narrow(1, i, 1)?
169 .squeeze(0)?
170 .squeeze(0)?
171 .to_dtype(DType::F32)?
172 .to_scalar()?;
173 sum += val;
174 }
175 Ok(sum)
176}
177
178// ---------------------------------------------------------------------------
179// Tests
180// ---------------------------------------------------------------------------
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 /// Smoke test: PCA on a tiny matrix recovers the dominant direction.
187 #[test]
188 fn pca_smoke() -> Result<()> {
189 // 5 points along (1, 0) + small noise in dim 2
190 let data = Tensor::new(
191 &[
192 [1.0_f32, 0.0],
193 [2.0, 0.1],
194 [3.0, -0.1],
195 [4.0, 0.05],
196 [5.0, -0.05],
197 ],
198 &Device::Cpu,
199 )?;
200
201 let result = pca_top_k(&data, 2, 50)?;
202
203 // PC1 should capture most variance (> 99%)
204 assert!(
205 result.explained_variance_ratio[0] > 0.99,
206 "PC1 variance ratio {:.4} should be > 0.99",
207 result.explained_variance_ratio[0],
208 );
209
210 // PC1 direction should be close to (1, 0) or (-1, 0)
211 let pc1: Vec<f32> = result.components.get(0)?.to_vec1()?;
212 assert!(
213 pc1[0].abs() > 0.99,
214 "PC1[0] = {:.4}, expected close to ±1.0",
215 pc1[0],
216 );
217
218 // Two components should sum to ~1.0
219 let total: f32 = result.explained_variance_ratio.iter().sum();
220 assert!(
221 (total - 1.0).abs() < 0.01,
222 "Total variance {total:.4} should be ~1.0",
223 );
224
225 Ok(())
226 }
227}