use linfa::{
dataset::{DatasetBase, Records, WithLapack, WithoutLapack},
traits::*,
Float,
};
#[cfg(not(feature = "blas"))]
use linfa_linalg::{eigh::*, svd::*};
use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Data, Ix2};
#[cfg(feature = "blas")]
use ndarray_linalg::{eigh::Eigh, solveh::UPLO, svd::SVD};
use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
use ndarray_stats::QuantileExt;
use rand_xoshiro::Xoshiro256Plus;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use crate::error::{FastIcaError, Result};
use crate::hyperparams::FastIcaValidParams;
impl<F: Float, D: Data<Elem = F>, T> Fit<ArrayBase<D, Ix2>, T, FastIcaError>
for FastIcaValidParams<F>
{
type Object = FastIca<F>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<Self::Object> {
let x = &dataset.records;
let (nsamples, nfeatures) = (x.nsamples(), x.nfeatures());
if dataset.nsamples() == 0 {
return Err(FastIcaError::NotEnoughSamples);
}
let ncomponents = self
.ncomponents()
.unwrap_or_else(|| nsamples.min(nfeatures));
if ncomponents > nsamples.min(nfeatures) {
return Err(FastIcaError::InvalidValue(format!(
"ncomponents cannot be greater than the min({nsamples}, {nfeatures}), got {ncomponents}"
)));
}
let xmean = x.mean_axis(Axis(0)).unwrap();
let mut xcentered = x - &xmean.view().insert_axis(Axis(0));
xcentered = xcentered.reversed_axes();
let xcentered = xcentered.with_lapack();
let k = match xcentered.svd(true, false)? {
(Some(u), s, _) => {
let s = s.mapv(F::Lapack::cast);
(u.slice_move(s![.., ..nsamples.min(nfeatures)]) / s)
.t()
.slice(s![..ncomponents, ..])
.to_owned()
}
_ => return Err(FastIcaError::SvdDecomposition),
};
let mut xwhitened = k.dot(&xcentered).without_lapack();
let k = k.without_lapack();
let nsamples_sqrt = F::cast(nsamples).sqrt();
xwhitened.mapv_inplace(|x| x * nsamples_sqrt);
let w: Array2<f64>;
if let Some(seed) = self.random_state() {
let mut rng = Xoshiro256Plus::seed_from_u64(*seed as u64);
w = Array::random_using((ncomponents, ncomponents), Uniform::new(0., 1.), &mut rng);
} else {
w = Array::random((ncomponents, ncomponents), Uniform::new(0., 1.));
}
let mut w = w.mapv(F::cast);
w = self.ica_parallel(&xwhitened, &w)?;
let components = w.dot(&k);
Ok(FastIca {
mean: xmean,
components,
})
}
}
impl<F: Float> FastIcaValidParams<F> {
fn ica_parallel(&self, x: &Array2<F>, w: &Array2<F>) -> Result<Array2<F>> {
let mut w = Self::sym_decorrelation(w)?;
let p = x.ncols() as f64;
for _ in 0..self.max_iter() {
let (gwtx, g_wtx) = self.gfunc().exec(&w.dot(x))?;
let lhs = gwtx.dot(&x.t()).mapv(|x| x / F::cast(p));
let rhs = &w * &g_wtx.insert_axis(Axis(1));
let wnew = Self::sym_decorrelation(&(lhs - rhs))?;
let lim = *wnew
.outer_iter()
.zip(w.outer_iter())
.map(|(a, b)| a.dot(&b))
.collect::<Array1<F>>()
.mapv(|x| x.abs())
.mapv(|x| x - F::cast(1.))
.mapv(|x| x.abs())
.max()
.unwrap();
w = wnew;
if lim < F::cast(self.tol()) {
break;
}
}
Ok(w)
}
fn sym_decorrelation(w: &Array2<F>) -> Result<Array2<F>> {
#[cfg(feature = "blas")]
let (eig_val, eig_vec) = w.dot(&w.t()).with_lapack().eigh(UPLO::Upper)?;
#[cfg(not(feature = "blas"))]
let (eig_val, eig_vec) = w.dot(&w.t()).with_lapack().eigh()?;
let eig_val = eig_val.mapv(F::cast);
let eig_vec = eig_vec.without_lapack();
let tmp = &eig_vec
* &(eig_val.mapv(|x| x.sqrt()).mapv(|x| {
let lower_bound = F::cast(1e-7);
if x < lower_bound {
return lower_bound.recip();
}
x.recip()
}))
.insert_axis(Axis(0));
Ok(tmp.dot(&eig_vec.t()).dot(w))
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq)]
pub struct FastIca<F> {
mean: Array1<F>,
components: Array2<F>,
}
impl<F: Float> PredictInplace<Array2<F>, Array2<F>> for FastIca<F> {
fn predict_inplace(&self, x: &Array2<F>, y: &mut Array2<F>) {
assert_eq!(
y.shape(),
&[x.nrows(), self.components.nrows()],
"The number of data points must match the number of output targets."
);
let xcentered = x - &self.mean.view().insert_axis(Axis(0));
*y = xcentered.dot(&self.components.t());
}
fn default_target(&self, x: &Array2<F>) -> Array2<F> {
Array2::zeros((x.nrows(), self.components.nrows()))
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq, Copy)]
pub enum GFunc {
Logcosh(f64),
Exp,
Cube,
}
impl GFunc {
fn exec<A: Float>(&self, x: &Array2<A>) -> Result<(Array2<A>, Array1<A>)> {
match self {
Self::Cube => Ok(Self::cube(x)),
Self::Exp => Ok(Self::exp(x)),
Self::Logcosh(alpha) => Self::logcosh(x, *alpha),
}
}
fn cube<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
(
x.mapv(|x| x.powi(3)),
x.mapv(|x| A::cast(3.) * x.powi(2))
.mean_axis(Axis(1))
.unwrap(),
)
}
fn exp<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
let exp = x.mapv(|x| (-x.powi(2) / A::cast(2.)).exp());
(
x * &exp,
(x.mapv(|x| A::cast(1.) - x.powi(2)) * &exp)
.mean_axis(Axis(1))
.unwrap(),
)
}
fn logcosh<A: Float>(x: &Array2<A>, alpha: f64) -> Result<(Array2<A>, Array1<A>)> {
if !(1.0..=2.0).contains(&alpha) {
return Err(FastIcaError::InvalidValue(format!(
"alpha must be between 1 and 2 inclusive, got {alpha}"
)));
}
let alpha = A::cast(alpha);
let gx = x.mapv(|x| (x * alpha).tanh());
let g_x = gx.mapv(|x| alpha * (A::cast(1.) - x.powi(2)));
Ok((gx, g_x.mean_axis(Axis(1)).unwrap()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use linfa::traits::{Fit, Predict};
use crate::hyperparams::{FastIcaParams, FastIcaValidParams};
use ndarray_rand::rand_distr::StudentT;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<FastIca<f64>>();
has_autotraits::<GFunc>();
has_autotraits::<FastIcaParams<f64>>();
has_autotraits::<FastIcaValidParams<f64>>();
has_autotraits::<FastIcaError>();
}
#[test]
fn test_ncomponents_err() {
let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
let ica = FastIca::params().ncomponents(100);
let ica = ica.fit(&input);
assert!(ica.is_err());
}
#[test]
fn test_logcosh_alpha_err() {
let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
let ica = FastIca::params().gfunc(GFunc::Logcosh(10.));
let ica = ica.fit(&input);
assert!(ica.is_err());
}
macro_rules! fast_ica_tests {
($($name:ident: $gfunc:expr,)*) => {
paste::item! {
$(
#[test]
fn [<test_fast_ica_$name>]() {
test_fast_ica($gfunc);
}
)*
}
}
}
fast_ica_tests! {
exp: GFunc::Exp, cube: GFunc::Cube, logcosh: GFunc::Logcosh(1.0),
}
fn test_fast_ica(gfunc: GFunc) {
let nsamples = 1000;
let center_and_norm = |s: &mut Array2<f64>| {
let mean = s.mean_axis(Axis(0)).unwrap();
*s -= &mean.insert_axis(Axis(0));
let std = s.std_axis(Axis(0), 0.);
*s /= &std.insert_axis(Axis(0));
};
let mut source1 = Array::linspace(0., 100., nsamples);
source1.mapv_inplace(|x| {
let tmp = 2. * f64::sin(x);
if tmp > 0. {
return 0.;
}
-1.
});
let mut rng = Xoshiro256Plus::seed_from_u64(42);
let source2 = Array::random_using((nsamples, 1), StudentT::new(1.0).unwrap(), &mut rng);
let mut sources = concatenate![Axis(1), source1.insert_axis(Axis(1)), source2];
center_and_norm(&mut sources);
let phi: f64 = 0.6;
let mixing = array![[phi.cos(), phi.sin()], [phi.sin(), -phi.cos()]];
sources = mixing.dot(&sources.t());
center_and_norm(&mut sources);
sources = sources.reversed_axes();
let ica = FastIca::params()
.ncomponents(2)
.gfunc(gfunc)
.random_state(42);
let sources_dataset = DatasetBase::from(sources.view());
let ica = ica.fit(&sources_dataset).unwrap();
let mut output = ica.predict(&sources);
center_and_norm(&mut output);
assert_eq!(output.shape(), &[1000, 2]);
let s1 = sources.column(0);
let s2 = sources.column(1);
let mut s1_ = output.column(0);
let mut s2_ = output.column(1);
if s1_.dot(&s2).abs() > s1_.dot(&s1).abs() {
s1_ = output.column(1);
s2_ = output.column(0);
}
let similarity1 = s1.dot(&s1_).abs() / (nsamples as f64);
let similarity2 = s2.dot(&s2_).abs() / (nsamples as f64);
assert!(similarity1.max(similarity2) > 0.9);
}
}