dsalgo 0.3.10

A package for Datastructures and Algorithms.
Documentation
#[derive(Debug, Clone, PartialEq, Eq, Hash)]

pub struct DynamicTensor<T> {
    shape: Vec<usize>,
    data: Vec<T>,
}

impl<T> DynamicTensor<T> {
    pub fn shape(&self) -> &[usize] {
        &self.shape
    }

    pub(crate) fn dim(&self) -> usize {
        self.shape.len()
    }

    pub(crate) fn strides(&self) -> Vec<usize> {
        let mut strides: Vec<usize> = self.shape.clone();

        let d = self.dim();

        if d > 0 {
            strides[d - 1] = 1;
        }

        for i in (1..d).rev() {
            strides[i - 1] = strides[i] * self.shape[i];
        }

        strides
    }

    pub(crate) fn compute_size(shape: &[usize]) -> usize {
        let mut size = 1;

        for &dim in shape {
            size *= dim;
        }

        size
    }

    pub fn size(&self) -> usize {
        self.data.len()
    }
}

impl<T: Default> DynamicTensor<T> {
    pub fn new(shape: &[usize]) -> Self {
        let size = Self::compute_size(&shape);

        Self {
            shape: shape.to_vec(),
            data: (0..size).map(|_| T::default()).collect(),
        }
    }
}

impl<T> DynamicTensor<T> {
    fn flatten_index(
        &self,
        index: &[usize],
    ) -> usize {
        let mut idx = 0;

        let strides = self.strides();

        assert_eq!(index.len(), self.dim());

        for i in 0..self.dim() {
            idx += strides[i] * index[i];
        }

        idx
    }
}

impl<T> std::ops::Index<&[usize]> for DynamicTensor<T> {
    type Output = T;

    fn index(
        &self,
        index: &[usize],
    ) -> &Self::Output {
        &self.data[self.flatten_index(index)]
    }
}

impl<T> std::ops::IndexMut<&[usize]> for DynamicTensor<T> {
    fn index_mut(
        &mut self,
        index: &[usize],
    ) -> &mut Self::Output {
        let idx = self.flatten_index(index);

        &mut self.data[idx]
    }
}

#[cfg(test)]

mod tests {

    use super::*;

    #[test]

    fn test() {
        let mut a = DynamicTensor::<i64>::new(&[1, 2, 3]);

        a[&[0, 0, 0]] = 1;

        assert_eq!(a[&[0, 0, 0]], 1);

        println!("{:?}", a);

        let b = DynamicTensor::<i64>::new(&[]);

        println!("{:?}", b);
    }
}