redstone_ml/ndarray/
slice.rs1use crate::axis::Axis;
2use crate::dtype::RawDataType;
3use crate::index::Indexer;
4use crate::iterator::collapse_contiguous::has_uniform_stride;
5use crate::ndarray::flags::NdArrayFlags;
6use crate::{AxisType, NdArray, StridedMemory};
7
8pub(super) fn update_flags_with_contiguity(mut flags: NdArrayFlags, shape: &[usize], stride: &[usize]) -> NdArrayFlags {
9 match has_uniform_stride(shape, stride) {
10 None => {
11 flags -= NdArrayFlags::UniformStride;
12 flags -= NdArrayFlags::Contiguous;
13 }
14 Some(stride) => {
15 flags |= NdArrayFlags::UniformStride;
16
17 if stride <= 1 {
18 flags |= NdArrayFlags::Contiguous;
19 } else {
20 flags -= NdArrayFlags::Contiguous;
21 }
22 }
23 }
24
25 flags
26}
27
28fn calculate_strided_buffer_length(shape: &[usize], stride: &[usize]) -> usize {
29 shape.iter().zip(stride.iter())
36 .map(|(&axis_length, &axis_stride)| axis_stride * (axis_length - 1))
37 .sum::<usize>() + 1
38}
39
40
41impl<'a, T: RawDataType> NdArray<'a, T> {
42 pub fn slice_along<S: Indexer>(&'a self, axis: Axis, index: S) -> NdArray<'a, T>
43 {
44 let axis = axis.as_absolute(self.ndims());
45
46 let mut new_shape = self.shape.clone();
47 let mut new_stride = self.stride.clone();
48
49 if index.collapse_dimension() {
50 new_shape.remove(axis);
51 new_stride.remove(axis);
52 } else {
53 new_shape[axis] = index.indexed_length(new_shape[axis]);
54 }
55
56 let offset = self.stride[axis] * index.index_of_first_element();
57
58 let len = calculate_strided_buffer_length(&new_shape, &new_stride);
59
60 let mut flags = update_flags_with_contiguity(self.flags, &new_shape, &new_stride);
61 flags -= NdArrayFlags::Owned;
62 flags -= NdArrayFlags::UserCreated;
63
64 NdArray {
65 ptr: unsafe { self.ptr.add(offset) },
66 len,
67 capacity: len,
68
69 shape: new_shape,
70 stride: new_stride,
71 flags,
72
73 _marker: self._marker,
74 }
75 }
76
77 pub fn slice<S, I>(&'a self, index: I) -> NdArray<'a, T>
78 where
79 S: Indexer,
80 I: IntoIterator<Item=S>,
81 {
82 let ndims = self.ndims();
83 let mut offset = 0;
84 let mut axis = 0;
85
86 let mut new_shape = Vec::with_capacity(ndims);
87 let mut new_stride = Vec::with_capacity(ndims);
88
89 for idx in index {
90 if !idx.collapse_dimension() {
91 let new_length = idx.indexed_length(self.shape[axis]);
92 new_shape.push(new_length);
93 new_stride.push(self.stride[axis]);
94 }
95
96 offset += self.stride[axis] * idx.index_of_first_element();
97 axis += 1;
98 }
99
100 for j in axis..ndims {
101 new_shape.push(self.shape[j]);
102 new_stride.push(self.stride[j]);
103 }
104
105 let len = calculate_strided_buffer_length(&new_shape, &new_stride);
106 let mut flags = update_flags_with_contiguity(self.flags, &new_shape, &new_stride);
107 flags -= NdArrayFlags::Owned;
108 flags -= NdArrayFlags::UserCreated;
109
110 NdArray {
111 ptr: unsafe { self.ptr.add(offset) },
112 len,
113 capacity: 0,
114
115 shape: new_shape,
116 stride: new_stride,
117 flags,
118
119 _marker: self._marker,
120 }
121 }
122}