Skip to main content

ferrolearn_decomp/
fast_ica.rs

1//! Fast Independent Component Analysis (FastICA).
2//!
3//! FastICA separates a multivariate signal into additive independent components
4//! by maximising non-Gaussianity (negentropy approximation).
5//!
6//! # Algorithm
7//!
8//! 1. **Centre**: subtract the mean of each feature.
9//! 2. **Whiten** (PCA whitening): decorrelate and scale the data so that each
10//!    component has unit variance.
11//! 3. **FastICA iteration**: for each unmixing direction `w`, iterate:
12//!    ```text
13//!    w' = E[X g(w^T X)] - E[g'(w^T X)] w
14//!    w' = w' / ||w'||
15//!    ```
16//!    until convergence, using a chosen nonlinearity `g`.
17//! 4. Two variants are supported: `Parallel` (update all directions
18//!    simultaneously) and `Deflation` (extract one at a time via Gram-Schmidt).
19//!
20//! # Non-linearities
21//!
22//! - [`NonLinearity::LogCosh`]: `g(u) = tanh(u)`.
23//! - [`NonLinearity::Exp`]: `g(u) = u exp(-u²/2)`.
24//! - [`NonLinearity::Cube`]: `g(u) = u³`.
25//!
26//! # Examples
27//!
28//! ```
29//! use ferrolearn_decomp::fast_ica::{FastICA, Algorithm, NonLinearity};
30//! use ferrolearn_core::traits::{Fit, Transform};
31//! use ndarray::Array2;
32//!
33//! let ica = FastICA::new(2)
34//!     .with_algorithm(Algorithm::Deflation)
35//!     .with_fun(NonLinearity::LogCosh)
36//!     .with_random_state(0);
37//!
38//! let x = Array2::from_shape_vec(
39//!     (6, 2),
40//!     vec![1.0, 0.0, 0.0, 1.0, -1.0, 0.0, 0.0, -1.0, 1.0, 1.0, -1.0, -1.0],
41//! ).unwrap();
42//! let fitted = ica.fit(&x, &()).unwrap();
43//! let sources = fitted.transform(&x).unwrap();
44//! assert_eq!(sources.ncols(), 2);
45//! ```
46
47use ferrolearn_core::error::FerroError;
48use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
49use ferrolearn_core::traits::{Fit, Transform};
50use ndarray::{Array1, Array2};
51use num_traits::Float;
52use rand::SeedableRng;
53use rand_distr::{Distribution, StandardNormal};
54
55// ---------------------------------------------------------------------------
56// Configuration enums
57// ---------------------------------------------------------------------------
58
59/// FastICA iteration strategy.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum Algorithm {
62    /// Update all unmixing directions simultaneously.
63    Parallel,
64    /// Extract one unmixing direction at a time (Gram-Schmidt orthogonalisation).
65    Deflation,
66}
67
68/// Non-linearity function used to approximate negentropy.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum NonLinearity {
71    /// `g(u) = tanh(u)`, `g'(u) = 1 - tanh²(u)`.
72    LogCosh,
73    /// `g(u) = u exp(-u²/2)`, `g'(u) = (1 - u²) exp(-u²/2)`.
74    Exp,
75    /// `g(u) = u³`, `g'(u) = 3u²`.
76    Cube,
77}
78
79// ---------------------------------------------------------------------------
80// FastICA (unfitted)
81// ---------------------------------------------------------------------------
82
83/// FastICA configuration.
84///
85/// Calling [`Fit::fit`] whitens the data and runs the FastICA algorithm,
86/// returning a [`FittedFastICA`].
87///
88/// # Type Parameters
89///
90/// - `F`: The floating-point scalar type.
91#[derive(Debug, Clone)]
92pub struct FastICA<F> {
93    /// Number of independent components to extract.
94    n_components: usize,
95    /// Iteration strategy.
96    algorithm: Algorithm,
97    /// Non-linearity function.
98    fun: NonLinearity,
99    /// Maximum number of iterations.
100    max_iter: usize,
101    /// Convergence tolerance.
102    tol: f64,
103    /// Optional random seed.
104    random_state: Option<u64>,
105    _marker: std::marker::PhantomData<F>,
106}
107
108impl<F: Float + Send + Sync + 'static> FastICA<F> {
109    /// Create a new `FastICA` that extracts `n_components` independent components.
110    ///
111    /// Defaults: `algorithm = Parallel`, `fun = LogCosh`, `max_iter = 200`,
112    /// `tol = 1e-4`, no fixed random seed.
113    #[must_use]
114    pub fn new(n_components: usize) -> Self {
115        Self {
116            n_components,
117            algorithm: Algorithm::Parallel,
118            fun: NonLinearity::LogCosh,
119            max_iter: 200,
120            tol: 1e-4,
121            random_state: None,
122            _marker: std::marker::PhantomData,
123        }
124    }
125
126    /// Set the iteration strategy.
127    #[must_use]
128    pub fn with_algorithm(mut self, algorithm: Algorithm) -> Self {
129        self.algorithm = algorithm;
130        self
131    }
132
133    /// Set the non-linearity function.
134    #[must_use]
135    pub fn with_fun(mut self, fun: NonLinearity) -> Self {
136        self.fun = fun;
137        self
138    }
139
140    /// Set the maximum number of iterations.
141    #[must_use]
142    pub fn with_max_iter(mut self, max_iter: usize) -> Self {
143        self.max_iter = max_iter;
144        self
145    }
146
147    /// Set the convergence tolerance.
148    #[must_use]
149    pub fn with_tol(mut self, tol: f64) -> Self {
150        self.tol = tol;
151        self
152    }
153
154    /// Set the random seed for reproducibility.
155    #[must_use]
156    pub fn with_random_state(mut self, seed: u64) -> Self {
157        self.random_state = Some(seed);
158        self
159    }
160
161    /// Return the number of components.
162    #[must_use]
163    pub fn n_components(&self) -> usize {
164        self.n_components
165    }
166}
167
168impl<F: Float + Send + Sync + 'static> Default for FastICA<F> {
169    fn default() -> Self {
170        Self::new(1)
171    }
172}
173
174// ---------------------------------------------------------------------------
175// FittedFastICA
176// ---------------------------------------------------------------------------
177
178/// A fitted FastICA model.
179///
180/// Created by calling [`Fit::fit`] on a [`FastICA`].
181/// Implements [`Transform<Array2<F>>`] to unmix new signals.
182#[derive(Debug, Clone)]
183pub struct FittedFastICA<F> {
184    /// Unmixing matrix (applied after whitening), shape `(n_components, n_components_white)`.
185    ///
186    /// To recover sources from whitened data: `S = unmixing @ X_white`.
187    components: Array2<F>,
188
189    /// Mixing matrix (pseudo-inverse of the unmixing), shape `(n_features, n_components)`.
190    mixing: Array2<F>,
191
192    /// Per-feature mean, shape `(n_features,)`.
193    mean: Array1<F>,
194
195    /// Whitening matrix, shape `(n_components, n_features)`.
196    whitening: Array2<F>,
197
198    /// Number of iterations performed.
199    n_iter: usize,
200
201    /// Number of features seen during fitting.
202    n_features: usize,
203}
204
205impl<F: Float + Send + Sync + 'static> FittedFastICA<F> {
206    /// Unmixing matrix applied to whitened data, shape `(n_components, n_components)`.
207    #[must_use]
208    pub fn components(&self) -> &Array2<F> {
209        &self.components
210    }
211
212    /// Mixing matrix (approximate pseudo-inverse of unmixing + whitening).
213    #[must_use]
214    pub fn mixing(&self) -> &Array2<F> {
215        &self.mixing
216    }
217
218    /// Per-feature mean learned during fitting.
219    #[must_use]
220    pub fn mean(&self) -> &Array1<F> {
221        &self.mean
222    }
223
224    /// Number of iterations performed.
225    #[must_use]
226    pub fn n_iter(&self) -> usize {
227        self.n_iter
228    }
229}
230
231// ---------------------------------------------------------------------------
232// Internal helpers
233// ---------------------------------------------------------------------------
234
235/// Apply the non-linearity `g` and its derivative `g'` element-wise.
236///
237/// Returns `(g_vals, g_prime_vals)`.
238fn apply_nonlinearity<F: Float>(u: &Array1<F>, fun: NonLinearity) -> (Array1<F>, Array1<F>) {
239    let n = u.len();
240    let mut g_vals = Array1::<F>::zeros(n);
241    let mut gp_vals = Array1::<F>::zeros(n);
242    let half = F::from(0.5).unwrap();
243    for i in 0..n {
244        let ui = u[i];
245        match fun {
246            NonLinearity::LogCosh => {
247                // g(u) = tanh(u)
248                // Use the formula: tanh(x) = (e^2x - 1)/(e^2x + 1)
249                let t = if ui > F::from(20.0).unwrap() {
250                    F::one()
251                } else if ui < F::from(-20.0).unwrap() {
252                    -F::one()
253                } else {
254                    let e2 = (ui * F::from(2.0).unwrap()).exp();
255                    (e2 - F::one()) / (e2 + F::one())
256                };
257                g_vals[i] = t;
258                gp_vals[i] = F::one() - t * t;
259            }
260            NonLinearity::Exp => {
261                // g(u) = u exp(-u²/2)
262                let neg_u2_half = -(ui * ui) * half;
263                let exp_v = neg_u2_half.exp();
264                g_vals[i] = ui * exp_v;
265                gp_vals[i] = (F::one() - ui * ui) * exp_v;
266            }
267            NonLinearity::Cube => {
268                // g(u) = u³
269                g_vals[i] = ui * ui * ui;
270                gp_vals[i] = F::from(3.0).unwrap() * ui * ui;
271            }
272        }
273    }
274    (g_vals, gp_vals)
275}
276
277/// Compute `g` and mean of `g'` for all samples.
278///
279/// `x_white_w`: the projections `W_row @ X_white`, shape `(n_samples,)`.
280/// Returns `(mean_g_prime, g_vals)` where `g_vals` has shape `(n_samples,)`.
281fn ica_step_values<F: Float>(projections: &Array1<F>, fun: NonLinearity) -> (F, Array1<F>) {
282    let (g_vals, gp_vals) = apply_nonlinearity(projections, fun);
283    let n_f = F::from(projections.len()).unwrap();
284    let mean_gp = gp_vals.iter().copied().fold(F::zero(), |a, b| a + b) / n_f;
285    (mean_gp, g_vals)
286}
287
288/// Gram-Schmidt orthogonalisation of `W` (row vectors).
289fn gs_orthogonalise<F: Float>(w: &mut Array2<F>, col: usize) {
290    let k = col;
291    // w[k] -= sum_{j<k} (w[k] . w[j]) w[j]
292    for j in 0..k {
293        let dot = (0..w.ncols())
294            .map(|d| w[[k, d]] * w[[j, d]])
295            .fold(F::zero(), |a, b| a + b);
296        for d in 0..w.ncols() {
297            let wd = w[[j, d]];
298            w[[k, d]] = w[[k, d]] - dot * wd;
299        }
300    }
301    // Normalise.
302    let norm = (0..w.ncols())
303        .map(|d| w[[k, d]] * w[[k, d]])
304        .fold(F::zero(), |a, b| a + b)
305        .sqrt();
306    if norm > F::from(1e-15).unwrap() {
307        for d in 0..w.ncols() {
308            w[[k, d]] = w[[k, d]] / norm;
309        }
310    }
311}
312
313/// Symmetric orthogonalisation: W ← (W W^T)^{-1/2} W.
314fn sym_orthogonalise<F: Float + Send + Sync + 'static>(
315    w: &mut Array2<F>,
316) -> Result<(), FerroError> {
317    let k = w.nrows();
318    // Compute S = W W^T (k × k).
319    let mut s = Array2::<F>::zeros((k, k));
320    for i in 0..k {
321        for j in 0..k {
322            let dot = (0..w.ncols())
323                .map(|d| w[[i, d]] * w[[j, d]])
324                .fold(F::zero(), |a, b| a + b);
325            s[[i, j]] = dot;
326        }
327    }
328    // Eigendecompose S = V D V^T.
329    let max_iter = k * k * 100 + 1000;
330    let (eigenvalues, eigenvectors) = jacobi_eigen_small(&s, max_iter)?;
331    // W_new = V D^{-1/2} V^T W
332    // = Σ_i (1/sqrt(d_i)) (v_i v_i^T) W
333    let mut w_new = Array2::<F>::zeros((k, w.ncols()));
334    let eps = F::from(1e-10).unwrap();
335    for i in 0..k {
336        let d = eigenvalues[i];
337        let scale = if d > eps {
338            F::one() / d.sqrt()
339        } else {
340            F::one()
341        };
342        // v_i is column i of eigenvectors.
343        // outer product: v_i v_i^T W = v_i (v_i^T W)
344        // (v_i^T W) is a row vector of shape (1, n_comp).
345        let mut vi_t_w = Array1::<F>::zeros(w.ncols());
346        for d_idx in 0..k {
347            for col in 0..w.ncols() {
348                vi_t_w[col] = vi_t_w[col] + eigenvectors[[d_idx, i]] * w[[d_idx, col]];
349            }
350        }
351        for row in 0..k {
352            for col in 0..w.ncols() {
353                w_new[[row, col]] =
354                    w_new[[row, col]] + scale * eigenvectors[[row, i]] * vi_t_w[col];
355            }
356        }
357    }
358    *w = w_new;
359    Ok(())
360}
361
362/// Jacobi eigendecomposition for a small k×k symmetric matrix.
363fn jacobi_eigen_small<F: Float + Send + Sync + 'static>(
364    a: &Array2<F>,
365    max_iter: usize,
366) -> Result<(Array1<F>, Array2<F>), FerroError> {
367    let n = a.nrows();
368    let mut mat = a.to_owned();
369    let mut v = Array2::<F>::zeros((n, n));
370    for i in 0..n {
371        v[[i, i]] = F::one();
372    }
373    let tol = F::from(1e-12).unwrap_or(F::epsilon());
374    let two = F::from(2.0).unwrap();
375    for _ in 0..max_iter {
376        let mut max_off = F::zero();
377        let mut p = 0;
378        let mut q = 1;
379        for i in 0..n {
380            for j in (i + 1)..n {
381                let val = mat[[i, j]].abs();
382                if val > max_off {
383                    max_off = val;
384                    p = i;
385                    q = j;
386                }
387            }
388        }
389        if max_off < tol {
390            let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
391            return Ok((eigenvalues, v));
392        }
393        let app = mat[[p, p]];
394        let aqq = mat[[q, q]];
395        let apq = mat[[p, q]];
396        let theta = if (app - aqq).abs() < tol {
397            F::from(std::f64::consts::FRAC_PI_4).unwrap_or(F::one())
398        } else {
399            let tau = (aqq - app) / (two * apq);
400            let t = if tau >= F::zero() {
401                F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
402            } else {
403                -F::one() / (tau.abs() + (F::one() + tau * tau).sqrt())
404            };
405            t.atan()
406        };
407        let c = theta.cos();
408        let s = theta.sin();
409        let mut new_mat = mat.clone();
410        for i in 0..n {
411            if i != p && i != q {
412                let mip = mat[[i, p]];
413                let miq = mat[[i, q]];
414                new_mat[[i, p]] = c * mip - s * miq;
415                new_mat[[p, i]] = new_mat[[i, p]];
416                new_mat[[i, q]] = s * mip + c * miq;
417                new_mat[[q, i]] = new_mat[[i, q]];
418            }
419        }
420        new_mat[[p, p]] = c * c * app - two * s * c * apq + s * s * aqq;
421        new_mat[[q, q]] = s * s * app + two * s * c * apq + c * c * aqq;
422        new_mat[[p, q]] = F::zero();
423        new_mat[[q, p]] = F::zero();
424        mat = new_mat;
425        for i in 0..n {
426            let vip = v[[i, p]];
427            let viq = v[[i, q]];
428            v[[i, p]] = c * vip - s * viq;
429            v[[i, q]] = s * vip + c * viq;
430        }
431    }
432    // Didn't fully converge, but return best estimate.
433    let eigenvalues = Array1::from_shape_fn(n, |i| mat[[i, i]]);
434    Ok((eigenvalues, v))
435}
436
437// ---------------------------------------------------------------------------
438// Fit
439// ---------------------------------------------------------------------------
440
441impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for FastICA<F> {
442    type Fitted = FittedFastICA<F>;
443    type Error = FerroError;
444
445    /// Fit FastICA to data.
446    ///
447    /// # Errors
448    ///
449    /// - [`FerroError::InvalidParameter`] if `n_components` is zero or
450    ///   exceeds `n_features`.
451    /// - [`FerroError::InsufficientSamples`] if fewer than 2 samples are provided.
452    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedFastICA<F>, FerroError> {
453        let (n_samples, n_features) = x.dim();
454
455        if self.n_components == 0 {
456            return Err(FerroError::InvalidParameter {
457                name: "n_components".into(),
458                reason: "must be at least 1".into(),
459            });
460        }
461        if self.n_components > n_features {
462            return Err(FerroError::InvalidParameter {
463                name: "n_components".into(),
464                reason: format!(
465                    "n_components ({}) exceeds n_features ({})",
466                    self.n_components, n_features
467                ),
468            });
469        }
470        if n_samples < 2 {
471            return Err(FerroError::InsufficientSamples {
472                required: 2,
473                actual: n_samples,
474                context: "FastICA requires at least 2 samples".into(),
475            });
476        }
477
478        let k = self.n_components;
479        let n_f = F::from(n_samples).unwrap();
480
481        // --- Step 1: Centre --------------------------------------------------
482        let mut mean = Array1::<F>::zeros(n_features);
483        for j in 0..n_features {
484            let s = x.column(j).iter().copied().fold(F::zero(), |a, b| a + b);
485            mean[j] = s / n_f;
486        }
487        let mut xc = x.to_owned();
488        for mut row in xc.rows_mut() {
489            for (v, &m) in row.iter_mut().zip(mean.iter()) {
490                *v = *v - m;
491            }
492        }
493
494        // --- Step 2: Whiten (PCA) -------------------------------------------
495        // Covariance matrix C = X_c^T X_c / n  (n_features × n_features)
496        let cov = xc.t().dot(&xc).mapv(|v| v / n_f);
497
498        // Eigendecompose C.
499        let max_jacobi = n_features * n_features * 100 + 1000;
500        let (eigenvalues, eigenvectors) = jacobi_eigen_small(&cov, max_jacobi)?;
501
502        // Sort descending.
503        let mut indices: Vec<usize> = (0..n_features).collect();
504        indices.sort_by(|&a, &b| {
505            eigenvalues[b]
506                .partial_cmp(&eigenvalues[a])
507                .unwrap_or(std::cmp::Ordering::Equal)
508        });
509
510        // Build whitening matrix K: k × n_features.
511        // K[i, :] = eigenvectors[:, indices[i]] / sqrt(eigenvalues[indices[i]])
512        let eps = F::from(1e-10).unwrap();
513        let mut whitening = Array2::<F>::zeros((k, n_features));
514        for i in 0..k {
515            let idx = indices[i];
516            let ev = eigenvalues[idx];
517            let scale = if ev > eps {
518                F::one() / ev.sqrt()
519            } else {
520                F::zero()
521            };
522            for j in 0..n_features {
523                whitening[[i, j]] = eigenvectors[[j, idx]] * scale;
524            }
525        }
526
527        // Whitened data X_w = K @ X_c^T  (k × n_samples), then transpose to n × k.
528        let x_white_t = whitening.dot(&xc.t()); // k × n
529        let x_white = x_white_t.t().to_owned(); // n × k
530
531        // --- Step 3: FastICA -------------------------------------------------
532        let seed = self.random_state.unwrap_or(42);
533        let mut rng = rand_xoshiro::Xoshiro256PlusPlus::seed_from_u64(seed);
534        let std_normal = StandardNormal;
535
536        // Initialise W as a k × k random matrix (rows are unmixing directions).
537        let mut w = Array2::<F>::zeros((k, k));
538        for i in 0..k {
539            for j in 0..k {
540                let v: f64 = std_normal.sample(&mut rng);
541                w[[i, j]] = F::from(v).unwrap();
542            }
543        }
544        // Orthogonalise initial W.
545        sym_orthogonalise(&mut w)?;
546
547        let tol_f = F::from(self.tol).unwrap();
548        let mut n_iter = 0usize;
549
550        match self.algorithm {
551            Algorithm::Parallel => {
552                for iter in 0..self.max_iter {
553                    let mut w_new = Array2::<F>::zeros((k, k));
554                    // For each component i, update using all samples.
555                    for i in 0..k {
556                        // Projection: u = X_w @ w[i]  (n_samples,)
557                        let w_row: Array1<F> = w.row(i).to_owned();
558                        let u: Array1<F> = x_white.dot(&w_row);
559                        let (mean_gp, g_vals) = ica_step_values(&u, self.fun);
560                        // w_new[i] = (1/n) X_w^T g(u) - mean_g' * w[i]
561                        // X_w^T g(u) = sum_t x_w[t] g(u[t])  (k-vector)
562                        let mut xw_t_g = Array1::<F>::zeros(k);
563                        for t in 0..n_samples {
564                            for d in 0..k {
565                                xw_t_g[d] = xw_t_g[d] + x_white[[t, d]] * g_vals[t];
566                            }
567                        }
568                        for d in 0..k {
569                            xw_t_g[d] = xw_t_g[d] / n_f;
570                        }
571                        for d in 0..k {
572                            w_new[[i, d]] = xw_t_g[d] - mean_gp * w_row[d];
573                        }
574                    }
575                    // Symmetric orthogonalisation.
576                    sym_orthogonalise(&mut w_new)?;
577
578                    // Convergence: max |1 - |w_new[i] . w[i]||
579                    let mut max_change = F::zero();
580                    for i in 0..k {
581                        let dot: F = (0..k)
582                            .map(|d| w_new[[i, d]] * w[[i, d]])
583                            .fold(F::zero(), |a, b| a + b);
584                        let change = (F::one() - dot.abs()).abs();
585                        if change > max_change {
586                            max_change = change;
587                        }
588                    }
589                    w = w_new;
590                    n_iter = iter + 1;
591                    if max_change < tol_f {
592                        break;
593                    }
594                }
595            }
596            Algorithm::Deflation => {
597                for i in 0..k {
598                    for iter in 0..self.max_iter {
599                        // Projection: u = X_w @ w[i]  (n_samples,)
600                        let w_row: Array1<F> = w.row(i).to_owned();
601                        let u: Array1<F> = x_white.dot(&w_row);
602                        let (mean_gp, g_vals) = ica_step_values(&u, self.fun);
603                        // w_new = (1/n) X_w^T g(u) - mean_g' * w[i]
604                        let mut w_new_row = Array1::<F>::zeros(k);
605                        for t in 0..n_samples {
606                            for d in 0..k {
607                                w_new_row[d] = w_new_row[d] + x_white[[t, d]] * g_vals[t];
608                            }
609                        }
610                        for d in 0..k {
611                            w_new_row[d] = w_new_row[d] / n_f - mean_gp * w_row[d];
612                        }
613                        // Gram-Schmidt orthogonalisation.
614                        for j in 0..i {
615                            let dot: F = (0..k)
616                                .map(|d| w_new_row[d] * w[[j, d]])
617                                .fold(F::zero(), |a, b| a + b);
618                            for d in 0..k {
619                                let wd = w[[j, d]];
620                                w_new_row[d] = w_new_row[d] - dot * wd;
621                            }
622                        }
623                        // Normalise.
624                        let norm = w_new_row
625                            .iter()
626                            .copied()
627                            .map(|v| v * v)
628                            .fold(F::zero(), |a, b| a + b)
629                            .sqrt();
630                        if norm > F::from(1e-15).unwrap() {
631                            w_new_row.mapv_inplace(|v| v / norm);
632                        }
633                        // Convergence: |1 - |w_new . w_old||
634                        let dot: F = (0..k)
635                            .map(|d| w_new_row[d] * w_row[d])
636                            .fold(F::zero(), |a, b| a + b);
637                        let change = (F::one() - dot.abs()).abs();
638                        for d in 0..k {
639                            w[[i, d]] = w_new_row[d];
640                        }
641                        n_iter = iter + 1;
642                        if change < tol_f {
643                            break;
644                        }
645                    }
646                    // Gram-Schmidt after finalising component i.
647                    gs_orthogonalise(&mut w, i);
648                }
649            }
650        }
651
652        // --- Mixing matrix ---------------------------------------------------
653        // The full unmixing pipeline is: s = W @ K @ (x - mean)
654        // where K is the whitening matrix (k × n_features), W is k × k.
655        // The mixing matrix M satisfies s ≈ W K x_c, so x_c ≈ K^T W^T s (Moore-Penrose pseudo-inverse).
656        // mixing = K^T W^T  (n_features × k)
657        let mixing = whitening.t().dot(&w.t()); // n_features × k
658
659        Ok(FittedFastICA {
660            components: w,
661            mixing,
662            mean,
663            whitening,
664            n_iter,
665            n_features,
666        })
667    }
668}
669
670// ---------------------------------------------------------------------------
671// Transform
672// ---------------------------------------------------------------------------
673
674impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedFastICA<F> {
675    type Output = Array2<F>;
676    type Error = FerroError;
677
678    /// Unmix new signals: `S = (W @ K @ (X - mean)^T)^T`.
679    ///
680    /// Returns an array of shape `(n_samples, n_components)`.
681    ///
682    /// # Errors
683    ///
684    /// Returns [`FerroError::ShapeMismatch`] if the number of features does not
685    /// match the model.
686    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
687        if x.ncols() != self.n_features {
688            return Err(FerroError::ShapeMismatch {
689                expected: vec![x.nrows(), self.n_features],
690                actual: vec![x.nrows(), x.ncols()],
691                context: "FittedFastICA::transform".into(),
692            });
693        }
694        // Centre.
695        let mut xc = x.to_owned();
696        for mut row in xc.rows_mut() {
697            for (v, &m) in row.iter_mut().zip(self.mean.iter()) {
698                *v = *v - m;
699            }
700        }
701        // Whiten: X_w = K @ X_c^T  (k × n), transpose to n × k.
702        let x_white = self.whitening.dot(&xc.t()).t().to_owned(); // n × k
703        // Unmix: S = (W @ X_w^T)^T = X_w @ W^T  (n × k)
704        let sources = x_white.dot(&self.components.t());
705        Ok(sources)
706    }
707}
708
709// ---------------------------------------------------------------------------
710// Pipeline integration (generic)
711// ---------------------------------------------------------------------------
712
713impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FastICA<F> {
714    /// Fit using the pipeline interface (ignores `y`).
715    ///
716    /// # Errors
717    ///
718    /// Propagates errors from [`Fit::fit`].
719    fn fit_pipeline(
720        &self,
721        x: &Array2<F>,
722        _y: &Array1<F>,
723    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
724        let fitted = self.fit(x, &())?;
725        Ok(Box::new(fitted))
726    }
727}
728
729impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F> for FittedFastICA<F> {
730    /// Transform via the pipeline interface.
731    ///
732    /// # Errors
733    ///
734    /// Propagates errors from [`Transform::transform`].
735    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
736        self.transform(x)
737    }
738}
739
740// ---------------------------------------------------------------------------
741// Tests
742// ---------------------------------------------------------------------------
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747    use approx::assert_abs_diff_eq;
748    use ndarray::Array2;
749
750    fn mixed_signals() -> Array2<f64> {
751        // Two synthetic source signals, then mixed.
752        let n = 50;
753        let mut x = Array2::<f64>::zeros((n, 2));
754        for i in 0..n {
755            let t = i as f64 * 0.2;
756            // source 1: sine wave, source 2: sawtooth
757            let s1 = t.sin();
758            let s2 = (t * 0.5).cos();
759            // mixing matrix
760            x[[i, 0]] = 0.5 * s1 + 0.5 * s2;
761            x[[i, 1]] = 0.2 * s1 + 0.8 * s2;
762        }
763        x
764    }
765
766    #[test]
767    fn test_ica_fit_returns_fitted() {
768        let ica = FastICA::<f64>::new(2).with_random_state(0);
769        let x = mixed_signals();
770        let fitted = ica.fit(&x, &()).unwrap();
771        assert_eq!(fitted.components().dim(), (2, 2));
772    }
773
774    #[test]
775    fn test_ica_transform_shape() {
776        let ica = FastICA::<f64>::new(2).with_random_state(0);
777        let x = mixed_signals();
778        let fitted = ica.fit(&x, &()).unwrap();
779        let sources = fitted.transform(&x).unwrap();
780        assert_eq!(sources.dim(), (50, 2));
781    }
782
783    #[test]
784    fn test_ica_parallel_algorithm() {
785        let ica = FastICA::<f64>::new(2)
786            .with_algorithm(Algorithm::Parallel)
787            .with_random_state(1);
788        let x = mixed_signals();
789        let fitted = ica.fit(&x, &()).unwrap();
790        assert_eq!(fitted.components().nrows(), 2);
791    }
792
793    #[test]
794    fn test_ica_deflation_algorithm() {
795        let ica = FastICA::<f64>::new(2)
796            .with_algorithm(Algorithm::Deflation)
797            .with_random_state(2);
798        let x = mixed_signals();
799        let fitted = ica.fit(&x, &()).unwrap();
800        assert_eq!(fitted.components().nrows(), 2);
801    }
802
803    #[test]
804    fn test_ica_logcosh() {
805        let ica = FastICA::<f64>::new(2)
806            .with_fun(NonLinearity::LogCosh)
807            .with_random_state(3);
808        let x = mixed_signals();
809        let fitted = ica.fit(&x, &()).unwrap();
810        let s = fitted.transform(&x).unwrap();
811        assert_eq!(s.ncols(), 2);
812    }
813
814    #[test]
815    fn test_ica_exp() {
816        let ica = FastICA::<f64>::new(2)
817            .with_fun(NonLinearity::Exp)
818            .with_random_state(4);
819        let x = mixed_signals();
820        let fitted = ica.fit(&x, &()).unwrap();
821        let s = fitted.transform(&x).unwrap();
822        assert_eq!(s.ncols(), 2);
823    }
824
825    #[test]
826    fn test_ica_cube() {
827        let ica = FastICA::<f64>::new(2)
828            .with_fun(NonLinearity::Cube)
829            .with_random_state(5);
830        let x = mixed_signals();
831        let fitted = ica.fit(&x, &()).unwrap();
832        let s = fitted.transform(&x).unwrap();
833        assert_eq!(s.ncols(), 2);
834    }
835
836    #[test]
837    fn test_ica_n_iter_positive() {
838        let ica = FastICA::<f64>::new(2).with_random_state(0);
839        let x = mixed_signals();
840        let fitted = ica.fit(&x, &()).unwrap();
841        assert!(fitted.n_iter() >= 1);
842    }
843
844    #[test]
845    fn test_ica_mixing_shape() {
846        let ica = FastICA::<f64>::new(2).with_random_state(0);
847        let x = mixed_signals();
848        let fitted = ica.fit(&x, &()).unwrap();
849        assert_eq!(fitted.mixing().dim(), (2, 2));
850    }
851
852    #[test]
853    fn test_ica_mean_shape() {
854        let ica = FastICA::<f64>::new(2).with_random_state(0);
855        let x = mixed_signals();
856        let fitted = ica.fit(&x, &()).unwrap();
857        assert_eq!(fitted.mean().len(), 2);
858    }
859
860    #[test]
861    fn test_ica_transform_shape_mismatch() {
862        let ica = FastICA::<f64>::new(2).with_random_state(0);
863        let x = mixed_signals();
864        let fitted = ica.fit(&x, &()).unwrap();
865        let x_bad = Array2::<f64>::zeros((3, 5));
866        assert!(fitted.transform(&x_bad).is_err());
867    }
868
869    #[test]
870    fn test_ica_error_zero_components() {
871        let ica = FastICA::<f64>::new(0);
872        let x = mixed_signals();
873        assert!(ica.fit(&x, &()).is_err());
874    }
875
876    #[test]
877    fn test_ica_error_too_many_components() {
878        let ica = FastICA::<f64>::new(10); // n_features = 2
879        let x = mixed_signals();
880        assert!(ica.fit(&x, &()).is_err());
881    }
882
883    #[test]
884    fn test_ica_error_insufficient_samples() {
885        let ica = FastICA::<f64>::new(1);
886        let x = Array2::<f64>::zeros((1, 2));
887        assert!(ica.fit(&x, &()).is_err());
888    }
889
890    #[test]
891    fn test_ica_single_component() {
892        let ica = FastICA::<f64>::new(1).with_random_state(0);
893        let x = mixed_signals();
894        let fitted = ica.fit(&x, &()).unwrap();
895        let s = fitted.transform(&x).unwrap();
896        assert_eq!(s.dim(), (50, 1));
897    }
898
899    #[test]
900    fn test_ica_sources_not_all_zero() {
901        let ica = FastICA::<f64>::new(2).with_random_state(0);
902        let x = mixed_signals();
903        let fitted = ica.fit(&x, &()).unwrap();
904        let s = fitted.transform(&x).unwrap();
905        let total: f64 = s.iter().map(|v| v.abs()).sum();
906        assert!(total > 0.0);
907    }
908
909    #[test]
910    fn test_ica_reproducible_with_seed() {
911        let ica1 = FastICA::<f64>::new(2).with_random_state(7);
912        let ica2 = FastICA::<f64>::new(2).with_random_state(7);
913        let x = mixed_signals();
914        let f1 = ica1.fit(&x, &()).unwrap();
915        let f2 = ica2.fit(&x, &()).unwrap();
916        for (a, b) in f1.components().iter().zip(f2.components().iter()) {
917            assert_abs_diff_eq!(a, b, epsilon = 1e-12);
918        }
919    }
920
921    #[test]
922    fn test_ica_pipeline_transformer() {
923        use ferrolearn_core::pipeline::PipelineTransformer;
924        let ica = FastICA::<f64>::new(2).with_random_state(0);
925        let x = mixed_signals();
926        let y = Array1::<f64>::zeros(50);
927        let fitted = ica.fit_pipeline(&x, &y).unwrap();
928        let out = fitted.transform_pipeline(&x).unwrap();
929        assert_eq!(out.ncols(), 2);
930    }
931
932    #[test]
933    fn test_ica_n_components_getter() {
934        let ica = FastICA::<f64>::new(3);
935        assert_eq!(ica.n_components(), 3);
936    }
937
938    #[test]
939    fn test_ica_nonlinearity_values() {
940        // Check g(0) = 0 for all non-linearities.
941        let u = Array1::from_vec(vec![0.0f64]);
942        let (g_lc, _) = apply_nonlinearity(&u, NonLinearity::LogCosh);
943        let (g_exp, _) = apply_nonlinearity(&u, NonLinearity::Exp);
944        let (g_cube, _) = apply_nonlinearity(&u, NonLinearity::Cube);
945        assert_abs_diff_eq!(g_lc[0], 0.0, epsilon = 1e-10);
946        assert_abs_diff_eq!(g_exp[0], 0.0, epsilon = 1e-10);
947        assert_abs_diff_eq!(g_cube[0], 0.0, epsilon = 1e-10);
948    }
949}