use linfa::{
dataset::{DatasetBase, Targets},
traits::*,
Float,
};
use ndarray::{Array, Array1, Array2, ArrayBase, Axis, Data, Ix2};
use ndarray_linalg::{eigh::Eigh, lapack::UPLO, svd::SVD, Lapack};
use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
use ndarray_stats::QuantileExt;
use rand_isaac::Isaac64Rng;
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use crate::error::{FastIcaError, Result};
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug)]
pub struct FastIca<F: Float> {
ncomponents: Option<usize>,
gfunc: GFunc,
max_iter: usize,
tol: F,
random_state: Option<usize>,
}
impl<F: Float> Default for FastIca<F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Float> FastIca<F> {
pub fn new() -> Self {
FastIca {
ncomponents: None,
gfunc: GFunc::Logcosh(1.),
max_iter: 200,
tol: F::from(1e-4).unwrap(),
random_state: None,
}
}
pub fn ncomponents(mut self, ncomponents: usize) -> Self {
self.ncomponents = Some(ncomponents);
self
}
pub fn gfunc(mut self, gfunc: GFunc) -> Self {
self.gfunc = gfunc;
self
}
pub fn max_iter(mut self, max_iter: usize) -> Self {
self.max_iter = max_iter;
self
}
pub fn tol(mut self, tol: F) -> Self {
self.tol = tol;
self
}
pub fn random_state(mut self, random_state: usize) -> Self {
self.random_state = Some(random_state);
self
}
}
impl<'a, F: Float + Lapack, D: Data<Elem = F>, T: Targets> Fit<'a, ArrayBase<D, Ix2>, T>
for FastIca<F>
{
type Object = Result<FittedFastIca<F>>;
fn fit(&self, dataset: &DatasetBase<ArrayBase<D, Ix2>, T>) -> Result<FittedFastIca<F>> {
let x = &dataset.records;
let (nsamples, nfeatures) = (x.nrows(), x.ncols());
let ncomponents = self.ncomponents.unwrap_or_else(|| nsamples.min(nfeatures));
if ncomponents > nsamples.min(nfeatures) {
return Err(FastIcaError::InvalidValue(format!(
"ncomponents cannot be greater than the min({}, {}), got {}",
nsamples, nfeatures, ncomponents
)));
}
let xmean = x.mean_axis(Axis(0)).unwrap();
let mut xcentered = x - &xmean.view().insert_axis(Axis(0));
xcentered = xcentered.reversed_axes();
let k = match xcentered.svd(true, false)? {
(Some(u), s, _) => {
let s = s.mapv(|x| F::from(x).unwrap());
(u.slice(s![.., ..nsamples.min(nfeatures)]).to_owned() / s)
.t()
.slice(s![..ncomponents, ..])
.to_owned()
}
_ => return Err(FastIcaError::SvdDecomposition),
};
let mut xwhitened = k.dot(&xcentered);
let nsamples_sqrt = F::from((nsamples as f64).sqrt()).unwrap();
xwhitened.mapv_inplace(|x| x * nsamples_sqrt);
let w: Array2<f64>;
if let Some(seed) = self.random_state {
let mut rng = Isaac64Rng::seed_from_u64(seed as u64);
w = Array::random_using((ncomponents, ncomponents), Uniform::new(0., 1.), &mut rng);
} else {
w = Array::random((ncomponents, ncomponents), Uniform::new(0., 1.));
}
let mut w = w.mapv(|x| F::from(x).unwrap());
w = self.ica_parallel(&xwhitened, &w)?;
let components = w.dot(&k);
Ok(FittedFastIca {
mean: xmean,
components,
})
}
}
impl<F: Float + Lapack> FastIca<F> {
fn ica_parallel(&self, x: &Array2<F>, w: &Array2<F>) -> Result<Array2<F>> {
let mut w = Self::sym_decorrelation(&w)?;
let p = x.ncols() as f64;
for _ in 0..self.max_iter {
let (gwtx, g_wtx) = self.gfunc.exec(&w.dot(x))?;
let lhs = gwtx.dot(&x.t()).mapv(|x| x / F::from(p).unwrap());
let rhs = &w * &g_wtx.insert_axis(Axis(1));
let wnew = Self::sym_decorrelation(&(lhs - rhs))?;
let lim = *wnew
.outer_iter()
.zip(w.outer_iter())
.map(|(a, b)| a.dot(&b))
.collect::<Array1<F>>()
.mapv(num_traits::Float::abs)
.mapv(|x| x - F::from(1.).unwrap())
.mapv(num_traits::Float::abs)
.max()
.unwrap();
w = wnew;
if lim < F::from(self.tol).unwrap() {
break;
}
}
Ok(w)
}
fn sym_decorrelation(w: &Array2<F>) -> Result<Array2<F>> {
let (eig_val, eig_vec) = w.dot(&w.t()).eigh(UPLO::Upper)?;
let eig_val = eig_val.mapv(|x| F::from(x).unwrap());
let tmp = &eig_vec
* &(eig_val.mapv(num_traits::Float::sqrt).mapv(|x| {
let lower_bound = F::from(1e-7).unwrap();
if x < lower_bound {
return num_traits::Float::recip(lower_bound);
}
num_traits::Float::recip(x)
}))
.insert_axis(Axis(0));
Ok(tmp.dot(&eig_vec.t()).dot(w))
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug)]
pub struct FittedFastIca<F> {
mean: Array1<F>,
components: Array2<F>,
}
impl<F: Float> Predict<&Array2<F>, Array2<F>> for FittedFastIca<F> {
fn predict(&self, x: &Array2<F>) -> Array2<F> {
let xcentered = x - &self.mean.view().insert_axis(Axis(0));
xcentered.dot(&self.components.t())
}
}
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug)]
pub enum GFunc {
Logcosh(f64),
Exp,
Cube,
}
impl GFunc {
fn exec<A: Float>(&self, x: &Array2<A>) -> Result<(Array2<A>, Array1<A>)> {
match self {
Self::Cube => Ok(Self::cube(x)),
Self::Exp => Ok(Self::exp(x)),
Self::Logcosh(alpha) => Self::logcosh(x, *alpha),
}
}
fn cube<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
(
x.mapv(|x| x.powi(3)),
x.mapv(|x| A::from(3.).unwrap() * x.powi(2))
.mean_axis(Axis(1))
.unwrap(),
)
}
fn exp<A: Float>(x: &Array2<A>) -> (Array2<A>, Array1<A>) {
let exp = x.mapv(|x| -x.powi(2) / A::from(2.).unwrap());
(
x * &exp,
(x.mapv(|x| A::from(1.).unwrap() - x.powi(2)) * &exp)
.mean_axis(Axis(1))
.unwrap(),
)
}
fn logcosh<A: Float>(x: &Array2<A>, alpha: f64) -> Result<(Array2<A>, Array1<A>)> {
if !(1.0..=2.0).contains(&alpha) {
return Err(FastIcaError::InvalidValue(format!(
"alpha must be between 1 and 2 inclusive, got {}",
alpha
)));
}
let alpha = A::from(alpha).unwrap();
let gx = x.mapv(|x| (x * alpha).tanh());
let g_x = gx.mapv(|x| alpha * (A::from(1.).unwrap() - x.powi(2)));
Ok((gx, g_x.mean_axis(Axis(1)).unwrap()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use linfa::traits::{Fit, Predict};
use ndarray_rand::rand_distr::StudentT;
#[test]
fn test_ncomponents_err() {
let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
let ica = FastIca::new().ncomponents(100);
let ica = ica.fit(&input);
assert!(ica.is_err());
}
#[test]
fn test_logcosh_alpha_err() {
let input = DatasetBase::from(Array::random((4, 4), Uniform::new(0.0, 1.0)));
let ica = FastIca::new().gfunc(GFunc::Logcosh(10.));
let ica = ica.fit(&input);
assert!(ica.is_err());
}
macro_rules! fast_ica_tests {
($($name:ident: $gfunc:expr,)*) => {
paste::item! {
$(
#[test]
fn [<test_fast_ica_$name>]() {
test_fast_ica($gfunc);
}
)*
}
}
}
fast_ica_tests! {
exp: GFunc::Exp, cube: GFunc::Cube, logcosh: GFunc::Logcosh(1.0),
}
fn test_fast_ica(gfunc: GFunc) {
let nsamples = 1000;
let center_and_norm = |s: &mut Array2<f64>| {
let mean = s.mean_axis(Axis(0)).unwrap();
*s -= &mean.insert_axis(Axis(0));
let std = s.std_axis(Axis(0), 0.);
*s /= &std.insert_axis(Axis(0));
};
let mut source1 = Array::linspace(0., 100., nsamples);
source1.mapv_inplace(|x| {
let tmp = 2. * f64::sin(x);
if tmp > 0. {
return 0.;
}
-1.
});
let mut rng = Isaac64Rng::seed_from_u64(42);
let source2 = Array::random_using((nsamples, 1), StudentT::new(1.0).unwrap(), &mut rng);
let mut sources = stack![Axis(1), source1.insert_axis(Axis(1)), source2];
center_and_norm(&mut sources);
let phi: f64 = 0.6;
let mixing = array![[phi.cos(), phi.sin()], [phi.sin(), -phi.cos()]];
sources = mixing.dot(&sources.t());
center_and_norm(&mut sources);
sources = sources.reversed_axes();
let ica = FastIca::new().ncomponents(2).gfunc(gfunc).random_state(42);
let sources_dataset = DatasetBase::from(sources.view());
let ica = ica.fit(&sources_dataset).unwrap();
let mut output = ica.predict(&sources);
center_and_norm(&mut output);
assert_eq!(output.shape(), &[1000, 2]);
let s1 = sources.column(0);
let s2 = sources.column(1);
let mut s1_ = output.column(0);
let mut s2_ = output.column(1);
if s1_.dot(&s2).abs() > s1_.dot(&s1).abs() {
s1_ = output.column(1);
s2_ = output.column(0);
}
let similarity1 = s1.dot(&s1_).abs() / (nsamples as f64);
let similarity2 = s2.dot(&s2_).abs() / (nsamples as f64);
assert!(similarity1.max(similarity2) > 0.9);
}
}