cubecl_core/frontend/container/
slice.rs

1use std::marker::PhantomData;
2
3use crate::{
4    frontend::{indexation::Index, Tensor},
5    ir::{self, Operator},
6    prelude::{CubeContext, IntoRuntime},
7    unexpanded,
8};
9use crate::{
10    frontend::{
11        Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory,
12        SizedContainer,
13    },
14    ir::Instruction,
15};
16
17use super::Line;
18
19/// A read-only contiguous list of elements
20///
21/// # Safety
22///
23/// Since data can't be deallocated during kernel execution, this is safe.
24#[derive(Clone)]
25pub struct Slice<E> {
26    _e: PhantomData<E>,
27}
28
29/// A read-write contiguous list of elements.
30///
31/// # Safety
32///
33/// Since data can be accessed by any unit during kernel execution, this can never be safe.
34pub struct SliceMut<E> {
35    _e: PhantomData<E>,
36}
37
38mod metadata {
39    use super::*;
40
41    impl<E> Slice<E> {
42        /// Get the length of the slice.
43        #[allow(clippy::len_without_is_empty)]
44        pub fn len(&self) -> u32 {
45            unexpanded!()
46        }
47
48        /// Returns the same slice, but with lines of length 1.
49        pub fn to_aligned(&self) -> Slice<Line<E>>
50        where
51            E: CubePrimitive,
52        {
53            unexpanded!()
54        }
55    }
56
57    impl<E> SliceMut<E> {
58        /// Get the length of the slice.
59        #[allow(clippy::len_without_is_empty)]
60        pub fn len(&self) -> u32 {
61            unexpanded!()
62        }
63
64        /// Returns the same slice, but with lines of length 1.
65        pub fn into_aligned(self) -> SliceMut<Line<E>>
66        where
67            E: CubePrimitive,
68        {
69            unexpanded!()
70        }
71    }
72
73    impl<C: CubeType> ExpandElementTyped<Slice<C>> {
74        // Expand method of [len](Slice::len).
75        pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
76            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
77            elem.__expand_len_method(context)
78        }
79
80        // Expand method of [len](Slice::to_aligned).
81        pub fn __expand_to_aligned_method(
82            self,
83            _context: &mut CubeContext,
84        ) -> ExpandElementTyped<Slice<Line<C>>>
85        where
86            C: CubePrimitive,
87        {
88            self.expand.into()
89        }
90
91        // Expand method of [clone](Clone::clone).
92        pub fn __expand_clone_method(
93            self,
94            _context: &mut CubeContext,
95        ) -> ExpandElementTyped<Slice<Line<C>>>
96        where
97            C: CubePrimitive,
98        {
99            self.expand.into()
100        }
101    }
102
103    impl<C: CubeType> ExpandElementTyped<SliceMut<C>> {
104        // Expand method of [len](SliceMut::len).
105        pub fn __expand_len_method(self, context: &mut CubeContext) -> ExpandElementTyped<u32> {
106            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
107            elem.__expand_len_method(context)
108        }
109
110        // Expand method of [len](SliceMut::into_aligned).
111        pub fn __expand_into_aligned_method(
112            self,
113            _context: &mut CubeContext,
114        ) -> ExpandElementTyped<SliceMut<Line<C>>>
115        where
116            C: CubePrimitive,
117        {
118            self.expand.into()
119        }
120    }
121}
122
123/// Module that contains the implementation details of the index functions.
124mod indexation {
125    use ir::Instruction;
126
127    use crate::{
128        ir::{BinaryOperator, Operator},
129        prelude::{CubeIndex, CubeIndexMut},
130    };
131
132    use super::*;
133
134    impl<E: CubePrimitive> Slice<E> {
135        /// Perform an unchecked index into the array
136        ///
137        /// # Safety
138        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
139        /// always in bounds
140        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
141        where
142            Self: CubeIndex<I>,
143        {
144            unexpanded!()
145        }
146    }
147
148    impl<E: CubePrimitive> SliceMut<E> {
149        /// Perform an unchecked index into the array
150        ///
151        /// # Safety
152        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
153        /// always in bounds
154        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
155        where
156            Self: CubeIndex<I>,
157        {
158            unexpanded!()
159        }
160
161        /// Perform an unchecked index assignment into the array
162        ///
163        /// # Safety
164        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
165        /// always in bounds
166        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
167        where
168            Self: CubeIndexMut<I>,
169        {
170            unexpanded!()
171        }
172    }
173
174    impl<E: CubePrimitive> ExpandElementTyped<Slice<E>> {
175        pub fn __expand_index_unchecked_method(
176            self,
177            context: &mut CubeContext,
178            i: ExpandElementTyped<u32>,
179        ) -> ExpandElementTyped<E> {
180            let out = context.create_local(self.expand.item);
181            context.register(Instruction::new(
182                Operator::UncheckedIndex(BinaryOperator {
183                    lhs: *self.expand,
184                    rhs: i.expand.consume(),
185                }),
186                *out,
187            ));
188            out.into()
189        }
190    }
191
192    impl<E: CubePrimitive> ExpandElementTyped<SliceMut<E>> {
193        pub fn __expand_index_unchecked_method(
194            self,
195            context: &mut CubeContext,
196            i: ExpandElementTyped<u32>,
197        ) -> ExpandElementTyped<E> {
198            let out = context.create_local(self.expand.item);
199            context.register(Instruction::new(
200                Operator::UncheckedIndex(BinaryOperator {
201                    lhs: *self.expand,
202                    rhs: i.expand.consume(),
203                }),
204                *out,
205            ));
206            out.into()
207        }
208
209        pub fn __expand_index_assign_unchecked_method(
210            self,
211            context: &mut CubeContext,
212            i: ExpandElementTyped<u32>,
213            value: ExpandElementTyped<E>,
214        ) {
215            context.register(Instruction::new(
216                Operator::UncheckedIndexAssign(BinaryOperator {
217                    lhs: i.expand.consume(),
218                    rhs: value.expand.consume(),
219                }),
220                *self.expand,
221            ));
222        }
223    }
224}
225
226impl<E: CubeType> CubeType for Slice<E> {
227    type ExpandType = ExpandElementTyped<Slice<E>>;
228}
229
230impl<C: CubeType> Init for ExpandElementTyped<Slice<C>> {
231    fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
232        // The type can't be deeply cloned/copied.
233        self
234    }
235}
236
237impl<E: CubeType> CubeType for SliceMut<E> {
238    type ExpandType = ExpandElementTyped<SliceMut<E>>;
239}
240
241impl<E: CubeType> CubeType for &mut SliceMut<E> {
242    type ExpandType = ExpandElementTyped<SliceMut<E>>;
243}
244
245impl<C: CubeType> Init for ExpandElementTyped<SliceMut<C>> {
246    fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
247        // The type can't be deeply cloned/copied.
248        self
249    }
250}
251
252impl<C: CubeType<ExpandType = ExpandElementTyped<C>>> SizedContainer for Slice<C> {
253    type Item = C;
254}
255
256impl<T: CubeType> Iterator for Slice<T> {
257    type Item = T;
258
259    fn next(&mut self) -> Option<Self::Item> {
260        unexpanded!()
261    }
262}
263
264pub trait SliceOperator<E: CubeType>: CubeType<ExpandType = Self::Expand> {
265    type Expand: SliceOperatorExpand<E>;
266
267    /// Return a read-only view of all elements comprise between the `start` and `end` indices.
268    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
269    /// the length of `self`.
270    #[allow(unused_variables)]
271    fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> Slice<E> {
272        unexpanded!()
273    }
274    /// Expand function of [SliceOperator::slice].
275    fn __expand_slice(
276        context: &mut CubeContext,
277        expand: Self::Expand,
278        start: ExpandElementTyped<u32>,
279        end: ExpandElementTyped<u32>,
280    ) -> ExpandElementTyped<Slice<E>> {
281        expand.__expand_slice_method(context, start, end)
282    }
283
284    /// Return a read-write view of all elements comprise between the `start` and `end` indices.
285    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
286    /// the length of `self`.
287    #[allow(unused_variables)]
288    fn slice_mut<Start: Index, End: Index>(&mut self, start: Start, end: End) -> SliceMut<E> {
289        unexpanded!()
290    }
291
292    /// Expand function of [SliceOperator::slice_mut].
293    fn __expand_slice_mut(
294        context: &mut CubeContext,
295        expand: Self::Expand,
296        start: ExpandElementTyped<u32>,
297        end: ExpandElementTyped<u32>,
298    ) -> ExpandElementTyped<SliceMut<E>> {
299        expand.__expand_slice_mut_method(context, start, end)
300    }
301
302    /// Reinterprete the current type as a read-only slice.
303    #[allow(unused_variables)]
304    fn to_slice(&self) -> Slice<E> {
305        unexpanded!()
306    }
307
308    /// Expand function of [SliceOperator::to_slice].
309    fn __expand_to_slice(
310        context: &mut CubeContext,
311        expand: Self::Expand,
312    ) -> ExpandElementTyped<Slice<E>> {
313        expand.__expand_to_slice_method(context)
314    }
315
316    /// Reinterprete the current type as a read-write slice.
317    #[allow(unused_variables)]
318    fn to_slice_mut(&mut self) -> SliceMut<E> {
319        unexpanded!()
320    }
321
322    /// Expand function of [SliceOperator::to_slice_mut].
323    fn __expand_to_slice_mut(
324        context: &mut CubeContext,
325        expand: Self::Expand,
326    ) -> ExpandElementTyped<SliceMut<E>> {
327        expand.__expand_to_slice_mut_method(context)
328    }
329}
330
331pub trait SliceOperatorExpand<E: CubeType>: Into<ExpandElement> + Clone {
332    fn slice_base<Start: Index, End: Index>(
333        &self,
334        context: &mut CubeContext,
335        start: Start,
336        end: End,
337    ) -> ExpandElement;
338
339    fn __expand_slice_method(
340        &self,
341        context: &mut CubeContext,
342        start: ExpandElementTyped<u32>,
343        end: ExpandElementTyped<u32>,
344    ) -> ExpandElementTyped<Slice<E>> {
345        ExpandElementTyped::new(self.slice_base(context, start, end))
346    }
347
348    fn __expand_slice_mut_method(
349        &self,
350        context: &mut CubeContext,
351        start: ExpandElementTyped<u32>,
352        end: ExpandElementTyped<u32>,
353    ) -> ExpandElementTyped<SliceMut<E>> {
354        ExpandElementTyped::new(self.slice_base(context, start, end))
355    }
356
357    fn __expand_to_slice_method(&self, _context: &mut CubeContext) -> ExpandElementTyped<Slice<E>> {
358        let expand = self.clone().into();
359        ExpandElementTyped::new(expand)
360    }
361
362    fn __expand_to_slice_mut_method(
363        &self,
364        _context: &mut CubeContext,
365    ) -> ExpandElementTyped<SliceMut<E>> {
366        let expand = self.clone().into();
367        ExpandElementTyped::new(expand)
368    }
369}
370
371macro_rules! slice_op {
372    ($type:ident) => {
373        impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
374            type Expand = ExpandElementTyped<$type<E>>;
375        }
376
377        impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
378            fn slice_base<Start: Index, End: Index>(
379                &self,
380                context: &mut CubeContext,
381                start: Start,
382                end: End,
383            ) -> ExpandElement {
384                slice_expand(context, self.clone(), start, end)
385            }
386        }
387    };
388    (slice $type:ident) => {
389        impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
390            type Expand = ExpandElementTyped<$type<E>>;
391        }
392
393        impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
394            fn slice_base<Start: Index, End: Index>(
395                &self,
396                context: &mut CubeContext,
397                start: Start,
398                end: End,
399            ) -> ExpandElement {
400                slice_expand(context, self.clone(), start, end)
401            }
402        }
403    };
404}
405
406slice_op!(Array);
407slice_op!(Tensor);
408slice_op!(SharedMemory);
409slice_op!(slice Slice);
410slice_op!(slice SliceMut);
411
412pub fn slice_expand<I: Into<ExpandElement>, S1: Index, S2: Index>(
413    context: &mut CubeContext,
414    input: I,
415    start: S1,
416    end: S2, // Todo use it to get the length.
417) -> ExpandElement {
418    let input = input.into();
419    let out = context.create_slice(input.item);
420
421    context.register(Instruction::new(
422        Operator::Slice(ir::SliceOperator {
423            input: *input,
424            start: start.value(),
425            end: end.value(),
426        }),
427        *out,
428    ));
429
430    out
431}
432
433impl<E: CubePrimitive> IntoRuntime for Slice<E> {
434    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
435        unimplemented!("Array can't exist at compile time")
436    }
437}
438
439impl<E: CubePrimitive> IntoRuntime for SliceMut<E> {
440    fn __expand_runtime_method(self, _context: &mut CubeContext) -> Self::ExpandType {
441        unimplemented!("Array can't exist at compile time")
442    }
443}