s2gpp 1.0.2

Algorithm for Highly Efficient Detection of Correlation Anomalies in Multivariate Time Series
Documentation
use anyhow::{Error, Result};
use ndarray::{arr1, stack, Array, Array1, ArrayBase, ArrayView, Axis, Data, Dim};

pub trait IndexArr<A, D> {
    fn get_multiple(&self, indices: Array1<usize>, axis: Axis) -> Result<Array<A, D>>;
}

impl<A, S> IndexArr<A, Dim<[usize; 2]>> for ArrayBase<S, Dim<[usize; 2]>>
where
    A: Copy,
    S: Data<Elem = A>,
{
    fn get_multiple(
        &self,
        indices: Array1<usize>,
        axis: Axis,
    ) -> Result<Array<A, Dim<[usize; 2]>>> {
        let indexed_vec: Vec<ArrayView<_, _>> = indices
            .to_vec()
            .into_iter()
            .map(|index| self.index_axis(axis, index))
            .collect();
        Ok(stack(axis, indexed_vec.as_slice())?)
    }
}

impl<A, S> IndexArr<A, Dim<[usize; 1]>> for ArrayBase<S, Dim<[usize; 1]>>
where
    A: Clone,
    S: Data<Elem = A>,
{
    fn get_multiple(
        &self,
        indices: Array1<usize>,
        _axis: Axis,
    ) -> Result<Array<A, Dim<[usize; 1]>>> {
        let indexed_vec = indices
            .to_vec()
            .into_iter()
            .map(|index| {
                self.get(index)
                    .ok_or_else(|| Error::msg(format!("Index {} out of bounds", index)))
            })
            .map(|x| x.map(|x| (*x).clone()))
            .collect::<Result<Vec<A>, _>>()?;
        Ok(arr1(indexed_vec.as_slice()))
    }
}

#[cfg(test)]
mod tests {
    use crate::utils::ndarray_extensions::index_arr::IndexArr;
    use ndarray::{arr1, arr2, Axis};

    #[test]
    fn array1_get_multiple() {
        let arr = arr1(&[2., 4., 8., 16.]);
        let indices = arr1(&[1, 3]);
        let indexed_arr = arr.get_multiple(indices, Axis(0)).unwrap();
        let expect = arr1(&[4., 16.]);
        assert_eq!(indexed_arr, expect)
    }

    #[test]
    fn array2_get_multiple_axis0() {
        let arr = arr2(&[[2.], [4.], [8.], [16.]]);
        let indices = arr1(&[1, 3]);
        let indexed_arr = arr.get_multiple(indices, Axis(0)).unwrap();
        let expect = arr2(&[[4.], [16.]]);
        assert_eq!(indexed_arr, expect)
    }

    #[test]
    fn array2_get_multiple_axis1() {
        let arr = arr2(&[[2., 4., 8., 16.], [2., 4., 8., 16.]]);
        let indices = arr1(&[0, 2]);
        let indexed_arr = arr.get_multiple(indices, Axis(1)).unwrap();
        let expect = arr2(&[[2., 8.], [2., 8.]]);
        assert_eq!(indexed_arr, expect)
    }
}