1use crate::dense::{
4 layout::{convert_1d_nd_from_shape, convert_nd_raw},
5 number_types::{IsGreaterByOne, IsGreaterZero, NumberType},
6 types::RlstBase,
7};
8
9use super::{
10 empty_chunk, Array, ChunkedAccess, RawAccess, RawAccessMut, Shape, Stride,
11 UnsafeRandomAccessByRef, UnsafeRandomAccessByValue, UnsafeRandomAccessMut,
12};
13
14pub struct ArraySlice<
16 Item: RlstBase,
17 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
18 const ADIM: usize,
19 const NDIM: usize,
20> where
21 NumberType<ADIM>: IsGreaterByOne<NDIM>,
22 NumberType<NDIM>: IsGreaterZero,
23{
24 arr: Array<Item, ArrayImpl, ADIM>,
25 slice: [usize; 2],
27 mask: [usize; NDIM],
28}
29
30impl<
33 Item: RlstBase,
34 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
35 const ADIM: usize,
36 const NDIM: usize,
37 > ArraySlice<Item, ArrayImpl, ADIM, NDIM>
38where
39 NumberType<ADIM>: IsGreaterByOne<NDIM>,
40 NumberType<NDIM>: IsGreaterZero,
41{
42 pub fn new(arr: Array<Item, ArrayImpl, ADIM>, slice: [usize; 2]) -> Self {
44 let mut mask = [1; NDIM];
47 assert!(
48 slice[0] < ADIM,
49 "Axis {} out of bounds. Array has {} axes.",
50 slice[0],
51 ADIM
52 );
53 assert!(
54 slice[1] < arr.shape()[slice[0]],
55 "Index {} in axis {} out of bounds. Dimension of axis is {}.",
56 slice[1],
57 slice[0],
58 arr.shape()[slice[0]]
59 );
60 mask.iter_mut().take(slice[0]).for_each(|val| *val = 0);
61 Self { arr, slice, mask }
62 }
63}
64
65impl<
66 Item: RlstBase,
67 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
68 const ADIM: usize,
69 const NDIM: usize,
70 > UnsafeRandomAccessByValue<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
71where
72 NumberType<ADIM>: IsGreaterByOne<NDIM>,
73 NumberType<NDIM>: IsGreaterZero,
74{
75 type Item = Item;
76
77 unsafe fn get_value_unchecked(&self, multi_index: [usize; NDIM]) -> Self::Item {
78 let mut orig_index = multi_index_to_orig(multi_index, self.mask);
79 orig_index[self.slice[0]] = self.slice[1];
80 self.arr.get_value_unchecked(orig_index)
81 }
82}
83
84impl<
85 Item: RlstBase,
86 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
87 + Shape<ADIM>
88 + UnsafeRandomAccessByRef<ADIM, Item = Item>,
89 const ADIM: usize,
90 const NDIM: usize,
91 > UnsafeRandomAccessByRef<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
92where
93 NumberType<ADIM>: IsGreaterByOne<NDIM>,
94 NumberType<NDIM>: IsGreaterZero,
95{
96 type Item = Item;
97
98 unsafe fn get_unchecked(&self, multi_index: [usize; NDIM]) -> &Self::Item {
99 let mut orig_index = multi_index_to_orig(multi_index, self.mask);
100 orig_index[self.slice[0]] = self.slice[1];
101 self.arr.get_unchecked(orig_index)
102 }
103}
104
105impl<
106 Item: RlstBase,
107 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
108 const ADIM: usize,
109 const NDIM: usize,
110 > Shape<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
111where
112 NumberType<ADIM>: IsGreaterByOne<NDIM>,
113 NumberType<NDIM>: IsGreaterZero,
114{
115 fn shape(&self) -> [usize; NDIM] {
116 let mut result = [0; NDIM];
117 let orig_shape = self.arr.shape();
118
119 for (index, value) in result.iter_mut().enumerate() {
120 *value = orig_shape[index + self.mask[index]];
121 }
122
123 result
124 }
125}
126
127impl<
128 Item: RlstBase,
129 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
130 + Shape<ADIM>
131 + RawAccess<Item = Item>
132 + Stride<ADIM>,
133 const ADIM: usize,
134 const NDIM: usize,
135 > RawAccess for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
136where
137 NumberType<ADIM>: IsGreaterByOne<NDIM>,
138 NumberType<NDIM>: IsGreaterZero,
139{
140 type Item = Item;
141 fn data(&self) -> &[Self::Item] {
142 assert!(!self.is_empty());
143 let (start_raw, end_raw) =
144 compute_raw_range(self.slice, self.arr.stride(), self.arr.shape());
145
146 &self.arr.data()[start_raw..end_raw]
147 }
148
149 fn buff_ptr(&self) -> *const Self::Item {
150 self.arr.buff_ptr()
151 }
152
153 fn offset(&self) -> usize {
154 let mut orig_index = [0; ADIM];
155 orig_index[self.slice[0]] = self.slice[1];
156 self.arr.offset() + convert_nd_raw(orig_index, self.arr.stride())
157 }
158}
159
160impl<
161 Item: RlstBase,
162 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM> + Stride<ADIM>,
163 const ADIM: usize,
164 const NDIM: usize,
165 > Stride<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
166where
167 NumberType<ADIM>: IsGreaterByOne<NDIM>,
168 NumberType<NDIM>: IsGreaterZero,
169{
170 fn stride(&self) -> [usize; NDIM] {
171 let mut result = [0; NDIM];
172 let orig_stride: [usize; ADIM] = self.arr.stride();
173
174 for (index, value) in result.iter_mut().enumerate() {
175 *value = orig_stride[index + self.mask[index]];
176 }
177
178 result
179 }
180}
181
182impl<
183 Item: RlstBase,
184 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
185 const ADIM: usize,
186 > Array<Item, ArrayImpl, ADIM>
187{
188 pub fn slice<const NDIM: usize>(
200 self,
201 axis: usize,
202 index: usize,
203 ) -> Array<Item, ArraySlice<Item, ArrayImpl, ADIM, NDIM>, NDIM>
204 where
205 NumberType<ADIM>: IsGreaterByOne<NDIM>,
206 NumberType<NDIM>: IsGreaterZero,
207 {
208 Array::new(ArraySlice::new(self, [axis, index]))
209 }
210}
211
212impl<
213 Item: RlstBase,
214 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM> + Stride<ADIM>,
215 const ADIM: usize,
216 const NDIM: usize,
217 const N: usize,
218 > ChunkedAccess<N> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
219where
220 NumberType<ADIM>: IsGreaterByOne<NDIM>,
221 NumberType<NDIM>: IsGreaterZero,
222{
223 type Item = Item;
224
225 #[inline]
226 fn get_chunk(
227 &self,
228 chunk_index: usize,
229 ) -> Option<crate::dense::types::DataChunk<Self::Item, N>> {
230 let nelements = self.shape().iter().product();
231 if let Some(mut chunk) = empty_chunk(chunk_index, nelements) {
232 for count in 0..chunk.valid_entries {
233 unsafe {
234 chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape(
235 chunk.start_index + count,
236 self.shape(),
237 ))
238 }
239 }
240 Some(chunk)
241 } else {
242 None
243 }
244 }
245}
246
247impl<
248 Item: RlstBase,
249 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
250 + Shape<ADIM>
251 + UnsafeRandomAccessMut<ADIM, Item = Item>,
252 const ADIM: usize,
253 const NDIM: usize,
254 > UnsafeRandomAccessMut<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
255where
256 NumberType<ADIM>: IsGreaterByOne<NDIM>,
257 NumberType<NDIM>: IsGreaterZero,
258{
259 type Item = Item;
260
261 unsafe fn get_unchecked_mut(&mut self, multi_index: [usize; NDIM]) -> &mut Self::Item {
262 let mut orig_index = multi_index_to_orig(multi_index, self.mask);
263 orig_index[self.slice[0]] = self.slice[1];
264 self.arr.get_unchecked_mut(orig_index)
265 }
266}
267
268impl<
269 Item: RlstBase,
270 ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
271 + Shape<ADIM>
272 + RawAccessMut<Item = Item>
273 + Stride<ADIM>
274 + UnsafeRandomAccessMut<ADIM, Item = Item>,
275 const ADIM: usize,
276 const NDIM: usize,
277 > RawAccessMut for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
278where
279 NumberType<ADIM>: IsGreaterByOne<NDIM>,
280 NumberType<NDIM>: IsGreaterZero,
281{
282 fn data_mut(&mut self) -> &mut [Self::Item] {
283 assert!(!self.is_empty());
284 let (start_raw, end_raw) =
285 compute_raw_range(self.slice, self.arr.stride(), self.arr.shape());
286 &mut self.arr.data_mut()[start_raw..end_raw]
287 }
288
289 fn buff_ptr_mut(&mut self) -> *mut Self::Item {
290 self.arr.buff_ptr_mut()
291 }
292}
293
294fn multi_index_to_orig<const ADIM: usize, const NDIM: usize>(
297 multi_index: [usize; NDIM],
298 mask: [usize; NDIM],
299) -> [usize; ADIM]
300where
301 NumberType<ADIM>: IsGreaterByOne<NDIM>,
302 NumberType<NDIM>: IsGreaterZero,
303{
304 let mut orig = [0; ADIM];
305 for (index, &value) in multi_index.iter().enumerate() {
306 orig[index + mask[index]] = value;
307 }
308 orig
309}
310
311fn compute_raw_range<const NDIM: usize>(
312 slice: [usize; 2],
313 stride: [usize; NDIM],
314 shape: [usize; NDIM],
315) -> (usize, usize) {
316 let mut start_multi_index = [0; NDIM];
317 start_multi_index[slice[0]] = slice[1];
318 let mut end_multi_index = shape;
319 for (index, value) in end_multi_index.iter_mut().enumerate() {
320 if index == slice[0] {
321 *value = slice[1]
322 } else {
323 assert!(*value > 0);
327 *value -= 1;
328 }
329 }
330 (
331 convert_nd_raw(start_multi_index, stride),
332 1 + convert_nd_raw(end_multi_index, stride),
333 )
334}