s2gpp 1.0.2

Algorithm for Highly Efficient Detection of Correlation Anomalies in Multivariate Time Series
Documentation
use anyhow::{Error, Result};
use ndarray::{Array, ArrayBase, Axis, Data, Dim};
use num_traits::Float;
use std::fmt::Debug;
use std::ops::Div;

pub(crate) trait Softmax<A, D>
where
    A: Copy + Float + Debug,
{
    fn softmax(&self) -> Result<Array<A, D>>;
    fn softmax_axis(&self, axis: Axis) -> Result<Array<A, D>>;
}

impl<A, S> Softmax<A, Dim<[usize; 1]>> for ArrayBase<S, Dim<[usize; 1]>>
where
    A: Copy + Float + Debug,
    S: Data<Elem = A>,
{
    fn softmax(&self) -> Result<Array<A, Dim<[usize; 1]>>> {
        let max = self.sum();
        Ok(self.mapv(|x| x.div(max)))
    }

    fn softmax_axis(&self, axis: Axis) -> Result<Array<A, Dim<[usize; 1]>>> {
        let max = self.sum_axis(axis).insert_axis(axis);
        let shape = self.shape();
        Ok(self.div(
            &max.broadcast([shape[0]])
                .ok_or_else(|| Error::msg("Could not broadcast max"))?,
        ))
    }
}

impl<A, S> Softmax<A, Dim<[usize; 2]>> for ArrayBase<S, Dim<[usize; 2]>>
where
    A: Copy + Float + Debug,
    S: Data<Elem = A>,
{
    fn softmax(&self) -> Result<Array<A, Dim<[usize; 2]>>> {
        let max = self.sum();
        Ok(self.mapv(|x| x.div(max)))
    }

    fn softmax_axis(&self, axis: Axis) -> Result<Array<A, Dim<[usize; 2]>>> {
        let max = self.sum_axis(axis).insert_axis(axis);
        let shape = self.shape();
        Ok(self.div(
            &max.broadcast([shape[0], shape[1]])
                .ok_or_else(|| Error::msg("Could not broadcast max"))?,
        ))
    }
}

#[cfg(test)]
mod tests {
    use crate::utils::ndarray_extensions::softmax::Softmax;
    use ndarray::{arr1, arr2, Axis};

    #[test]
    fn correct_result_1d() {
        let a = arr1(&[1., 2., 1.]);
        let expected = arr1(&[0.25, 0.5, 0.25]);
        assert_eq!(a.softmax().unwrap(), expected)
    }

    #[test]
    fn correct_result_2d() {
        let a = arr2(&[[1., 2., 1.]]);
        let expected = arr2(&[[0.25, 0.5, 0.25]]);
        assert_eq!(a.softmax_axis(Axis(1)).unwrap(), expected)
    }
}