Skip to main content

oxirs_embed/
dimensionality_reducer.rs

1//! # Dimensionality Reducer
2//!
3//! PCA-based dimensionality reduction for embedding vectors.
4//! Implements power-iteration PCA for top-k eigenvector extraction,
5//! plus a thin Truncated SVD wrapper.
6//!
7//! ## Algorithm
8//!
9//! 1. Center the data (subtract column means).
10//! 2. Compute the covariance matrix.
11//! 3. Use power iteration with deflation to extract the top-k eigenvectors.
12//! 4. Project data onto the component subspace.
13
14use std::fmt;
15
16// ─────────────────────────────────────────────
17// Error
18// ─────────────────────────────────────────────
19
20/// Errors from dimensionality reduction operations.
21#[derive(Debug, Clone, PartialEq)]
22pub enum ReductionError {
23    /// Not enough data samples to compute PCA.
24    InsufficientData,
25    /// Requested more components than features.
26    TooManyComponents,
27    /// Dimension mismatch between training and transform data.
28    DimensionMismatch(String),
29}
30
31impl fmt::Display for ReductionError {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Self::InsufficientData => write!(f, "insufficient data for PCA"),
35            Self::TooManyComponents => write!(f, "requested more components than features"),
36            Self::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
37        }
38    }
39}
40
41impl std::error::Error for ReductionError {}
42
43// ─────────────────────────────────────────────
44// Result type
45// ─────────────────────────────────────────────
46
47/// Output of a `fit_transform` call.
48#[derive(Debug, Clone)]
49pub struct ReductionResult {
50    /// Projected data in the reduced space.
51    pub reduced: Vec<Vec<f64>>,
52    /// Fraction of total variance explained by each component.
53    pub explained_variance_ratio: Vec<f64>,
54    /// Sum of `explained_variance_ratio`.
55    pub total_variance_explained: f64,
56}
57
58// ─────────────────────────────────────────────
59// PCA reducer
60// ─────────────────────────────────────────────
61
62/// PCA dimensionality reducer using power-iteration eigenvector extraction.
63#[derive(Debug, Clone)]
64pub struct PcaReducer {
65    /// Principal components (each row is one eigenvector, shape n_components × n_features).
66    pub components: Vec<Vec<f64>>,
67    /// Per-feature mean used for centering.
68    pub mean: Vec<f64>,
69    /// Variance explained by each component.
70    pub explained_variance: Vec<f64>,
71    /// Number of components retained.
72    pub n_components: usize,
73}
74
75impl PcaReducer {
76    /// Fit PCA on `data` and retain `n_components` principal components.
77    ///
78    /// Uses 50 power-iteration steps per component.
79    pub fn fit(data: &[Vec<f64>], n_components: usize) -> Result<Self, ReductionError> {
80        if data.is_empty() {
81            return Err(ReductionError::InsufficientData);
82        }
83        let n_features = data[0].len();
84        if n_features == 0 {
85            return Err(ReductionError::InsufficientData);
86        }
87        if n_components > n_features {
88            return Err(ReductionError::TooManyComponents);
89        }
90        if n_components == 0 {
91            return Err(ReductionError::TooManyComponents);
92        }
93
94        let (centered, mean) = center_data(data);
95        let covariance = compute_covariance(&centered);
96
97        let mut components: Vec<Vec<f64>> = Vec::with_capacity(n_components);
98        let mut explained_variance: Vec<f64> = Vec::with_capacity(n_components);
99
100        // Deflating copy of the covariance matrix.
101        let mut cov = covariance;
102
103        for _ in 0..n_components {
104            let (eigenvec, eigenval) = power_iteration(&cov, 50);
105            // Deflate: cov ← cov − λ * v * vᵀ
106            deflate(&mut cov, &eigenvec, eigenval);
107            explained_variance.push(eigenval);
108            components.push(eigenvec);
109        }
110
111        Ok(Self {
112            components,
113            mean,
114            explained_variance,
115            n_components,
116        })
117    }
118
119    /// Project `data` onto the PCA component space.
120    pub fn transform(&self, data: &[Vec<f64>]) -> Result<Vec<Vec<f64>>, ReductionError> {
121        if data.is_empty() {
122            return Ok(vec![]);
123        }
124        let n_features = data[0].len();
125        if n_features != self.mean.len() {
126            return Err(ReductionError::DimensionMismatch(format!(
127                "expected {} features, got {}",
128                self.mean.len(),
129                n_features
130            )));
131        }
132
133        let mut result = Vec::with_capacity(data.len());
134        for row in data {
135            let centered: Vec<f64> = row
136                .iter()
137                .zip(self.mean.iter())
138                .map(|(x, m)| x - m)
139                .collect();
140            let projected = self
141                .components
142                .iter()
143                .map(|comp| dot_product(&centered, comp))
144                .collect();
145            result.push(projected);
146        }
147        Ok(result)
148    }
149
150    /// Fit PCA and immediately project the training data.
151    pub fn fit_transform(
152        data: &[Vec<f64>],
153        n_components: usize,
154    ) -> Result<ReductionResult, ReductionError> {
155        let reducer = Self::fit(data, n_components)?;
156        let reduced = reducer.transform(data)?;
157
158        let total_var: f64 = reducer
159            .explained_variance
160            .iter()
161            .sum::<f64>()
162            .max(f64::EPSILON);
163        let explained_variance_ratio: Vec<f64> = reducer
164            .explained_variance
165            .iter()
166            .map(|&v| v / total_var)
167            .collect();
168        let total_variance_explained: f64 = explained_variance_ratio.iter().sum();
169
170        Ok(ReductionResult {
171            reduced,
172            explained_variance_ratio,
173            total_variance_explained,
174        })
175    }
176
177    /// Approximately reconstruct original-space vectors from reduced-space vectors.
178    ///
179    /// This is a lossy reconstruction: `x̂ = mean + W * z`
180    pub fn inverse_transform(&self, reduced: &[Vec<f64>]) -> Vec<Vec<f64>> {
181        let n_features = self.mean.len();
182        let mut result = Vec::with_capacity(reduced.len());
183
184        for row in reduced {
185            let mut rec = self.mean.clone();
186            for (k, &coeff) in row.iter().enumerate() {
187                if let Some(comp) = self.components.get(k) {
188                    for (f, &w) in comp.iter().enumerate() {
189                        rec[f] += coeff * w;
190                    }
191                }
192            }
193            // rec should have n_features entries
194            let _ = n_features;
195            result.push(rec);
196        }
197        result
198    }
199}
200
201// ─────────────────────────────────────────────
202// Truncated SVD (thin wrapper)
203// ─────────────────────────────────────────────
204
205/// Truncated SVD: operates on the raw (uncentered) data matrix.
206#[derive(Debug, Clone)]
207pub struct TruncatedSvd {
208    /// Right singular vectors (n_components × n_features).
209    pub components: Vec<Vec<f64>>,
210    /// Singular values.
211    pub singular_values: Vec<f64>,
212    /// Number of components retained.
213    pub n_components: usize,
214}
215
216impl TruncatedSvd {
217    /// Compute a truncated SVD via power iteration on XᵀX.
218    pub fn fit(data: &[Vec<f64>], n_components: usize) -> Result<Self, ReductionError> {
219        if data.is_empty() {
220            return Err(ReductionError::InsufficientData);
221        }
222        let n_features = data[0].len();
223        if n_components > n_features {
224            return Err(ReductionError::TooManyComponents);
225        }
226        if n_components == 0 {
227            return Err(ReductionError::TooManyComponents);
228        }
229
230        // Build XᵀX (n_features × n_features)
231        let xt_x = compute_gram(data);
232
233        let mut components: Vec<Vec<f64>> = Vec::with_capacity(n_components);
234        let mut singular_values: Vec<f64> = Vec::with_capacity(n_components);
235        let mut gram = xt_x;
236
237        for _ in 0..n_components {
238            let (v, lambda) = power_iteration(&gram, 50);
239            deflate(&mut gram, &v, lambda);
240            singular_values.push(lambda.sqrt().max(0.0));
241            components.push(v);
242        }
243
244        Ok(Self {
245            components,
246            singular_values,
247            n_components,
248        })
249    }
250}
251
252// ─────────────────────────────────────────────
253// Internal helpers
254// ─────────────────────────────────────────────
255
256/// Center the data by subtracting column means.
257/// Returns `(centered_data, column_means)`.
258pub fn center_data(data: &[Vec<f64>]) -> (Vec<Vec<f64>>, Vec<f64>) {
259    let n = data.len();
260    if n == 0 {
261        return (vec![], vec![]);
262    }
263    let d = data[0].len();
264    let mut mean = vec![0.0f64; d];
265    for row in data {
266        for (j, &v) in row.iter().enumerate() {
267            mean[j] += v;
268        }
269    }
270    for m in mean.iter_mut() {
271        *m /= n as f64;
272    }
273    let centered = data
274        .iter()
275        .map(|row| row.iter().zip(mean.iter()).map(|(x, m)| x - m).collect())
276        .collect();
277    (centered, mean)
278}
279
280/// Compute the dot product of two equal-length slices.
281pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
282    a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
283}
284
285/// Normalise a vector in-place (L2 norm). No-op for zero vectors.
286pub fn normalize(v: &mut [f64]) {
287    let norm = dot_product(v, v).sqrt();
288    if norm > f64::EPSILON {
289        for x in v.iter_mut() {
290            *x /= norm;
291        }
292    }
293}
294
295/// Multiply matrix `mat` (m × n) by vector `vec` (n) → result (m).
296pub fn mat_vec_mul(mat: &[Vec<f64>], vec: &[f64]) -> Vec<f64> {
297    mat.iter().map(|row| dot_product(row, vec)).collect()
298}
299
300/// Compute the covariance matrix (n_features × n_features) of centered data.
301fn compute_covariance(centered: &[Vec<f64>]) -> Vec<Vec<f64>> {
302    let n = centered.len();
303    if n == 0 {
304        return vec![];
305    }
306    let d = centered[0].len();
307    let denom = (n.saturating_sub(1).max(1)) as f64;
308
309    let mut cov = vec![vec![0.0f64; d]; d];
310    for row in centered {
311        #[allow(clippy::needless_range_loop)]
312        for i in 0..d {
313            for j in i..d {
314                cov[i][j] += row[i] * row[j];
315            }
316        }
317    }
318    #[allow(clippy::needless_range_loop)]
319    for i in 0..d {
320        for j in i..d {
321            let val = cov[i][j] / denom;
322            cov[i][j] = val;
323            cov[j][i] = val;
324        }
325    }
326    cov
327}
328
329/// Compute XᵀX (Gram matrix) for SVD.
330fn compute_gram(data: &[Vec<f64>]) -> Vec<Vec<f64>> {
331    let n = data.len();
332    if n == 0 {
333        return vec![];
334    }
335    let d = data[0].len();
336    let denom = n as f64;
337    let mut gram = vec![vec![0.0f64; d]; d];
338    for row in data {
339        #[allow(clippy::needless_range_loop)]
340        for i in 0..d {
341            for j in i..d {
342                gram[i][j] += row[i] * row[j];
343            }
344        }
345    }
346    #[allow(clippy::needless_range_loop)]
347    for i in 0..d {
348        for j in i..d {
349            let val = gram[i][j] / denom;
350            gram[i][j] = val;
351            gram[j][i] = val;
352        }
353    }
354    gram
355}
356
357/// Power iteration: find the dominant eigenvector of a symmetric matrix.
358/// Returns `(eigenvector, eigenvalue)`.
359fn power_iteration(mat: &[Vec<f64>], iterations: usize) -> (Vec<f64>, f64) {
360    let d = mat.len();
361    if d == 0 {
362        return (vec![], 0.0);
363    }
364    // Initialise with a seeded non-zero vector.
365    let mut v: Vec<f64> = (0..d).map(|i| (i as f64 + 1.0).recip()).collect();
366    normalize(&mut v);
367
368    for _ in 0..iterations {
369        let mut w = mat_vec_mul(mat, &v);
370        let eigenval_est = dot_product(&v, &w);
371        if eigenval_est.abs() < f64::EPSILON {
372            break;
373        }
374        normalize(&mut w);
375        v = w;
376    }
377
378    // Rayleigh quotient for eigenvalue.
379    let av = mat_vec_mul(mat, &v);
380    let eigenvalue = dot_product(&v, &av);
381
382    (v, eigenvalue.max(0.0))
383}
384
385/// Deflate a symmetric matrix: `mat ← mat − λ * v * vᵀ`
386fn deflate(mat: &mut [Vec<f64>], v: &[f64], lambda: f64) {
387    let d = mat.len();
388    for i in 0..d {
389        for j in 0..d {
390            mat[i][j] -= lambda * v[i] * v[j];
391        }
392    }
393}
394
395// ─────────────────────────────────────────────
396// Tests
397// ─────────────────────────────────────────────
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402
403    // ── helpers ────────────────────────────────────────────────────────
404
405    /// Simple 2-D dataset with strong first principal axis along (1, 1) / √2.
406    fn axis_data() -> Vec<Vec<f64>> {
407        // Points along the line y = x with small noise
408        vec![
409            vec![1.0, 1.0],
410            vec![2.0, 2.0],
411            vec![3.0, 3.0],
412            vec![4.0, 4.0],
413            vec![5.0, 5.0],
414        ]
415    }
416
417    fn near(a: f64, b: f64, tol: f64) -> bool {
418        (a - b).abs() < tol
419    }
420
421    // ── center_data ────────────────────────────────────────────────────
422
423    #[test]
424    fn test_center_data_zero_mean() {
425        let data = vec![vec![1.0, 3.0], vec![3.0, 7.0]];
426        let (centered, mean) = center_data(&data);
427        assert!(near(mean[0], 2.0, 1e-9));
428        assert!(near(mean[1], 5.0, 1e-9));
429        assert!(near(centered[0][0], -1.0, 1e-9));
430        assert!(near(centered[1][0], 1.0, 1e-9));
431    }
432
433    #[test]
434    fn test_center_data_empty() {
435        let (centered, mean) = center_data(&[]);
436        assert!(centered.is_empty());
437        assert!(mean.is_empty());
438    }
439
440    // ── dot_product ────────────────────────────────────────────────────
441
442    #[test]
443    fn test_dot_product_basic() {
444        assert!(near(
445            dot_product(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]),
446            32.0,
447            1e-9
448        ));
449    }
450
451    #[test]
452    fn test_dot_product_orthogonal() {
453        assert!(near(dot_product(&[1.0, 0.0], &[0.0, 1.0]), 0.0, 1e-9));
454    }
455
456    // ── normalize ─────────────────────────────────────────────────────
457
458    #[test]
459    fn test_normalize_unit() {
460        let mut v = vec![3.0, 4.0];
461        normalize(&mut v);
462        let norm = dot_product(&v, &v).sqrt();
463        assert!(near(norm, 1.0, 1e-9));
464    }
465
466    #[test]
467    fn test_normalize_zero_vector_noop() {
468        let mut v = vec![0.0, 0.0];
469        normalize(&mut v);
470        assert!(near(v[0], 0.0, 1e-9));
471    }
472
473    // ── mat_vec_mul ────────────────────────────────────────────────────
474
475    #[test]
476    fn test_mat_vec_mul_identity() {
477        let eye = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
478        let v = vec![3.0, 5.0];
479        let result = mat_vec_mul(&eye, &v);
480        assert!(near(result[0], 3.0, 1e-9));
481        assert!(near(result[1], 5.0, 1e-9));
482    }
483
484    // ── PcaReducer::fit ────────────────────────────────────────────────
485
486    #[test]
487    fn test_fit_success() {
488        let data = axis_data();
489        let pca = PcaReducer::fit(&data, 1).expect("should fit");
490        assert_eq!(pca.n_components, 1);
491        assert_eq!(pca.components.len(), 1);
492        assert_eq!(pca.mean.len(), 2);
493    }
494
495    #[test]
496    fn test_fit_empty_data_error() {
497        assert!(matches!(
498            PcaReducer::fit(&[], 1),
499            Err(ReductionError::InsufficientData)
500        ));
501    }
502
503    #[test]
504    fn test_fit_too_many_components_error() {
505        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
506        assert!(matches!(
507            PcaReducer::fit(&data, 3),
508            Err(ReductionError::TooManyComponents)
509        ));
510    }
511
512    #[test]
513    fn test_fit_zero_components_error() {
514        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
515        assert!(matches!(
516            PcaReducer::fit(&data, 0),
517            Err(ReductionError::TooManyComponents)
518        ));
519    }
520
521    // ── PcaReducer::transform ──────────────────────────────────────────
522
523    #[test]
524    fn test_transform_correct_output_shape() {
525        let data = axis_data();
526        let pca = PcaReducer::fit(&data, 1).expect("fit");
527        let reduced = pca.transform(&data).expect("transform");
528        assert_eq!(reduced.len(), data.len());
529        assert_eq!(reduced[0].len(), 1);
530    }
531
532    #[test]
533    fn test_transform_empty_input() {
534        let data = axis_data();
535        let pca = PcaReducer::fit(&data, 1).expect("fit");
536        let reduced = pca.transform(&[]).expect("transform");
537        assert!(reduced.is_empty());
538    }
539
540    #[test]
541    fn test_transform_dimension_mismatch() {
542        let data = axis_data();
543        let pca = PcaReducer::fit(&data, 1).expect("fit");
544        let bad_data = vec![vec![1.0, 2.0, 3.0]]; // 3 features, expected 2
545        assert!(matches!(
546            pca.transform(&bad_data),
547            Err(ReductionError::DimensionMismatch(_))
548        ));
549    }
550
551    // ── PcaReducer::fit_transform ──────────────────────────────────────
552
553    #[test]
554    fn test_fit_transform_variance_ratio_sums_to_one() {
555        let data: Vec<Vec<f64>> = (0..10)
556            .map(|i| vec![i as f64, (i * 2) as f64, (i * 3) as f64])
557            .collect();
558        let result = PcaReducer::fit_transform(&data, 2).expect("fit_transform");
559        assert!(
560            near(result.total_variance_explained, 1.0, 0.05),
561            "total_variance_explained={}",
562            result.total_variance_explained
563        );
564    }
565
566    #[test]
567    fn test_fit_transform_shape() {
568        let data = axis_data();
569        let result = PcaReducer::fit_transform(&data, 1).expect("fit_transform");
570        assert_eq!(result.reduced.len(), data.len());
571        assert_eq!(result.reduced[0].len(), 1);
572        assert_eq!(result.explained_variance_ratio.len(), 1);
573    }
574
575    #[test]
576    fn test_fit_transform_explained_variance_first_component() {
577        // For data along y=x, first PC should capture nearly all variance.
578        let data = axis_data();
579        let result = PcaReducer::fit_transform(&data, 1).expect("fit_transform");
580        assert!(
581            result.total_variance_explained > 0.95,
582            "expected high explained variance, got {}",
583            result.total_variance_explained
584        );
585    }
586
587    // ── PcaReducer::inverse_transform ─────────────────────────────────
588
589    #[test]
590    fn test_inverse_transform_shape() {
591        let data = axis_data();
592        let pca = PcaReducer::fit(&data, 1).expect("fit");
593        let reduced = pca.transform(&data).expect("transform");
594        let reconstructed = pca.inverse_transform(&reduced);
595        assert_eq!(reconstructed.len(), data.len());
596        assert_eq!(reconstructed[0].len(), data[0].len());
597    }
598
599    #[test]
600    fn test_inverse_transform_approximate_reconstruction() {
601        // For data along y=x, reconstruction should be close.
602        let data = axis_data();
603        let pca = PcaReducer::fit(&data, 1).expect("fit");
604        let reduced = pca.transform(&data).expect("transform");
605        let reconstructed = pca.inverse_transform(&reduced);
606        // Each reconstructed point should be close to original
607        for (orig, rec) in data.iter().zip(reconstructed.iter()) {
608            let err: f64 = orig
609                .iter()
610                .zip(rec.iter())
611                .map(|(a, b)| (a - b).powi(2))
612                .sum::<f64>()
613                .sqrt();
614            assert!(err < 0.5, "reconstruction error too large: {err}");
615        }
616    }
617
618    #[test]
619    fn test_inverse_transform_empty() {
620        let data = axis_data();
621        let pca = PcaReducer::fit(&data, 1).expect("fit");
622        let reconstructed = pca.inverse_transform(&[]);
623        assert!(reconstructed.is_empty());
624    }
625
626    // ── TruncatedSvd ───────────────────────────────────────────────────
627
628    #[test]
629    fn test_truncated_svd_fit_success() {
630        let data: Vec<Vec<f64>> = (0..10).map(|i| vec![i as f64, (i * 2) as f64]).collect();
631        let svd = TruncatedSvd::fit(&data, 1).expect("svd fit");
632        assert_eq!(svd.n_components, 1);
633        assert_eq!(svd.singular_values.len(), 1);
634        assert!(svd.singular_values[0] >= 0.0);
635    }
636
637    #[test]
638    fn test_truncated_svd_too_many_components() {
639        let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
640        assert!(matches!(
641            TruncatedSvd::fit(&data, 3),
642            Err(ReductionError::TooManyComponents)
643        ));
644    }
645
646    #[test]
647    fn test_truncated_svd_empty_data() {
648        assert!(matches!(
649            TruncatedSvd::fit(&[], 1),
650            Err(ReductionError::InsufficientData)
651        ));
652    }
653
654    // ── error display ─────────────────────────────────────────────────
655
656    #[test]
657    fn test_error_display() {
658        let e = ReductionError::DimensionMismatch("test".to_string());
659        assert!(e.to_string().contains("test"));
660        let e2 = ReductionError::InsufficientData;
661        assert!(!e2.to_string().is_empty());
662        let e3 = ReductionError::TooManyComponents;
663        assert!(!e3.to_string().is_empty());
664    }
665
666    // ── 3D data ────────────────────────────────────────────────────────
667
668    #[test]
669    fn test_3d_to_2d_reduction() {
670        let data: Vec<Vec<f64>> = (0..20)
671            .map(|i| {
672                let x = i as f64;
673                vec![x, x * 2.0, x * 0.5 + 1.0]
674            })
675            .collect();
676        let pca = PcaReducer::fit(&data, 2).expect("fit");
677        let reduced = pca.transform(&data).expect("transform");
678        assert_eq!(reduced[0].len(), 2);
679    }
680}