redstone_ml/ndarray/
index_impl.rs

1use crate::dtype::RawDataType;
2use crate::{NdArray, StridedMemory};
3use std::ops::{Index, IndexMut};
4
5impl<T: RawDataType, const D: usize> Index<[usize; D]> for NdArray<'_, T> {
6    type Output = T;
7
8    fn index(&self, index: [usize; D]) -> &Self::Output {
9        assert_eq!(D, self.ndims(), "[] index must equal number of array dimensions!");
10
11        let i: usize = index.iter().zip(self.stride.iter())
12            .map(|(idx, stride)| idx * stride)
13            .sum();
14
15        assert!(i < self.len, "[] index out of bounds!");
16        unsafe { self.ptr.add(i).as_ref() }
17    }
18}
19
20impl<T: RawDataType, const D: usize> IndexMut<[usize; D]> for NdArray<'_, T> {
21    fn index_mut(&mut self, index: [usize; D]) -> &mut Self::Output {
22        assert!(D <= self.ndims(), "[] index must be equal number of array dimensions!");
23
24        let i: usize = index.iter().zip(self.stride.iter())
25                            .map(|(idx, stride)| idx * stride)
26                            .sum();
27
28        assert!(i < self.len, "[] index out of bounds!");
29        unsafe { self.ptr.add(i).as_mut() }
30    }
31}
32
33impl<T: RawDataType> Index<usize> for NdArray<'_, T> {
34    type Output = T;
35
36    fn index(&self, index: usize) -> &Self::Output {
37        &self[[index]]
38    }
39}
40
41impl<T: RawDataType> IndexMut<usize> for NdArray<'_, T> {
42    fn index_mut(&mut self, index: usize) -> &mut Self::Output {
43        &mut self[[index]]
44    }
45}