rlst/dense/array/
slice.rs

1//! Array slicing.
2
3use crate::dense::{
4    layout::{convert_1d_nd_from_shape, convert_nd_raw},
5    number_types::{IsGreaterByOne, IsGreaterZero, NumberType},
6    types::RlstBase,
7};
8
9use super::{
10    empty_chunk, Array, ChunkedAccess, RawAccess, RawAccessMut, Shape, Stride,
11    UnsafeRandomAccessByRef, UnsafeRandomAccessByValue, UnsafeRandomAccessMut,
12};
13
14/// Generic structure to store Array slices.
15pub struct ArraySlice<
16    Item: RlstBase,
17    ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
18    const ADIM: usize,
19    const NDIM: usize,
20> where
21    NumberType<ADIM>: IsGreaterByOne<NDIM>,
22    NumberType<NDIM>: IsGreaterZero,
23{
24    arr: Array<Item, ArrayImpl, ADIM>,
25    // The first entry is the axis, the second is the index in the axis.
26    slice: [usize; 2],
27    mask: [usize; NDIM],
28}
29
30// Implementation of ArraySlice
31
32impl<
33        Item: RlstBase,
34        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
35        const ADIM: usize,
36        const NDIM: usize,
37    > ArraySlice<Item, ArrayImpl, ADIM, NDIM>
38where
39    NumberType<ADIM>: IsGreaterByOne<NDIM>,
40    NumberType<NDIM>: IsGreaterZero,
41{
42    /// Create new array slice
43    pub fn new(arr: Array<Item, ArrayImpl, ADIM>, slice: [usize; 2]) -> Self {
44        // The mask is zero for all entries before the sliced out one and
45        // one for all entries after.
46        let mut mask = [1; NDIM];
47        assert!(
48            slice[0] < ADIM,
49            "Axis {} out of bounds. Array has {} axes.",
50            slice[0],
51            ADIM
52        );
53        assert!(
54            slice[1] < arr.shape()[slice[0]],
55            "Index {} in axis {} out of bounds. Dimension of axis is {}.",
56            slice[1],
57            slice[0],
58            arr.shape()[slice[0]]
59        );
60        mask.iter_mut().take(slice[0]).for_each(|val| *val = 0);
61        Self { arr, slice, mask }
62    }
63}
64
65impl<
66        Item: RlstBase,
67        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
68        const ADIM: usize,
69        const NDIM: usize,
70    > UnsafeRandomAccessByValue<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
71where
72    NumberType<ADIM>: IsGreaterByOne<NDIM>,
73    NumberType<NDIM>: IsGreaterZero,
74{
75    type Item = Item;
76
77    unsafe fn get_value_unchecked(&self, multi_index: [usize; NDIM]) -> Self::Item {
78        let mut orig_index = multi_index_to_orig(multi_index, self.mask);
79        orig_index[self.slice[0]] = self.slice[1];
80        self.arr.get_value_unchecked(orig_index)
81    }
82}
83
84impl<
85        Item: RlstBase,
86        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
87            + Shape<ADIM>
88            + UnsafeRandomAccessByRef<ADIM, Item = Item>,
89        const ADIM: usize,
90        const NDIM: usize,
91    > UnsafeRandomAccessByRef<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
92where
93    NumberType<ADIM>: IsGreaterByOne<NDIM>,
94    NumberType<NDIM>: IsGreaterZero,
95{
96    type Item = Item;
97
98    unsafe fn get_unchecked(&self, multi_index: [usize; NDIM]) -> &Self::Item {
99        let mut orig_index = multi_index_to_orig(multi_index, self.mask);
100        orig_index[self.slice[0]] = self.slice[1];
101        self.arr.get_unchecked(orig_index)
102    }
103}
104
105impl<
106        Item: RlstBase,
107        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
108        const ADIM: usize,
109        const NDIM: usize,
110    > Shape<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
111where
112    NumberType<ADIM>: IsGreaterByOne<NDIM>,
113    NumberType<NDIM>: IsGreaterZero,
114{
115    fn shape(&self) -> [usize; NDIM] {
116        let mut result = [0; NDIM];
117        let orig_shape = self.arr.shape();
118
119        for (index, value) in result.iter_mut().enumerate() {
120            *value = orig_shape[index + self.mask[index]];
121        }
122
123        result
124    }
125}
126
127impl<
128        Item: RlstBase,
129        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
130            + Shape<ADIM>
131            + RawAccess<Item = Item>
132            + Stride<ADIM>,
133        const ADIM: usize,
134        const NDIM: usize,
135    > RawAccess for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
136where
137    NumberType<ADIM>: IsGreaterByOne<NDIM>,
138    NumberType<NDIM>: IsGreaterZero,
139{
140    type Item = Item;
141    fn data(&self) -> &[Self::Item] {
142        assert!(!self.is_empty());
143        let (start_raw, end_raw) =
144            compute_raw_range(self.slice, self.arr.stride(), self.arr.shape());
145
146        &self.arr.data()[start_raw..end_raw]
147    }
148
149    fn buff_ptr(&self) -> *const Self::Item {
150        self.arr.buff_ptr()
151    }
152
153    fn offset(&self) -> usize {
154        let mut orig_index = [0; ADIM];
155        orig_index[self.slice[0]] = self.slice[1];
156        self.arr.offset() + convert_nd_raw(orig_index, self.arr.stride())
157    }
158}
159
160impl<
161        Item: RlstBase,
162        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM> + Stride<ADIM>,
163        const ADIM: usize,
164        const NDIM: usize,
165    > Stride<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
166where
167    NumberType<ADIM>: IsGreaterByOne<NDIM>,
168    NumberType<NDIM>: IsGreaterZero,
169{
170    fn stride(&self) -> [usize; NDIM] {
171        let mut result = [0; NDIM];
172        let orig_stride: [usize; ADIM] = self.arr.stride();
173
174        for (index, value) in result.iter_mut().enumerate() {
175            *value = orig_stride[index + self.mask[index]];
176        }
177
178        result
179    }
180}
181
182impl<
183        Item: RlstBase,
184        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM>,
185        const ADIM: usize,
186    > Array<Item, ArrayImpl, ADIM>
187{
188    /// Create a slice from a given array.
189    ///
190    /// Consider an array `arr` with shape `[a0, a1, a2, a3, ...]`. The function call
191    /// `arr.slice(2, 3)` returns a one dimension smaller array indexed by `[a0, a1, 3, a3, ...]`.
192    /// Hence, the dimension `2` has been fixed to always have the value `3.`
193    ///
194    /// # Examples
195    ///
196    /// If `arr` is a matrix then the first column of the matrix is obtained from
197    /// `arr.slice(1, 0)`, while the third row of the matrix is obtained from
198    /// `arr.slice(0, 2)`.
199    pub fn slice<const NDIM: usize>(
200        self,
201        axis: usize,
202        index: usize,
203    ) -> Array<Item, ArraySlice<Item, ArrayImpl, ADIM, NDIM>, NDIM>
204    where
205        NumberType<ADIM>: IsGreaterByOne<NDIM>,
206        NumberType<NDIM>: IsGreaterZero,
207    {
208        Array::new(ArraySlice::new(self, [axis, index]))
209    }
210}
211
212impl<
213        Item: RlstBase,
214        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item> + Shape<ADIM> + Stride<ADIM>,
215        const ADIM: usize,
216        const NDIM: usize,
217        const N: usize,
218    > ChunkedAccess<N> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
219where
220    NumberType<ADIM>: IsGreaterByOne<NDIM>,
221    NumberType<NDIM>: IsGreaterZero,
222{
223    type Item = Item;
224
225    #[inline]
226    fn get_chunk(
227        &self,
228        chunk_index: usize,
229    ) -> Option<crate::dense::types::DataChunk<Self::Item, N>> {
230        let nelements = self.shape().iter().product();
231        if let Some(mut chunk) = empty_chunk(chunk_index, nelements) {
232            for count in 0..chunk.valid_entries {
233                unsafe {
234                    chunk.data[count] = self.get_value_unchecked(convert_1d_nd_from_shape(
235                        chunk.start_index + count,
236                        self.shape(),
237                    ))
238                }
239            }
240            Some(chunk)
241        } else {
242            None
243        }
244    }
245}
246
247impl<
248        Item: RlstBase,
249        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
250            + Shape<ADIM>
251            + UnsafeRandomAccessMut<ADIM, Item = Item>,
252        const ADIM: usize,
253        const NDIM: usize,
254    > UnsafeRandomAccessMut<NDIM> for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
255where
256    NumberType<ADIM>: IsGreaterByOne<NDIM>,
257    NumberType<NDIM>: IsGreaterZero,
258{
259    type Item = Item;
260
261    unsafe fn get_unchecked_mut(&mut self, multi_index: [usize; NDIM]) -> &mut Self::Item {
262        let mut orig_index = multi_index_to_orig(multi_index, self.mask);
263        orig_index[self.slice[0]] = self.slice[1];
264        self.arr.get_unchecked_mut(orig_index)
265    }
266}
267
268impl<
269        Item: RlstBase,
270        ArrayImpl: UnsafeRandomAccessByValue<ADIM, Item = Item>
271            + Shape<ADIM>
272            + RawAccessMut<Item = Item>
273            + Stride<ADIM>
274            + UnsafeRandomAccessMut<ADIM, Item = Item>,
275        const ADIM: usize,
276        const NDIM: usize,
277    > RawAccessMut for ArraySlice<Item, ArrayImpl, ADIM, NDIM>
278where
279    NumberType<ADIM>: IsGreaterByOne<NDIM>,
280    NumberType<NDIM>: IsGreaterZero,
281{
282    fn data_mut(&mut self) -> &mut [Self::Item] {
283        assert!(!self.is_empty());
284        let (start_raw, end_raw) =
285            compute_raw_range(self.slice, self.arr.stride(), self.arr.shape());
286        &mut self.arr.data_mut()[start_raw..end_raw]
287    }
288
289    fn buff_ptr_mut(&mut self) -> *mut Self::Item {
290        self.arr.buff_ptr_mut()
291    }
292}
293
294// ////////////////////
295
296fn multi_index_to_orig<const ADIM: usize, const NDIM: usize>(
297    multi_index: [usize; NDIM],
298    mask: [usize; NDIM],
299) -> [usize; ADIM]
300where
301    NumberType<ADIM>: IsGreaterByOne<NDIM>,
302    NumberType<NDIM>: IsGreaterZero,
303{
304    let mut orig = [0; ADIM];
305    for (index, &value) in multi_index.iter().enumerate() {
306        orig[index + mask[index]] = value;
307    }
308    orig
309}
310
311fn compute_raw_range<const NDIM: usize>(
312    slice: [usize; 2],
313    stride: [usize; NDIM],
314    shape: [usize; NDIM],
315) -> (usize, usize) {
316    let mut start_multi_index = [0; NDIM];
317    start_multi_index[slice[0]] = slice[1];
318    let mut end_multi_index = shape;
319    for (index, value) in end_multi_index.iter_mut().enumerate() {
320        if index == slice[0] {
321            *value = slice[1]
322        } else {
323            // We started with the shape. Reduce
324            // each value of the shape by 1 to get last
325            // index in that dimension.
326            assert!(*value > 0);
327            *value -= 1;
328        }
329    }
330    (
331        convert_nd_raw(start_multi_index, stride),
332        1 + convert_nd_raw(end_multi_index, stride),
333    )
334}