redstone_ml/ndarray/
slice.rs

1use crate::axis::Axis;
2use crate::dtype::RawDataType;
3use crate::index::Indexer;
4use crate::iterator::collapse_contiguous::has_uniform_stride;
5use crate::ndarray::flags::NdArrayFlags;
6use crate::{AxisType, NdArray, StridedMemory};
7
8pub(super) fn update_flags_with_contiguity(mut flags: NdArrayFlags, shape: &[usize], stride: &[usize]) -> NdArrayFlags {
9    match has_uniform_stride(shape, stride) {
10        None => {
11            flags -= NdArrayFlags::UniformStride;
12            flags -= NdArrayFlags::Contiguous;
13        }
14        Some(stride) => {
15            flags |= NdArrayFlags::UniformStride;
16
17            if stride <= 1 {
18                flags |= NdArrayFlags::Contiguous;
19            } else {
20                flags -= NdArrayFlags::Contiguous;
21            }
22        }
23    }
24
25    flags
26}
27
28fn calculate_strided_buffer_length(shape: &[usize], stride: &[usize]) -> usize {
29    // let mut len = 1;
30    // for i in 0..ndims {
31    //     len += stride[i] * (shape[i] - 1);
32    // }
33    //
34    // the following code is equivalent to the above loop
35    shape.iter().zip(stride.iter())
36         .map(|(&axis_length, &axis_stride)| axis_stride * (axis_length - 1))
37         .sum::<usize>() + 1
38}
39
40
41impl<'a, T: RawDataType> NdArray<'a, T> {
42    pub fn slice_along<S: Indexer>(&'a self, axis: Axis, index: S) -> NdArray<'a, T>
43    {
44        let axis = axis.as_absolute(self.ndims());
45
46        let mut new_shape = self.shape.clone();
47        let mut new_stride = self.stride.clone();
48
49        if index.collapse_dimension() {
50            new_shape.remove(axis);
51            new_stride.remove(axis);
52        } else {
53            new_shape[axis] = index.indexed_length(new_shape[axis]);
54        }
55
56        let offset = self.stride[axis] * index.index_of_first_element();
57
58        let len = calculate_strided_buffer_length(&new_shape, &new_stride);
59        
60        let mut flags = update_flags_with_contiguity(self.flags, &new_shape, &new_stride);
61        flags -= NdArrayFlags::Owned;
62        flags -= NdArrayFlags::UserCreated;
63
64        NdArray {
65            ptr: unsafe { self.ptr.add(offset) },
66            len,
67            capacity: len,
68
69            shape: new_shape,
70            stride: new_stride,
71            flags,
72
73            _marker: self._marker,
74        }
75    }
76
77    pub fn slice<S, I>(&'a self, index: I) -> NdArray<'a, T>
78    where
79        S: Indexer,
80        I: IntoIterator<Item=S>,
81    {
82        let ndims = self.ndims();
83        let mut offset = 0;
84        let mut axis = 0;
85
86        let mut new_shape = Vec::with_capacity(ndims);
87        let mut new_stride = Vec::with_capacity(ndims);
88
89        for idx in index {
90            if !idx.collapse_dimension() {
91                let new_length = idx.indexed_length(self.shape[axis]);
92                new_shape.push(new_length);
93                new_stride.push(self.stride[axis]);
94            }
95
96            offset += self.stride[axis] * idx.index_of_first_element();
97            axis += 1;
98        }
99
100        for j in axis..ndims {
101            new_shape.push(self.shape[j]);
102            new_stride.push(self.stride[j]);
103        }
104
105        let len = calculate_strided_buffer_length(&new_shape, &new_stride);
106        let mut flags = update_flags_with_contiguity(self.flags, &new_shape, &new_stride);
107        flags -= NdArrayFlags::Owned;
108        flags -= NdArrayFlags::UserCreated;
109
110        NdArray {
111            ptr: unsafe { self.ptr.add(offset) },
112            len,
113            capacity: 0,
114
115            shape: new_shape,
116            stride: new_stride,
117            flags,
118
119            _marker: self._marker,
120        }
121    }
122}