nd_array/ndarray/array/
access.rs

1use std::{
2    borrow::Cow,
3    ops::{Index, Range},
4};
5
6use crate::Array;
7
8impl<'a, T: Clone, const D: usize> Array<'a, T, D> {
9    pub fn slice(&'a self, slice: &[Range<usize>; D]) -> Array<'a, T, D> {
10        let mut shape = self.shape.clone();
11        let strides = self.strides.clone();
12        let mut idx_maps = self.idx_maps.clone();
13
14        slice.iter().enumerate().for_each(|(axis, range)| {
15            if range.end > self.shape[axis] {
16                panic!(
17                    "Range: [{},{}) is out of bounds for axis: {}",
18                    range.start, range.end, axis
19                )
20            }
21        });
22
23        for axis in 0..D {
24            idx_maps[axis].append_b((slice[axis].start) as isize);
25            shape[axis] = slice[axis].end - slice[axis].start;
26        }
27
28        Array {
29            vec: Cow::from(&*self.vec),
30            shape,
31            strides,
32            idx_maps,
33        }
34    }
35
36    pub fn get(&self, indices: [usize; D]) -> Option<&T> {
37        if indices
38            .iter()
39            .enumerate()
40            .any(|(axis, idx)| *idx >= self.shape[axis])
41        {
42            return None;
43        }
44
45        let index = indices
46            .iter()
47            .enumerate()
48            .fold(0, |acc, (axis, axis_index)| {
49                acc + self.idx_maps[axis].map(*axis_index) * self.strides[axis]
50            });
51
52        self.vec.get(index)
53    }
54}
55
56impl<'a, T: Clone, const D: usize> Index<[usize; D]> for Array<'a, T, D> {
57    type Output = T;
58
59    fn index(&self, indices: [usize; D]) -> &Self::Output {
60        if indices
61            .iter()
62            .enumerate()
63            .any(|(axis, idx)| *idx >= self.shape[axis])
64        {
65            panic!("Index out of bound");
66        }
67
68        let index = indices
69            .iter()
70            .enumerate()
71            .fold(0, |acc, (axis, axis_index)| {
72                acc + self.idx_maps[axis].map(*axis_index) * self.strides[axis]
73            });
74
75        &self.vec[index]
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    #[test]
84    fn index_array() {
85        // 2-D array:
86        // 1 2 3
87        // 4 5 6
88        let array = Array::init(vec![1, 2, 3, 4, 5, 6], [2, 3]);
89
90        assert_eq!(array[[0, 0]], 1);
91        assert_eq!(array[[0, 1]], 2);
92        assert_eq!(array[[0, 2]], 3);
93        assert_eq!(array[[1, 0]], 4);
94        assert_eq!(array[[1, 1]], 5);
95        assert_eq!(array[[1, 2]], 6);
96    }
97
98    #[test]
99    fn slicing() {
100        // 2-D array:
101        // 1   2  3  4
102        // 5   6  7  8
103        // 9  10 11 12
104        // 13 14 15 16
105        let array = Array::init(
106            vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
107            [4, 4],
108        );
109
110        // flip the array across axis=0
111        // 13 14 15 16
112        // 9  10 11 12
113        // 5  6  7  8
114        // 1  2  3  4
115        let flipped = array.flip(0);
116
117        // slice the center of the array
118        // 10 11
119        // 6  7
120        let slice = flipped.slice(&[1..3, 1..3]);
121
122        // 11 10
123        // 7  6
124        let flip_of_slice = slice.flip(1);
125
126        assert_eq!(
127            flip_of_slice.flat().copied().collect::<Vec<usize>>(),
128            vec![11, 10, 7, 6]
129        );
130    }
131}