nd_slice/
nd_slice.rs

1//! `NdSlice` wraps `&[T]` to represent an immutable n-dimensional array
2
3use std::ops::Index;
4use std::slice;
5use crate::addressing::{Order, address};
6use crate::errors::ShapeError;
7
8#[derive(Debug, Clone)]
9/// `NdSlice` wraps `&[T]` to represent an immutable n-dimensional array
10///
11/// ```
12/// # use nd_slice::{NdSlice, Order};
13/// let arr = [1, 2, 3, 4, 5, 6];
14/// let n = NdSlice::new(&arr, [2, 3], Order::RowMajor).unwrap();
15/// assert_eq!(n[[0, 0]], 1);
16/// assert_eq!(n[[1, 2]], 6);
17///
18/// let arr = [1, 2, 3, 4, 5, 6, 7, 8];
19/// let n = NdSlice::new(&arr, [2, 2, 2], Order::RowMajor).unwrap();
20/// assert_eq!(n[[0, 0, 0]], 1);
21/// assert_eq!(n[[1, 1, 1]], 8);
22/// ```
23///
24/// If the slice doesn't have enough elements to represent an array of that shape, it will
25/// return an `Err(ShapeError)`.
26///
27/// ```should_panic
28/// # use nd_slice::{NdSlice, Order};
29/// let n = NdSlice::new(&[1, 2, 3, 4, 5, 6], [2, 2], Order::RowMajor).unwrap(); // more elements
30/// let n = NdSlice::new(&[1, 2, 3, 4, 5, 6], [2, 4], Order::RowMajor).unwrap(); // less elements
31/// ```
32pub struct NdSlice<'s, T, const N: usize> {
33    slice: &'s [T],
34    shape: [usize; N],
35    order: Order,
36}
37
38type ConstructionResult<'s, T, const N: usize> = Result<NdSlice<'s, T, N>, ShapeError<'s, T, N>>;
39
40impl<'s, T, const N: usize> NdSlice<'s, T, N> {
41    /// Creates a new `NdSlice` with the specified ordering from a given slice and the expected shape
42    pub fn new(slice: &'s [T], shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
43        if slice.len() == shape.iter().fold(1, |acc, &x| acc * x) {
44            Ok(Self { slice, shape, order })
45        } else {
46            Err(ShapeError::new(slice, shape))
47        }
48    }
49
50    /// Creates a new `NdSlice` with the specified ordering from a raw pointer, it's length and the expected shape
51    pub unsafe fn from_ptr(ptr: *const T, len: usize, shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
52        NdSlice::new(slice::from_raw_parts(ptr, len), shape, order)
53    }
54
55    /// Creates a new `NdSlice` with row-major ordering from a given slice and the expected shape
56    pub fn new_row_ordered(slice: &'s [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
57        NdSlice::new(slice, shape, Order::RowMajor)
58    }
59
60    /// Creates a new `NdSlice` with row-major ordering from a raw pointer, it's length and the expected shape
61    pub unsafe fn row_ordered_from_ptr(ptr: *const T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
62        NdSlice::from_ptr(ptr, len, shape, Order::RowMajor)
63    }
64
65    /// Creates a new `NdSlice` with column-major ordering from a given slice and the expected shape
66    pub fn col_ordered(slice: &'s [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
67        NdSlice::new(slice, shape, Order::ColumnMajor)
68    }
69
70    /// Creates a new `NdSlice` with column-major ordering from a raw pointer, it's length and the expected shape.
71    pub unsafe fn col_ordered_from_ptr(ptr: *const T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
72        NdSlice::from_ptr(ptr, len, shape, Order::ColumnMajor)
73    }
74}
75
76impl<T, const N: usize> Index<[usize; N]> for NdSlice<'_, T, N> {
77    type Output = T;
78
79    fn index(&self, index: [usize; N]) -> &Self::Output {
80        &self.slice[address(&self.order, &self.shape, &index)]
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    const ARR: [usize; 8] = [1, 2, 3, 4, 5, 6, 7, 8];
89
90    fn new_check<const N: usize>(shape: [usize; N]) {
91        NdSlice::new_row_ordered(&ARR, shape).unwrap();
92    }
93
94    #[test]
95    fn new() {
96        new_check([8]);
97        new_check([4, 2]);
98        new_check([2, 2, 2]);
99    }
100
101    #[test]
102    #[should_panic]
103    fn new_err() {
104        new_check([9]);
105        new_check([4, 3]);
106        new_check([2, 2, 3]);
107    }
108
109    #[test]
110    fn new_from_ptr() {
111        let ptr = ARR.as_ptr();
112        unsafe {
113            NdSlice::row_ordered_from_ptr(ptr, 8, [8]).unwrap();
114            NdSlice::row_ordered_from_ptr(ptr, 8, [4, 2]).unwrap();
115            NdSlice::row_ordered_from_ptr(ptr, 8, [2, 2, 2]).unwrap();
116        }
117    }
118
119    #[test]
120    fn index_test() {
121        let rm = NdSlice::new_row_ordered(&[1, 2, 3, 4, 5, 6], [2, 3]).unwrap();
122        let cm = NdSlice::col_ordered(&[1, 4, 2, 5, 3, 6], [2, 3]).unwrap();
123
124        assert!(rm[[0, 0]] == 1 && rm[[0, 0]] == cm[[0, 0]]);
125        assert!(rm[[0, 1]] == 2 && rm[[0, 1]] == cm[[0, 1]]);
126        assert!(rm[[0, 2]] == 3 && rm[[0, 2]] == cm[[0, 2]]);
127        assert!(rm[[1, 0]] == 4 && rm[[1, 0]] == cm[[1, 0]]);
128        assert!(rm[[1, 1]] == 5 && rm[[1, 1]] == cm[[1, 1]]);
129        assert!(rm[[1, 2]] == 6 && rm[[1, 2]] == cm[[1, 2]]);
130    }
131}