anofox_ml_preprocessing/
fast_ica.rs1use anofox_ml_core::{FitUnsupervised, Result, RustMlError, Transform};
14use faer::linalg::solvers::Svd;
15use faer::Mat;
16use ndarray::{Array1, Array2};
17use rand::rngs::StdRng;
18use rand::{Rng, SeedableRng};
19
20#[derive(Debug, Clone)]
21pub struct FastIca {
22 pub n_components: usize,
23 pub max_iter: usize,
24 pub tol: f64,
25 pub seed: u64,
26}
27
28impl FastIca {
29 pub fn new(n_components: usize) -> Self {
30 Self {
31 n_components,
32 max_iter: 200,
33 tol: 1e-4,
34 seed: 0,
35 }
36 }
37 pub fn with_max_iter(mut self, m: usize) -> Self {
38 self.max_iter = m;
39 self
40 }
41 pub fn with_tol(mut self, t: f64) -> Self {
42 self.tol = t;
43 self
44 }
45 pub fn with_seed(mut self, s: u64) -> Self {
46 self.seed = s;
47 self
48 }
49}
50
51#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
52pub struct FittedFastIca {
53 pub w: Array2<f64>,
55 pub whitening: Array2<f64>,
58 pub mean: Array1<f64>,
60 pub n_features: usize,
61}
62
63fn g_logcosh(u: f64) -> (f64, f64) {
65 let t = u.tanh();
66 (t, 1.0 - t * t)
67}
68
69impl FitUnsupervised<f64> for FastIca {
70 type Fitted = FittedFastIca;
71
72 fn fit(&self, x: &Array2<f64>) -> Result<Self::Fitted> {
73 let n = x.nrows();
74 let d = x.ncols();
75 let k = self.n_components.min(d.min(n));
76 if k == 0 {
77 return Err(RustMlError::InvalidParameter("n_components >= 1".into()));
78 }
79 if n < 2 {
80 return Err(RustMlError::EmptyInput("need at least 2 samples".into()));
81 }
82
83 let mut mean = Array1::<f64>::zeros(d);
85 for j in 0..d {
86 mean[j] = x.column(j).sum() / n as f64;
87 }
88 let mut xc = x.clone();
89 for j in 0..d {
90 for i in 0..n {
91 xc[[i, j]] -= mean[j];
92 }
93 }
94
95 let xm = Mat::from_fn(n, d, |i, j| xc[[i, j]]);
97 let svd = Svd::new(xm.as_ref())
98 .map_err(|e| RustMlError::InvalidParameter(format!("SVD failed: {e:?}")))?;
99 let s = svd.S();
100 let v = svd.V();
101 let scale = (n as f64 - 1.0).sqrt();
102 let mut k_white = Array2::<f64>::zeros((d, k));
103 for c in 0..k {
104 let sigma = s.column_vector()[c].max(1e-12);
105 for j in 0..d {
106 k_white[[j, c]] = v[(j, c)] * scale / sigma;
107 }
108 }
109 let x1 = xc.dot(&k_white);
111
112 let mut rng = StdRng::seed_from_u64(self.seed);
114 let mut w = Array2::<f64>::zeros((k, k));
115 for comp in 0..k {
116 let mut wi: Array1<f64> = Array1::from_shape_fn(k, |_| rng.gen::<f64>() * 2.0 - 1.0);
118 let nrm = wi.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-12);
120 wi.mapv_inplace(|v| v / nrm);
121
122 for _ in 0..self.max_iter {
123 let mut u = vec![0.0_f64; n];
125 for i in 0..n {
126 let mut s = 0.0;
127 for c in 0..k {
128 s += x1[[i, c]] * wi[c];
129 }
130 u[i] = s;
131 }
132 let mut gu = vec![0.0_f64; n];
134 let mut g_prime_mean = 0.0_f64;
135 for i in 0..n {
136 let (g, gp) = g_logcosh(u[i]);
137 gu[i] = g;
138 g_prime_mean += gp;
139 }
140 g_prime_mean /= n as f64;
141 let mut new_wi = Array1::<f64>::zeros(k);
143 for c in 0..k {
144 let mut s = 0.0;
145 for i in 0..n {
146 s += x1[[i, c]] * gu[i];
147 }
148 new_wi[c] = s / n as f64 - g_prime_mean * wi[c];
149 }
150 for prev in 0..comp {
152 let mut dot = 0.0;
153 for c in 0..k {
154 dot += new_wi[c] * w[[prev, c]];
155 }
156 for c in 0..k {
157 new_wi[c] -= dot * w[[prev, c]];
158 }
159 }
160 let nrm = new_wi.iter().map(|v| v * v).sum::<f64>().sqrt().max(1e-12);
162 new_wi.mapv_inplace(|v| v / nrm);
163
164 let mut dot = 0.0;
166 for c in 0..k {
167 dot += new_wi[c] * wi[c];
168 }
169 let conv = (1.0 - dot.abs()).abs();
170 wi = new_wi;
171 if conv < self.tol {
172 break;
173 }
174 }
175 for c in 0..k {
176 w[[comp, c]] = wi[c];
177 }
178 }
179
180 Ok(FittedFastIca {
181 w,
182 whitening: k_white,
183 mean,
184 n_features: d,
185 })
186 }
187}
188
189impl Transform<f64> for FittedFastIca {
190 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
192 if x.ncols() != self.n_features {
193 return Err(RustMlError::ShapeMismatch(format!(
194 "expected {} features, got {}",
195 self.n_features,
196 x.ncols()
197 )));
198 }
199 let mut xc = x.clone();
200 for j in 0..self.n_features {
201 for i in 0..x.nrows() {
202 xc[[i, j]] -= self.mean[j];
203 }
204 }
205 let x_white = xc.dot(&self.whitening);
206 Ok(x_white.dot(&self.w.t()))
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use ndarray::array;
214
215 #[test]
216 fn test_fast_ica_runs() {
217 let n = 100;
219 let mut s = Array2::<f64>::zeros((n, 2));
220 for i in 0..n {
221 let t = i as f64 * 0.1;
222 s[[i, 0]] = t.sin(); s[[i, 1]] = (t * 0.3).sin().signum(); }
225 let a = array![[1.0_f64, 0.5], [0.5, 1.0]];
227 let x = s.dot(&a);
228 let fitted = FastIca::new(2).with_seed(1).fit(&x).unwrap();
229 let recovered = fitted.transform(&x).unwrap();
230 assert_eq!(recovered.shape(), &[n, 2]);
231 for v in recovered.iter() {
232 assert!(v.is_finite());
233 }
234 }
235}