nd_array/ndarray/array/
access.rs1use std::{
2 borrow::Cow,
3 ops::{Index, Range},
4};
5
6use crate::Array;
7
8impl<'a, T: Clone, const D: usize> Array<'a, T, D> {
9 pub fn slice(&'a self, slice: &[Range<usize>; D]) -> Array<'a, T, D> {
10 let mut shape = self.shape.clone();
11 let strides = self.strides.clone();
12 let mut idx_maps = self.idx_maps.clone();
13
14 slice.iter().enumerate().for_each(|(axis, range)| {
15 if range.end > self.shape[axis] {
16 panic!(
17 "Range: [{},{}) is out of bounds for axis: {}",
18 range.start, range.end, axis
19 )
20 }
21 });
22
23 for axis in 0..D {
24 idx_maps[axis].append_b((slice[axis].start) as isize);
25 shape[axis] = slice[axis].end - slice[axis].start;
26 }
27
28 Array {
29 vec: Cow::from(&*self.vec),
30 shape,
31 strides,
32 idx_maps,
33 }
34 }
35
36 pub fn get(&self, indices: [usize; D]) -> Option<&T> {
37 if indices
38 .iter()
39 .enumerate()
40 .any(|(axis, idx)| *idx >= self.shape[axis])
41 {
42 return None;
43 }
44
45 let index = indices
46 .iter()
47 .enumerate()
48 .fold(0, |acc, (axis, axis_index)| {
49 acc + self.idx_maps[axis].map(*axis_index) * self.strides[axis]
50 });
51
52 self.vec.get(index)
53 }
54}
55
56impl<'a, T: Clone, const D: usize> Index<[usize; D]> for Array<'a, T, D> {
57 type Output = T;
58
59 fn index(&self, indices: [usize; D]) -> &Self::Output {
60 if indices
61 .iter()
62 .enumerate()
63 .any(|(axis, idx)| *idx >= self.shape[axis])
64 {
65 panic!("Index out of bound");
66 }
67
68 let index = indices
69 .iter()
70 .enumerate()
71 .fold(0, |acc, (axis, axis_index)| {
72 acc + self.idx_maps[axis].map(*axis_index) * self.strides[axis]
73 });
74
75 &self.vec[index]
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82
83 #[test]
84 fn index_array() {
85 let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
89
90 assert_eq!(array[[0, 0]], 1);
91 assert_eq!(array[[0, 1]], 2);
92 assert_eq!(array[[0, 2]], 3);
93 assert_eq!(array[[1, 0]], 4);
94 assert_eq!(array[[1, 1]], 5);
95 assert_eq!(array[[1, 2]], 6);
96 }
97
98 #[test]
99 fn slicing() {
100 let array = Array::init(
106 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
107 [4, 4],
108 );
109
110 let flipped = array.flip(0);
116
117 let slice = flipped.slice(&[1..3, 1..3]);
121
122 let flip_of_slice = slice.flip(1);
125
126 assert_eq!(
127 flip_of_slice.flat().copied().collect::<Vec<usize>>(),
128 vec![11, 10, 7, 6]
129 );
130 }
131}