1use std::ops::Index;
4use std::slice;
5use crate::addressing::{Order, address};
6use crate::errors::ShapeError;
7
8#[derive(Debug, Clone)]
9pub struct NdSlice<'s, T, const N: usize> {
33 slice: &'s [T],
34 shape: [usize; N],
35 order: Order,
36}
37
38type ConstructionResult<'s, T, const N: usize> = Result<NdSlice<'s, T, N>, ShapeError<'s, T, N>>;
39
40impl<'s, T, const N: usize> NdSlice<'s, T, N> {
41 pub fn new(slice: &'s [T], shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
43 if slice.len() == shape.iter().fold(1, |acc, &x| acc * x) {
44 Ok(Self { slice, shape, order })
45 } else {
46 Err(ShapeError::new(slice, shape))
47 }
48 }
49
50 pub unsafe fn from_ptr(ptr: *const T, len: usize, shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
52 NdSlice::new(slice::from_raw_parts(ptr, len), shape, order)
53 }
54
55 pub fn new_row_ordered(slice: &'s [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
57 NdSlice::new(slice, shape, Order::RowMajor)
58 }
59
60 pub unsafe fn row_ordered_from_ptr(ptr: *const T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
62 NdSlice::from_ptr(ptr, len, shape, Order::RowMajor)
63 }
64
65 pub fn col_ordered(slice: &'s [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
67 NdSlice::new(slice, shape, Order::ColumnMajor)
68 }
69
70 pub unsafe fn col_ordered_from_ptr(ptr: *const T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
72 NdSlice::from_ptr(ptr, len, shape, Order::ColumnMajor)
73 }
74}
75
76impl<T, const N: usize> Index<[usize; N]> for NdSlice<'_, T, N> {
77 type Output = T;
78
79 fn index(&self, index: [usize; N]) -> &Self::Output {
80 &self.slice[address(&self.order, &self.shape, &index)]
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::*;
87
88 const ARR: [usize; 8] = [1, 2, 3, 4, 5, 6, 7, 8];
89
90 fn new_check<const N: usize>(shape: [usize; N]) {
91 NdSlice::new_row_ordered(&ARR, shape).unwrap();
92 }
93
94 #[test]
95 fn new() {
96 new_check([8]);
97 new_check([4, 2]);
98 new_check([2, 2, 2]);
99 }
100
101 #[test]
102 #[should_panic]
103 fn new_err() {
104 new_check([9]);
105 new_check([4, 3]);
106 new_check([2, 2, 3]);
107 }
108
109 #[test]
110 fn new_from_ptr() {
111 let ptr = ARR.as_ptr();
112 unsafe {
113 NdSlice::row_ordered_from_ptr(ptr, 8, [8]).unwrap();
114 NdSlice::row_ordered_from_ptr(ptr, 8, [4, 2]).unwrap();
115 NdSlice::row_ordered_from_ptr(ptr, 8, [2, 2, 2]).unwrap();
116 }
117 }
118
119 #[test]
120 fn index_test() {
121 let rm = NdSlice::new_row_ordered(&[1, 2, 3, 4, 5, 6], [2, 3]).unwrap();
122 let cm = NdSlice::col_ordered(&[1, 4, 2, 5, 3, 6], [2, 3]).unwrap();
123
124 assert!(rm[[0, 0]] == 1 && rm[[0, 0]] == cm[[0, 0]]);
125 assert!(rm[[0, 1]] == 2 && rm[[0, 1]] == cm[[0, 1]]);
126 assert!(rm[[0, 2]] == 3 && rm[[0, 2]] == cm[[0, 2]]);
127 assert!(rm[[1, 0]] == 4 && rm[[1, 0]] == cm[[1, 0]]);
128 assert!(rm[[1, 1]] == 5 && rm[[1, 1]] == cm[[1, 1]]);
129 assert!(rm[[1, 2]] == 6 && rm[[1, 2]] == cm[[1, 2]]);
130 }
131}