redstone_ml/ndarray/
index_impl.rs1use crate::dtype::RawDataType;
2use crate::{NdArray, StridedMemory};
3use std::ops::{Index, IndexMut};
4
5impl<T: RawDataType, const D: usize> Index<[usize; D]> for NdArray<'_, T> {
6 type Output = T;
7
8 fn index(&self, index: [usize; D]) -> &Self::Output {
9 assert_eq!(D, self.ndims(), "[] index must equal number of array dimensions!");
10
11 let i: usize = index.iter().zip(self.stride.iter())
12 .map(|(idx, stride)| idx * stride)
13 .sum();
14
15 assert!(i < self.len, "[] index out of bounds!");
16 unsafe { self.ptr.add(i).as_ref() }
17 }
18}
19
20impl<T: RawDataType, const D: usize> IndexMut<[usize; D]> for NdArray<'_, T> {
21 fn index_mut(&mut self, index: [usize; D]) -> &mut Self::Output {
22 assert!(D <= self.ndims(), "[] index must be equal number of array dimensions!");
23
24 let i: usize = index.iter().zip(self.stride.iter())
25 .map(|(idx, stride)| idx * stride)
26 .sum();
27
28 assert!(i < self.len, "[] index out of bounds!");
29 unsafe { self.ptr.add(i).as_mut() }
30 }
31}
32
33impl<T: RawDataType> Index<usize> for NdArray<'_, T> {
34 type Output = T;
35
36 fn index(&self, index: usize) -> &Self::Output {
37 &self[[index]]
38 }
39}
40
41impl<T: RawDataType> IndexMut<usize> for NdArray<'_, T> {
42 fn index_mut(&mut self, index: usize) -> &mut Self::Output {
43 &mut self[[index]]
44 }
45}