n_circular_array 0.4.2

An n-dimensional circular array.
Documentation
use std::ops::{Deref, DerefMut, Range};

/// The strides of an `N` dimension array.
///
/// `Strides` make no assumptions about the shape of an array. The caller must
/// ensure that expanded/flattened indices are valid.
#[derive(Debug, Clone, Copy)]
pub struct Strides<const N: usize>([usize; N]);

impl<const N: usize> Strides<N> {
    /// Create `Strides` for the given `shape`.
    pub fn new(shape: &[usize; N]) -> Self {
        let mut array = [1; N];
        for i in 1..N {
            array[i] = array[i - 1] * shape[i - 1];
        }

        Strides(array)
    }

    /// Expand an `index` into an `N` dimensional index.
    ///
    /// # Example
    ///
    /// ```
    /// # #[cfg(feature = "strides")] {
    /// # use std::ops::Range;
    /// # use n_circular_array::{CircularArray, CircularIndex, Strides};
    /// let mut array = CircularArray::new([3, 3], vec![
    ///     0, 1, 2,
    ///     3, 4, 5,
    ///     6, 7, 8
    /// ]);
    /// assert_eq!(array.strides().expand_index(4), [1, 1]);
    /// assert_eq!(array.strides().expand_index(5), [2, 1]);
    /// # }
    /// ```
    #[allow(dead_code)]
    pub fn expand_index(&self, index: usize) -> [usize; N] {
        let mut array = [0; N];

        for i in 0..N - 1 {
            array[i] = (index % self.0[i + 1]) / self.0[i];
        }
        array[N - 1] = index / self.0[N - 1];

        array
    }

    /// Flatten an `N` dimensional `index` into a contiguous index.
    ///
    /// # Example
    ///
    /// ```
    /// # #[cfg(feature = "strides")] {
    /// # use std::ops::Range;
    /// # use n_circular_array::{CircularArray, CircularIndex, Strides};
    /// let mut array = CircularArray::new([3, 3], vec![
    ///     0, 1, 2,
    ///     3, 4, 5,
    ///     6, 7, 8
    /// ]);
    /// assert_eq!(array.strides().flatten_index([1, 1]), 4);
    /// assert_eq!(array.strides().flatten_index([2, 1]), 5);
    /// # }
    /// ```
    #[allow(dead_code)]
    pub fn flatten_index(&self, index: [usize; N]) -> usize {
        index
            .iter()
            .zip(self.iter())
            .map(|(idx, stride)| idx * stride)
            .sum::<usize>()
    }

    /// Flatten a **contiguous** `N` dimensional index range into a contiguous
    /// `Range<usize>`.
    ///
    /// # Example
    ///
    /// ```
    /// # #[cfg(feature = "strides")] {
    /// # use std::ops::Range;
    /// # use n_circular_array::{CircularArray, CircularIndex, Strides};
    /// let mut array = CircularArray::new([3, 3], vec![
    ///     0, 1, 2,
    ///     3, 4, 5,
    ///     6, 7, 8
    /// ]);
    /// // A contiguous range from [0, 0] to [1, 1].
    /// assert_eq!(array.strides().flatten_range([0..2, 0..2]), 0..5);
    /// // A contiguous range from [1, 0] to [1, 2].
    /// assert_eq!(array.strides().flatten_range([1..2, 0..3]), 1..8);
    /// # }
    /// ```
    #[allow(dead_code)]
    pub fn flatten_range(&self, index_range: [Range<usize>; N]) -> Range<usize> {
        let (start, end) = index_range.into_iter().zip(self.iter()).fold(
            (0, 0),
            |(start, end), (range, stride)| {
                // Unrolled equivalent of `flatten_index`.
                (start + range.start * stride, end + (range.end - 1) * stride)
            },
        );
        start..end + 1
    }
}

impl<const N: usize> Deref for Strides<N> {
    type Target = [usize; N];

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<const N: usize> DerefMut for Strides<N> {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn new() {
        let strides_2d = Strides::new(&[5, 4]);
        let strides_3d = Strides::new(&[5, 4, 3]);
        let strides_4d = Strides::new(&[5, 4, 3, 2]);

        assert_eq!(strides_2d.0, [1, 5]);
        assert_eq!(strides_3d.0, [1, 5, 20]);
        assert_eq!(strides_4d.0, [1, 5, 20, 60]);
    }

    #[test]
    fn expand_index() {
        let strides_2d = Strides::new(&[5, 4]);
        let strides_3d = Strides::new(&[5, 4, 3]);
        let strides_4d = Strides::new(&[5, 4, 3, 2]);

        assert_eq!(strides_2d.expand_index(11), [1, 2]);
        assert_eq!(strides_3d.expand_index(31), [1, 2, 1]);
        assert_eq!(strides_4d.expand_index(81), [1, 0, 1, 1]);
    }

    #[test]
    fn flatten_index() {
        let strides_2d = Strides::new(&[5, 4]);
        let strides_3d = Strides::new(&[5, 4, 3]);
        let strides_4d = Strides::new(&[5, 4, 3, 2]);

        assert_eq!(strides_2d.flatten_index([1, 2]), 11);
        assert_eq!(strides_3d.flatten_index([1, 2, 1]), 31);
        assert_eq!(strides_4d.flatten_index([1, 0, 1, 1]), 81);
    }

    #[test]
    fn flatten_range() {
        let strides_2d = Strides::new(&[5, 4]);
        let strides_3d = Strides::new(&[5, 4, 3]);
        let strides_4d = Strides::new(&[5, 4, 3, 2]);

        assert_eq!(strides_2d.flatten_range([1..3, 2..3]), 11..13);
        assert_eq!(strides_3d.flatten_range([1..3, 2..3, 1..2]), 31..33);
        assert_eq!(strides_4d.flatten_range([1..3, 0..1, 1..2, 1..2]), 81..83);
    }
}