Skip to main content

anofox_ml_preprocessing/
fast_ica.rs

1//! FastICA — fixed-point Independent Component Analysis with deflation.
2//!
3//! Mirrors `sklearn.decomposition.FastICA` with `algorithm='deflation'` and
4//! `fun='logcosh'`. Standard pipeline:
5//!
6//! 1. Centre and whiten X via PCA so that `cov(X) = I`.
7//! 2. For each component, iterate the fixed-point update
8//!    `w ← E[X g(wᵀX)] - E[g'(wᵀX)] w`, orthogonalise against previously
9//!    extracted components, normalise to unit length.
10//! 3. The sources `S = W X_white`; the unmixing matrix in original space is
11//!    `W K` where `K` is the whitening matrix.
12
13use anofox_ml_core::{FitUnsupervised, Result, RustMlError, Transform};
14use faer::linalg::solvers::Svd;
15use faer::Mat;
16use ndarray::{Array1, Array2};
17use rand::rngs::StdRng;
18use rand::{Rng, SeedableRng};
19
20#[derive(Debug, Clone)]
21pub struct FastIca {
22    pub n_components: usize,
23    pub max_iter: usize,
24    pub tol: f64,
25    pub seed: u64,
26}
27
28impl FastIca {
29    pub fn new(n_components: usize) -> Self {
30        Self {
31            n_components,
32            max_iter: 200,
33            tol: 1e-4,
34            seed: 0,
35        }
36    }
37    pub fn with_max_iter(mut self, m: usize) -> Self {
38        self.max_iter = m;
39        self
40    }
41    pub fn with_tol(mut self, t: f64) -> Self {
42        self.tol = t;
43        self
44    }
45    pub fn with_seed(mut self, s: u64) -> Self {
46        self.seed = s;
47        self
48    }
49}
50
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct FittedFastIca {
53    /// Unmixing matrix from whitened space, shape (n_components, n_components).
54    pub w: Array2<f64>,
55    /// Whitening matrix `K` such that `X_centered @ K` is whitened, shape
56    /// (n_features, n_components).
57    pub whitening: Array2<f64>,
58    /// Per-feature mean used for centring.
59    pub mean: Array1<f64>,
60    pub n_features: usize,
61}
62
63/// Logcosh non-linearity: `g(u) = tanh(u)`, `g'(u) = 1 - tanh(u)²`.
64fn g_logcosh(u: f64) -> (f64, f64) {
65    let t = u.tanh();
66    (t, 1.0 - t * t)
67}
68
69impl FitUnsupervised<f64> for FastIca {
70    type Fitted = FittedFastIca;
71
72    fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
73        let n = x.nrows();
74        let d = x.ncols();
75        let k = self.n_components.min(d.min(n));
76        if k == 0 {
77            return Err(RustMlError::InvalidParameter("n_components >= 1".into()));
78        }
79        if n < 2 {
80            return Err(RustMlError::EmptyInput("need at least 2 samples".into()));
81        }
82
83        // 1. Centre.
84        let mut mean = Array1::<f64>::zeros(d);
85        for j in 0..d {
86            mean[j] = x.column(j).sum() / n as f64;
87        }
88        let mut xc = x.clone();
89        for j in 0..d {
90            for i in 0..n {
91                xc[[i, j]] -= mean[j];
92            }
93        }
94
95        // 2. Whiten via SVD: X_centered = U Σ Vᵀ. Whitening matrix K = V Σ⁻¹ √(n-1).
96        let xm = Mat::from_fn(n, d, |i, j| xc[[i, j]]);
97        let svd = Svd::new(xm.as_ref())
98            .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
99        let s = svd.S();
100        let v = svd.V();
101        let scale = (n as f64 - 1.0).sqrt();
102        let mut k_white = Array2::<f64>::zeros((d, k));
103        for c in 0..k {
104            let sigma = s.column_vector()[c].max(1e-12);
105            for j in 0..d {
106                k_white[[j, c]] = v[(j, c)] * scale / sigma;
107            }
108        }
109        // Whitened data: X1 = X_centered @ K, shape (n, k).
110        let x1 = xc.dot(&k_white);
111
112        // 3. Deflation extraction.
113        let mut rng = StdRng::seed_from_u64(self.seed);
114        let mut w = Array2::<f64>::zeros((k, k));
115        for comp in 0..k {
116            // Random init.
117            let mut wi: Array1<f64> = Array1::from_shape_fn(k, |_| rng.gen::<f64>() * 2.0 - 1.0);
118            // Normalize.
119            let nrm = wi.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-12);
120            wi.mapv_inplace(|v| v / nrm);
121
122            for _ in 0..self.max_iter {
123                // Compute u = X1 @ wi (length n).
124                let mut u = vec![0.0_f64; n];
125                for i in 0..n {
126                    let mut s = 0.0;
127                    for c in 0..k {
128                        s += x1[[i, c]] * wi[c];
129                    }
130                    u[i] = s;
131                }
132                // g(u) and g'(u).
133                let mut gu = vec![0.0_f64; n];
134                let mut g_prime_mean = 0.0_f64;
135                for i in 0..n {
136                    let (g, gp) = g_logcosh(u[i]);
137                    gu[i] = g;
138                    g_prime_mean += gp;
139                }
140                g_prime_mean /= n as f64;
141                // New wi: E[X1 g(wᵀ X1)] - E[g'(...)] w.
142                let mut new_wi = Array1::<f64>::zeros(k);
143                for c in 0..k {
144                    let mut s = 0.0;
145                    for i in 0..n {
146                        s += x1[[i, c]] * gu[i];
147                    }
148                    new_wi[c] = s / n as f64 - g_prime_mean * wi[c];
149                }
150                // Deflate: orthogonalise against previously-extracted components.
151                for prev in 0..comp {
152                    let mut dot = 0.0;
153                    for c in 0..k {
154                        dot += new_wi[c] * w[[prev, c]];
155                    }
156                    for c in 0..k {
157                        new_wi[c] -= dot * w[[prev, c]];
158                    }
159                }
160                // Normalise.
161                let nrm = new_wi.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-12);
162                new_wi.mapv_inplace(|v| v / nrm);
163
164                // Convergence: |1 - |<w_new, w_old>||.
165                let mut dot = 0.0;
166                for c in 0..k {
167                    dot += new_wi[c] * wi[c];
168                }
169                let conv = (1.0 - dot.abs()).abs();
170                wi = new_wi;
171                if conv < self.tol {
172                    break;
173                }
174            }
175            for c in 0..k {
176                w[[comp, c]] = wi[c];
177            }
178        }
179
180        Ok(FittedFastIca {
181            w,
182            whitening: k_white,
183            mean,
184            n_features: d,
185        })
186    }
187}
188
189impl Transform<f64> for FittedFastIca {
190    /// Returns the recovered source signals `S = (X - mean) · K · Wᵀ`.
191    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
192        if x.ncols() != self.n_features {
193            return Err(RustMlError::ShapeMismatch(format!(
194                "expected {} features, got {}",
195                self.n_features,
196                x.ncols()
197            )));
198        }
199        let mut xc = x.clone();
200        for j in 0..self.n_features {
201            for i in 0..x.nrows() {
202                xc[[i, j]] -= self.mean[j];
203            }
204        }
205        let x_white = xc.dot(&self.whitening);
206        Ok(x_white.dot(&self.w.t()))
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use ndarray::array;
214
215    #[test]
216    fn test_fast_ica_runs() {
217        // Build a mixture of two simple signals; FastICA should separate them.
218        let n = 100;
219        let mut s = Array2::<f64>::zeros((n, 2));
220        for i in 0..n {
221            let t = i as f64 * 0.1;
222            s[[i, 0]] = t.sin(); // sine
223            s[[i, 1]] = (t * 0.3).sin().signum(); // square
224        }
225        // Mixing matrix.
226        let a = array![[1.0_f64, 0.5], [0.5, 1.0]];
227        let x = s.dot(&a);
228        let fitted = FastIca::new(2).with_seed(1).fit(&x).unwrap();
229        let recovered = fitted.transform(&x).unwrap();
230        assert_eq!(recovered.shape(), &[n, 2]);
231        for v in recovered.iter() {
232            assert!(v.is_finite());
233        }
234    }
235}