kizzasi_io/signal/
separation.rs

1//! Blind source separation algorithms
2//!
3//! This module provides state-of-the-art techniques for separating mixed signals:
4//! - FastICA (Independent Component Analysis)
5//! - NMF (Non-negative Matrix Factorization)
6//! - PCA (Principal Component Analysis)
7//! - Temporal decorrelation methods
8
9use crate::error::{IoError, IoResult};
10use scirs2_core::ndarray::{Array1, Array2, Axis};
11use scirs2_core::random::thread_rng;
12
13/// FastICA algorithm for Independent Component Analysis
14///
15/// Separates mixed signals into statistically independent components
16/// using higher-order statistics (non-Gaussianity).
17pub struct FastICA {
18    /// Number of components to extract
19    n_components: usize,
20    /// Maximum iterations
21    max_iter: usize,
22    /// Convergence tolerance
23    tolerance: f32,
24    /// Nonlinearity function type
25    nonlinearity: Nonlinearity,
26}
27
28/// Nonlinearity function for ICA
29#[derive(Debug, Clone, Copy)]
30pub enum Nonlinearity {
31    /// Logcosh (default, robust)
32    LogCosh,
33    /// Exponential (fast convergence)
34    Exp,
35    /// Cubic (simple, less robust)
36    Cube,
37}
38
39impl FastICA {
40    /// Create a new FastICA analyzer
41    ///
42    /// # Arguments
43    /// * `n_components` - Number of independent components to extract
44    /// * `max_iter` - Maximum iterations (default 200)
45    /// * `tolerance` - Convergence tolerance (default 1e-4)
46    pub fn new(n_components: usize, max_iter: Option<usize>, tolerance: Option<f32>) -> Self {
47        Self {
48            n_components,
49            max_iter: max_iter.unwrap_or(200),
50            tolerance: tolerance.unwrap_or(1e-4),
51            nonlinearity: Nonlinearity::LogCosh,
52        }
53    }
54
55    /// Set nonlinearity function
56    pub fn with_nonlinearity(mut self, nonlinearity: Nonlinearity) -> Self {
57        self.nonlinearity = nonlinearity;
58        self
59    }
60
61    /// Fit and transform mixed signals to independent components
62    ///
63    /// # Arguments
64    /// * `mixed` - Mixed signals (n_samples × n_signals)
65    ///
66    /// # Returns
67    /// * Independent components (n_samples × n_components)
68    /// * Unmixing matrix (n_components × n_signals)
69    pub fn fit_transform(&self, mixed: &Array2<f32>) -> IoResult<(Array2<f32>, Array2<f32>)> {
70        let (n_samples, n_signals) = mixed.dim();
71
72        if n_samples < 2 || n_signals < 2 {
73            return Err(IoError::SignalError(
74                "Need at least 2 samples and 2 signals".into(),
75            ));
76        }
77
78        if self.n_components > n_signals {
79            return Err(IoError::SignalError(
80                "n_components cannot exceed n_signals".into(),
81            ));
82        }
83
84        // 1. Center the data
85        let mean = mixed
86            .mean_axis(Axis(0))
87            .expect("Mean axis computation must succeed");
88        let centered = mixed - &mean.view().insert_axis(Axis(0));
89
90        // 2. Whiten the data using PCA
91        let (whitened, whitening_matrix) = Self::whiten(&centered)?;
92
93        // 3. FastICA algorithm
94        let unmixing = self.fastica_core(&whitened)?;
95
96        // 4. Compute sources
97        let sources = whitened.dot(&unmixing.t());
98
99        // 5. Compute full unmixing matrix
100        let full_unmixing = unmixing.dot(&whitening_matrix);
101
102        Ok((sources, full_unmixing))
103    }
104
105    /// Core FastICA algorithm
106    fn fastica_core(&self, whitened: &Array2<f32>) -> IoResult<Array2<f32>> {
107        let n_components = self.n_components;
108        let mut rng = thread_rng();
109
110        // Initialize unmixing matrix randomly
111        let mut w = Array2::from_shape_fn((n_components, whitened.ncols()), |_| {
112            rng.gen_range(-1.0..1.0)
113        });
114
115        // Orthogonalize
116        Self::gram_schmidt(&mut w);
117
118        // Iterate until convergence
119        for _iter in 0..self.max_iter {
120            let w_old = w.clone();
121
122            // Update each component
123            for i in 0..n_components {
124                let mut w_i = w.row(i).to_owned();
125
126                // Compute E{x g(w^T x)} and E{g'(w^T x)}
127                let wx = whitened.dot(&w_i);
128                let (g_wx, gp_wx) = self.apply_nonlinearity(&wx);
129
130                let eg = whitened.t().dot(&g_wx) / whitened.nrows() as f32;
131                let egp = gp_wx.mean().unwrap();
132
133                // Newton update: w = E{x g(w^T x)} - E{g'(w^T x)} w
134                w_i = eg - &w_i * egp;
135
136                // Store updated component
137                for (j, val) in w_i.iter().enumerate() {
138                    w[[i, j]] = *val;
139                }
140
141                // Orthogonalize against previous components
142                for j in 0..i {
143                    let w_j = w.row(j).to_owned();
144                    let dot = w_i.dot(&w_j);
145                    w_i = w_i - &w_j * dot;
146                }
147
148                // Normalize
149                let norm = w_i.iter().map(|x| x * x).sum::<f32>().sqrt();
150                if norm > 1e-10 {
151                    w_i /= norm;
152                }
153
154                // Update row
155                for (j, val) in w_i.iter().enumerate() {
156                    w[[i, j]] = *val;
157                }
158            }
159
160            // Check convergence
161            let mut max_diff = 0.0f32;
162            for i in 0..n_components {
163                for j in 0..w.ncols() {
164                    let diff = (w[[i, j]] - w_old[[i, j]]).abs();
165                    max_diff = max_diff.max(diff);
166                }
167            }
168
169            if max_diff < self.tolerance {
170                break;
171            }
172        }
173
174        Ok(w)
175    }
176
177    /// Apply nonlinearity function and its derivative
178    fn apply_nonlinearity(&self, x: &Array1<f32>) -> (Array1<f32>, Array1<f32>) {
179        match self.nonlinearity {
180            Nonlinearity::LogCosh => {
181                let alpha = 1.0;
182                let g = x.mapv(|v| (alpha * v).tanh());
183                let gp = x.mapv(|v| alpha * (1.0 - (alpha * v).tanh().powi(2)));
184                (g, gp)
185            }
186            Nonlinearity::Exp => {
187                let g = x.mapv(|v| v * (-v * v / 2.0).exp());
188                let gp = x.mapv(|v| (1.0 - v * v) * (-v * v / 2.0).exp());
189                (g, gp)
190            }
191            Nonlinearity::Cube => {
192                let g = x.mapv(|v| v.powi(3));
193                let gp = x.mapv(|v| 3.0 * v * v);
194                (g, gp)
195            }
196        }
197    }
198
199    /// Whiten data using PCA
200    fn whiten(data: &Array2<f32>) -> IoResult<(Array2<f32>, Array2<f32>)> {
201        let n_samples = data.nrows();
202
203        // Compute covariance matrix
204        let cov = data.t().dot(data) / n_samples as f32;
205
206        // Simplified eigenvalue decomposition (for small matrices)
207        // In production, use proper eigendecomposition from linalg crate
208        let (eigenvalues, eigenvectors) = Self::simple_eig(&cov)?;
209
210        // Compute whitening matrix: W = D^{-1/2} E^T
211        let mut whitening = eigenvectors.t().to_owned();
212        for i in 0..eigenvalues.len() {
213            let scale = 1.0 / (eigenvalues[i].max(1e-10).sqrt());
214            for j in 0..whitening.ncols() {
215                whitening[[i, j]] *= scale;
216            }
217        }
218
219        // Whiten data
220        let whitened = data.dot(&whitening.t());
221
222        Ok((whitened, whitening))
223    }
224
225    /// Simplified eigenvalue decomposition (power iteration)
226    fn simple_eig(matrix: &Array2<f32>) -> IoResult<(Vec<f32>, Array2<f32>)> {
227        let n = matrix.nrows();
228        let mut eigenvalues = Vec::new();
229        let mut eigenvectors = Array2::zeros((n, n));
230        let mut remaining = matrix.clone();
231
232        for k in 0..n {
233            // Power iteration to find largest eigenvalue/eigenvector
234            let mut v = Array1::from_shape_fn(n, |_| thread_rng().gen_range(-1.0..1.0));
235            v = &v / v.iter().map(|x| x * x).sum::<f32>().sqrt();
236
237            for _ in 0..100 {
238                let v_new = remaining.dot(&v);
239                let norm = v_new.iter().map(|x| x * x).sum::<f32>().sqrt();
240                if norm < 1e-10 {
241                    break;
242                }
243                v = &v_new / norm;
244            }
245
246            // Compute eigenvalue
247            let av = remaining.dot(&v);
248            let eigenvalue = av.dot(&v);
249            eigenvalues.push(eigenvalue);
250
251            // Store eigenvector
252            for i in 0..n {
253                eigenvectors[[i, k]] = v[i];
254            }
255
256            // Deflate matrix
257            let vv = v
258                .clone()
259                .insert_axis(Axis(1))
260                .dot(&v.clone().insert_axis(Axis(0)));
261            remaining = &remaining - &(&vv * eigenvalue);
262        }
263
264        Ok((eigenvalues, eigenvectors))
265    }
266
267    /// Gram-Schmidt orthogonalization
268    fn gram_schmidt(matrix: &mut Array2<f32>) {
269        let n_rows = matrix.nrows();
270
271        for i in 0..n_rows {
272            // Orthogonalize against previous rows
273            for j in 0..i {
274                let dot: f32 = (0..matrix.ncols())
275                    .map(|k| matrix[[i, k]] * matrix[[j, k]])
276                    .sum();
277
278                for k in 0..matrix.ncols() {
279                    matrix[[i, k]] -= dot * matrix[[j, k]];
280                }
281            }
282
283            // Normalize
284            let norm: f32 = (0..matrix.ncols())
285                .map(|k| matrix[[i, k]] * matrix[[i, k]])
286                .sum::<f32>()
287                .sqrt();
288
289            if norm > 1e-10 {
290                for k in 0..matrix.ncols() {
291                    matrix[[i, k]] /= norm;
292                }
293            }
294        }
295    }
296}
297
298/// Non-negative Matrix Factorization
299///
300/// Factorizes a non-negative matrix into two non-negative matrices
301/// using multiplicative update rules.
302pub struct NMF {
303    /// Number of components
304    n_components: usize,
305    /// Maximum iterations
306    max_iter: usize,
307    /// Convergence tolerance
308    tolerance: f32,
309}
310
311impl NMF {
312    /// Create a new NMF analyzer
313    ///
314    /// # Arguments
315    /// * `n_components` - Number of components (rank)
316    /// * `max_iter` - Maximum iterations (default 200)
317    /// * `tolerance` - Convergence tolerance (default 1e-4)
318    pub fn new(n_components: usize, max_iter: Option<usize>, tolerance: Option<f32>) -> Self {
319        Self {
320            n_components,
321            max_iter: max_iter.unwrap_or(200),
322            tolerance: tolerance.unwrap_or(1e-4),
323        }
324    }
325
326    /// Factorize matrix V ≈ W H
327    ///
328    /// # Arguments
329    /// * `v` - Non-negative matrix (n_samples × n_features)
330    ///
331    /// # Returns
332    /// * W - Basis matrix (n_samples × n_components)
333    /// * H - Coefficient matrix (n_components × n_features)
334    pub fn fit_transform(&self, v: &Array2<f32>) -> IoResult<(Array2<f32>, Array2<f32>)> {
335        let (n_samples, n_features) = v.dim();
336
337        if n_samples < 1 || n_features < 1 {
338            return Err(IoError::SignalError("Empty matrix".into()));
339        }
340
341        // Check non-negativity
342        if v.iter().any(|&x| x < 0.0) {
343            return Err(IoError::SignalError("Matrix must be non-negative".into()));
344        }
345
346        // Initialize W and H randomly
347        let mut rng = thread_rng();
348        let mut w =
349            Array2::from_shape_fn((n_samples, self.n_components), |_| rng.gen_range(0.0..1.0));
350        let mut h =
351            Array2::from_shape_fn((self.n_components, n_features), |_| rng.gen_range(0.0..1.0));
352
353        let eps = 1e-10;
354        let mut prev_error = f32::MAX;
355
356        // Multiplicative update rules
357        for _iter in 0..self.max_iter {
358            // Update H: H = H .* (W^T V) ./ (W^T W H + eps)
359            let wt_v = w.t().dot(v);
360            let wt_w_h = w.t().dot(&w).dot(&h);
361
362            for i in 0..h.nrows() {
363                for j in 0..h.ncols() {
364                    h[[i, j]] *= wt_v[[i, j]] / (wt_w_h[[i, j]] + eps);
365                }
366            }
367
368            // Update W: W = W .* (V H^T) ./ (W H H^T + eps)
369            let v_ht = v.dot(&h.t());
370            let w_h_ht = w.dot(&h).dot(&h.t());
371
372            for i in 0..w.nrows() {
373                for j in 0..w.ncols() {
374                    w[[i, j]] *= v_ht[[i, j]] / (w_h_ht[[i, j]] + eps);
375                }
376            }
377
378            // Compute reconstruction error
379            let wh = w.dot(&h);
380            let error: f32 = v
381                .iter()
382                .zip(wh.iter())
383                .map(|(&a, &b)| (a - b).powi(2))
384                .sum::<f32>()
385                .sqrt();
386
387            // Check convergence
388            if (prev_error - error).abs() < self.tolerance {
389                break;
390            }
391            prev_error = error;
392        }
393
394        Ok((w, h))
395    }
396
397    /// Transform new data using fitted model
398    ///
399    /// # Arguments
400    /// * `v` - New data matrix
401    /// * `h` - Previously fitted coefficient matrix
402    pub fn transform(&self, v: &Array2<f32>, h: &Array2<f32>) -> IoResult<Array2<f32>> {
403        let (n_samples, _n_features) = v.dim();
404        let mut rng = thread_rng();
405
406        // Initialize W randomly
407        let mut w =
408            Array2::from_shape_fn((n_samples, self.n_components), |_| rng.gen_range(0.0..1.0));
409
410        let eps = 1e-10;
411
412        // Fix H, update only W
413        for _ in 0..self.max_iter {
414            let v_ht = v.dot(&h.t());
415            let w_h_ht = w.dot(h).dot(&h.t());
416
417            for i in 0..w.nrows() {
418                for j in 0..w.ncols() {
419                    w[[i, j]] *= v_ht[[i, j]] / (w_h_ht[[i, j]] + eps);
420                }
421            }
422        }
423
424        Ok(w)
425    }
426}
427
428/// Principal Component Analysis
429///
430/// Reduces dimensionality by projecting onto principal components
431/// (directions of maximum variance).
432pub struct PCA {
433    /// Number of components to keep
434    n_components: usize,
435}
436
437impl PCA {
438    /// Create a new PCA analyzer
439    pub fn new(n_components: usize) -> Self {
440        Self { n_components }
441    }
442
443    /// Fit and transform data to principal components
444    ///
445    /// # Arguments
446    /// * `data` - Input data (n_samples × n_features)
447    ///
448    /// # Returns
449    /// * Transformed data (n_samples × n_components)
450    /// * Principal components (n_components × n_features)
451    /// * Explained variance
452    pub fn fit_transform(
453        &self,
454        data: &Array2<f32>,
455    ) -> IoResult<(Array2<f32>, Array2<f32>, Vec<f32>)> {
456        let (n_samples, n_features) = data.dim();
457
458        if self.n_components > n_features {
459            return Err(IoError::SignalError(
460                "n_components cannot exceed n_features".into(),
461            ));
462        }
463
464        // Center the data
465        let mean = data
466            .mean_axis(Axis(0))
467            .expect("Mean axis computation must succeed");
468        let centered = data - &mean.view().insert_axis(Axis(0));
469
470        // Compute covariance matrix
471        let cov = centered.t().dot(&centered) / n_samples as f32;
472
473        // Eigendecomposition
474        let (eigenvalues, eigenvectors) = FastICA::simple_eig(&cov)?;
475
476        // Sort by eigenvalue (descending)
477        let mut indices: Vec<usize> = (0..eigenvalues.len()).collect();
478        indices.sort_by(|&a, &b| {
479            eigenvalues[b]
480                .partial_cmp(&eigenvalues[a])
481                .unwrap_or(std::cmp::Ordering::Equal)
482        });
483
484        // Select top n_components
485        let mut components = Array2::zeros((self.n_components, n_features));
486        let mut explained_var = Vec::new();
487
488        for i in 0..self.n_components {
489            let idx = indices[i];
490            explained_var.push(eigenvalues[idx]);
491
492            for j in 0..n_features {
493                components[[i, j]] = eigenvectors[[j, idx]];
494            }
495        }
496
497        // Transform data
498        let transformed = centered.dot(&components.t());
499
500        Ok((transformed, components, explained_var))
501    }
502
503    /// Transform new data using fitted components
504    pub fn transform(&self, data: &Array2<f32>, components: &Array2<f32>) -> Array2<f32> {
505        // Center (ideally use mean from training)
506        let mean = data
507            .mean_axis(Axis(0))
508            .expect("Mean axis computation must succeed");
509        let centered = data - &mean.view().insert_axis(Axis(0));
510
511        // Project onto components
512        centered.dot(&components.t())
513    }
514
515    /// Inverse transform (reconstruct from components)
516    pub fn inverse_transform(
517        &self,
518        transformed: &Array2<f32>,
519        components: &Array2<f32>,
520    ) -> Array2<f32> {
521        transformed.dot(components)
522    }
523}
524
525/// Temporal decorrelation for source separation
526pub struct TemporalDecorrelation {
527    /// Time delay for decorrelation
528    tau: usize,
529}
530
531impl TemporalDecorrelation {
532    /// Create a new temporal decorrelation analyzer
533    pub fn new(tau: usize) -> Self {
534        Self { tau }
535    }
536
537    /// Separate sources using temporal structure
538    ///
539    /// Exploits temporal correlations in source signals
540    pub fn separate(&self, mixed: &Array2<f32>) -> IoResult<Array2<f32>> {
541        let (n_samples, n_channels) = mixed.dim();
542
543        if n_samples <= self.tau {
544            return Err(IoError::SignalError("Insufficient samples".into()));
545        }
546
547        // Compute time-delayed covariance
548        let mut cov_delay = Array2::zeros((n_channels, n_channels));
549
550        for i in 0..(n_samples - self.tau) {
551            for j in 0..n_channels {
552                for k in 0..n_channels {
553                    cov_delay[[j, k]] += mixed[[i, j]] * mixed[[i + self.tau, k]];
554                }
555            }
556        }
557
558        cov_delay /= (n_samples - self.tau) as f32;
559
560        // Eigendecomposition of time-delayed covariance
561        let (_, eigenvectors) = FastICA::simple_eig(&cov_delay)?;
562
563        // Separate sources
564        let separated = mixed.dot(&eigenvectors);
565
566        Ok(separated)
567    }
568}
569
570#[cfg(test)]
571mod tests {
572    use super::*;
573    use scirs2_core::ndarray::arr2;
574
575    #[test]
576    fn test_fastica_basic() {
577        // Create simple mixed signals
578        let mixed = arr2(&[[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0]]);
579
580        let ica = FastICA::new(2, None, None);
581        let result = ica.fit_transform(&mixed);
582
583        assert!(result.is_ok());
584        let (sources, unmixing) = result.unwrap();
585        assert_eq!(sources.dim(), (4, 2));
586        assert_eq!(unmixing.dim(), (2, 2));
587    }
588
589    #[test]
590    fn test_nmf_basic() {
591        // Create non-negative matrix
592        let v = arr2(&[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]);
593
594        let nmf = NMF::new(2, Some(50), None);
595        let result = nmf.fit_transform(&v);
596
597        assert!(result.is_ok());
598        let (w, h) = result.unwrap();
599        assert_eq!(w.dim(), (3, 2));
600        assert_eq!(h.dim(), (2, 3));
601
602        // Check non-negativity
603        assert!(w.iter().all(|&x| x >= 0.0));
604        assert!(h.iter().all(|&x| x >= 0.0));
605    }
606
607    #[test]
608    fn test_pca_basic() {
609        let data = arr2(&[
610            [1.0, 2.0, 3.0],
611            [2.0, 3.0, 4.0],
612            [3.0, 4.0, 5.0],
613            [4.0, 5.0, 6.0],
614        ]);
615
616        let pca = PCA::new(2);
617        let result = pca.fit_transform(&data);
618
619        assert!(result.is_ok());
620        let (transformed, components, explained_var) = result.unwrap();
621        assert_eq!(transformed.dim(), (4, 2));
622        assert_eq!(components.dim(), (2, 3));
623        assert_eq!(explained_var.len(), 2);
624    }
625
626    #[test]
627    fn test_temporal_decorrelation() {
628        let mixed = arr2(&[[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0]]);
629
630        let td = TemporalDecorrelation::new(1);
631        let result = td.separate(&mixed);
632
633        assert!(result.is_ok());
634        let separated = result.unwrap();
635        assert_eq!(separated.dim(), (5, 2));
636    }
637
638    #[test]
639    fn test_nmf_negative_input() {
640        let v = arr2(&[[1.0, -2.0], [2.0, 3.0]]);
641
642        let nmf = NMF::new(2, None, None);
643        let result = nmf.fit_transform(&v);
644
645        assert!(result.is_err());
646    }
647
648    #[test]
649    fn test_pca_reconstruction() {
650        let data = arr2(&[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]);
651
652        let pca = PCA::new(2);
653        let (transformed, components, _) = pca.fit_transform(&data).unwrap();
654
655        let reconstructed = pca.inverse_transform(&transformed, &components);
656        assert_eq!(reconstructed.nrows(), 3);
657    }
658}