cubecl_core/frontend/container/
slice.rs

1use super::Line;
2use crate::{
3    frontend::{
4        Array, CubePrimitive, CubeType, ExpandElementTyped, Init, SharedMemory, SizedContainer,
5        Tensor, indexation::Index,
6    },
7    ir::{Instruction, Scope},
8    prelude::{CubeDebug, List, ListExpand, ListMut, ListMutExpand, index, index_assign},
9    unexpanded,
10};
11use cubecl_common::tf32;
12use cubecl_ir::{ExpandElement, Operator};
13use std::marker::PhantomData;
14
15/// A read-only contiguous list of elements
16///
17/// # Safety
18///
19/// Since data can't be deallocated during kernel execution, this is safe.
20#[derive(Clone, Copy)]
21pub struct Slice<E> {
22    _e: PhantomData<E>,
23}
24
25/// A read-write contiguous list of elements.
26///
27/// # Safety
28///
29/// Since data can be accessed by any unit during kernel execution, this can never be safe.
30#[derive(Clone, Copy)]
31pub struct SliceMut<E> {
32    _e: PhantomData<E>,
33}
34
35#[allow(unused)]
36mod metadata {
37    use core::num::NonZero;
38
39    use cubecl_ir::{Elem, FloatKind, Item, NonSemantic};
40
41    use crate::prelude::cube_comment;
42
43    use super::*;
44
45    impl<E> Slice<E> {
46        /// Get the length of the slice.
47        #[allow(clippy::len_without_is_empty)]
48        pub fn len(&self) -> u32 {
49            unexpanded!()
50        }
51
52        /// Returns the same slice, but with lines of length 1.
53        pub fn into_lined(&self) -> Slice<Line<E>>
54        where
55            E: CubePrimitive,
56        {
57            unexpanded!()
58        }
59        /// Try to cast the slice to the given type and panic if the type isn't the same.
60        ///
61        /// This function should only be used to satisfy the Rust type system, when two generic
62        /// types are supposed to be the same.
63        pub fn try_cast_unchecked<T>(&self) -> Slice<T>
64        where
65            E: CubePrimitive,
66            T: CubePrimitive,
67        {
68            unexpanded!()
69        }
70    }
71
72    impl<E: CubePrimitive> Slice<Line<E>> {
73        /// Return a new Slice with updated line_size. This doesn't copy or move the data,
74        /// it simply reinterpret how they are loaded and stored in memory.
75        ///
76        /// # Warning
77        ///
78        /// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
79        pub fn with_line_size(&self, line_size: u32) -> Slice<Line<E>> {
80            unexpanded!()
81        }
82    }
83
84    impl<E> SliceMut<E> {
85        /// Get the length of the slice.
86        #[allow(clippy::len_without_is_empty)]
87        pub fn len(&self) -> u32 {
88            unexpanded!()
89        }
90
91        /// Returns the same slice, but with lines of length 1.
92        pub fn into_lined(self) -> SliceMut<Line<E>>
93        where
94            E: CubePrimitive,
95        {
96            unexpanded!()
97        }
98
99        /// Try to cast the slice to the given type and panic if the type isn't the same.
100        ///
101        /// This function should only be used to satisfy the Rust type system, when two generic
102        /// types are supposed to be the same.
103        pub fn try_cast_unchecked<T>(&self) -> SliceMut<T>
104        where
105            E: CubePrimitive,
106            T: CubePrimitive,
107        {
108            unexpanded!()
109        }
110    }
111
112    impl<E: CubePrimitive> SliceMut<Line<E>> {
113        /// Return a new SliceMut with updated line_size. This doesn't copy or move the data,
114        /// it simply reinterpret how they are loaded and stored in memory.
115        ///
116        /// # Warning
117        ///
118        /// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
119        pub fn with_line_size(&self, line_size: u32) -> SliceMut<Line<E>> {
120            unexpanded!()
121        }
122    }
123
124    impl<E: CubePrimitive> ExpandElementTyped<Slice<E>> {
125        // Expand method of [len](Slice::len).
126        pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
127            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
128            elem.__expand_len_method(scope)
129        }
130
131        /// Expand method of [len](Slice::into_lined).
132        pub fn __expand_into_lined_method(
133            self,
134            _scope: &mut Scope,
135        ) -> ExpandElementTyped<Slice<Line<E>>>
136        where
137            E: CubePrimitive,
138        {
139            self.expand.into()
140        }
141
142        /// Expand method of [try_cast_unchecked](Slice::try_cast_unchecked).
143        pub fn __expand_try_cast_unchecked_method<T>(
144            self,
145            scope: &mut Scope,
146        ) -> ExpandElementTyped<Slice<T>>
147        where
148            E: CubePrimitive,
149            T: CubePrimitive,
150        {
151            if T::as_elem(scope) != E::as_elem(scope) && !is_tf32::<E, T>(scope) {
152                let elems = [T::as_elem(scope), E::as_elem(scope)];
153                let is_flex32_cast = elems.contains(&Elem::Float(FloatKind::F32))
154                    && elems.contains(&Elem::Float(FloatKind::Flex32));
155
156                if !is_flex32_cast {
157                    panic!(
158                        "Try cast unchecked should only be used to satisfy the rust type system."
159                    )
160                }
161            }
162
163            self.expand.into()
164        }
165
166        pub fn __expand_clone_method(self, _scope: &mut Scope) -> ExpandElementTyped<Slice<Line<E>>>
167        where
168            E: CubePrimitive,
169        {
170            self.expand.into()
171        }
172    }
173
174    impl<E: CubePrimitive> ExpandElementTyped<Slice<Line<E>>> {
175        /// Expand method of [with_line_size](Slice::with_line_size).
176        pub fn __expand_with_line_size_method(
177            self,
178            scope: &mut Scope,
179            line_size: u32,
180        ) -> ExpandElementTyped<Slice<Line<E>>>
181        where
182            E: CubePrimitive,
183        {
184            let input = self.clone().into_variable();
185            let mut item = input.item;
186
187            if line_size as u8 == item.vectorization.unwrap_or(NonZero::new(1).unwrap()).get() {
188                return self;
189            }
190
191            item.vectorization = NonZero::new(line_size as u8);
192            let out = scope.create_slice(item);
193
194            scope.register(Instruction::new(
195                Operator::ReinterpretSlice(cubecl_ir::ReinterpretSliceOperator {
196                    input,
197                    line_size,
198                }),
199                *out,
200            ));
201
202            out.into()
203        }
204    }
205
206    impl<E: CubePrimitive> ExpandElementTyped<SliceMut<E>> {
207        // Expand method of [len](SliceMut::len).
208        pub fn __expand_len_method(self, scope: &mut Scope) -> ExpandElementTyped<u32> {
209            let elem: ExpandElementTyped<Array<u32>> = self.expand.into();
210            elem.__expand_len_method(scope)
211        }
212
213        /// Expand method of [len](SliceMut::into_lined).
214        pub fn __expand_into_lined_method(
215            self,
216            _scope: &mut Scope,
217        ) -> ExpandElementTyped<SliceMut<Line<E>>>
218        where
219            E: CubePrimitive,
220        {
221            self.expand.into()
222        }
223
224        /// Expand method of [try_cast_unchecked](Slice::try_cast_unchecked).
225        pub fn __expand_try_cast_unchecked_method<T>(
226            self,
227            scope: &mut Scope,
228        ) -> ExpandElementTyped<SliceMut<T>>
229        where
230            E: CubePrimitive,
231            T: CubePrimitive,
232        {
233            if T::as_elem(scope) != E::as_elem(scope) && !is_tf32::<E, T>(scope) {
234                panic!("Try cast unchecked should only be used to satisfy the rust type system.")
235            }
236
237            self.expand.into()
238        }
239    }
240
241    impl<E: CubePrimitive> ExpandElementTyped<SliceMut<Line<E>>> {
242        /// Expand method of [with_line_size](SliceMut::with_line_size).
243        pub fn __expand_with_line_size_method(
244            self,
245            scope: &mut Scope,
246            line_size: u32,
247        ) -> ExpandElementTyped<SliceMut<Line<E>>>
248        where
249            E: CubePrimitive,
250        {
251            let input = self.clone().into_variable();
252            let mut item = input.item;
253
254            if line_size as u8 == item.vectorization.unwrap_or(NonZero::new(1).unwrap()).get() {
255                return self;
256            }
257
258            item.vectorization = NonZero::new(line_size as u8);
259            let out = scope.create_slice(item);
260
261            scope.register(Instruction::new(
262                Operator::ReinterpretSlice(cubecl_ir::ReinterpretSliceOperator {
263                    input,
264                    line_size,
265                }),
266                *out,
267            ));
268            out.into()
269        }
270    }
271}
272
273pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
274    let ty_c = C::as_elem(scope);
275    let ty_t = T::as_elem(scope);
276    let ty_f32 = f32::as_elem(scope);
277    let ty_tf32 = tf32::as_elem(scope);
278
279    (ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
280}
281
282/// Module that contains the implementation details of the index functions.
283mod indexation {
284    use cubecl_ir::{BinaryOperator, Instruction, Operator};
285
286    use crate::prelude::{CubeIndex, CubeIndexMut};
287
288    use super::*;
289
290    impl<E: CubePrimitive> Slice<E> {
291        /// Perform an unchecked index into the array
292        ///
293        /// # Safety
294        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
295        /// always in bounds
296        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
297        where
298            Self: CubeIndex<I>,
299        {
300            unexpanded!()
301        }
302    }
303
304    impl<E: CubePrimitive> SliceMut<E> {
305        /// Perform an unchecked index into the array
306        ///
307        /// # Safety
308        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
309        /// always in bounds
310        pub unsafe fn index_unchecked<I: Index>(&self, _i: I) -> &E
311        where
312            Self: CubeIndex<I>,
313        {
314            unexpanded!()
315        }
316
317        /// Perform an unchecked index assignment into the array
318        ///
319        /// # Safety
320        /// Out of bounds indexing causes undefined behaviour and may segfault. Ensure index is
321        /// always in bounds
322        pub unsafe fn index_assign_unchecked<I: Index>(&mut self, _i: I, _value: E)
323        where
324            Self: CubeIndexMut<I>,
325        {
326            unexpanded!()
327        }
328    }
329
330    impl<E: CubePrimitive> ExpandElementTyped<Slice<E>> {
331        pub fn __expand_index_unchecked_method(
332            self,
333            scope: &mut Scope,
334            i: ExpandElementTyped<u32>,
335        ) -> ExpandElementTyped<E> {
336            let out = scope.create_local(self.expand.item);
337            scope.register(Instruction::new(
338                Operator::UncheckedIndex(BinaryOperator {
339                    lhs: *self.expand,
340                    rhs: i.expand.consume(),
341                }),
342                *out,
343            ));
344            out.into()
345        }
346    }
347
348    impl<E: CubePrimitive> ExpandElementTyped<SliceMut<E>> {
349        pub fn __expand_index_unchecked_method(
350            self,
351            scope: &mut Scope,
352            i: ExpandElementTyped<u32>,
353        ) -> ExpandElementTyped<E> {
354            let out = scope.create_local(self.expand.item);
355            scope.register(Instruction::new(
356                Operator::UncheckedIndex(BinaryOperator {
357                    lhs: *self.expand,
358                    rhs: i.expand.consume(),
359                }),
360                *out,
361            ));
362            out.into()
363        }
364
365        pub fn __expand_index_assign_unchecked_method(
366            self,
367            scope: &mut Scope,
368            i: ExpandElementTyped<u32>,
369            value: ExpandElementTyped<E>,
370        ) {
371            scope.register(Instruction::new(
372                Operator::UncheckedIndexAssign(BinaryOperator {
373                    lhs: i.expand.consume(),
374                    rhs: value.expand.consume(),
375                }),
376                *self.expand,
377            ));
378        }
379    }
380}
381
382impl<E: CubeType> CubeType for Slice<E> {
383    type ExpandType = ExpandElementTyped<Slice<E>>;
384}
385
386impl<C: CubeType> Init for ExpandElementTyped<Slice<C>> {
387    fn init(self, _scope: &mut Scope) -> Self {
388        // The type can't be deeply cloned/copied.
389        self
390    }
391}
392
393impl<E: CubeType> CubeType for SliceMut<E> {
394    type ExpandType = ExpandElementTyped<SliceMut<E>>;
395}
396
397impl<E: CubeType> CubeType for &mut SliceMut<E> {
398    type ExpandType = ExpandElementTyped<SliceMut<E>>;
399}
400
401impl<C: CubeType> Init for ExpandElementTyped<SliceMut<C>> {
402    fn init(self, _scope: &mut Scope) -> Self {
403        // The type can't be deeply cloned/copied.
404        self
405    }
406}
407
408impl<C: CubeType<ExpandType = ExpandElementTyped<C>>> SizedContainer for Slice<C> {
409    type Item = C;
410}
411
412impl<T: CubeType> Iterator for Slice<T> {
413    type Item = T;
414
415    fn next(&mut self) -> Option<Self::Item> {
416        unexpanded!()
417    }
418}
419
420pub trait SliceOperator<E: CubeType>: CubeType<ExpandType = Self::Expand> {
421    type Expand: SliceOperatorExpand<E>;
422
423    /// Return a read-only view of all elements comprise between the `start` and `end` indices.
424    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
425    /// the length of `self`.
426    #[allow(unused_variables)]
427    fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> Slice<E> {
428        unexpanded!()
429    }
430    /// Expand function of [SliceOperator::slice].
431    fn __expand_slice(
432        scope: &mut Scope,
433        expand: Self::Expand,
434        start: ExpandElementTyped<u32>,
435        end: ExpandElementTyped<u32>,
436    ) -> ExpandElementTyped<Slice<E>> {
437        expand.__expand_slice_method(scope, start, end)
438    }
439
440    /// Return a read-write view of all elements comprise between the `start` and `end` indices.
441    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
442    /// the length of `self`.
443    #[allow(unused_variables)]
444    fn slice_mut<Start: Index, End: Index>(&mut self, start: Start, end: End) -> SliceMut<E> {
445        unexpanded!()
446    }
447
448    /// Expand function of [SliceOperator::slice_mut].
449    fn __expand_slice_mut(
450        scope: &mut Scope,
451        expand: Self::Expand,
452        start: ExpandElementTyped<u32>,
453        end: ExpandElementTyped<u32>,
454    ) -> ExpandElementTyped<SliceMut<E>> {
455        expand.__expand_slice_mut_method(scope, start, end)
456    }
457
458    /// Reinterprete the current type as a read-only slice.
459    #[allow(unused_variables)]
460    fn to_slice(&self) -> Slice<E> {
461        unexpanded!()
462    }
463
464    /// Expand function of [SliceOperator::to_slice].
465    fn __expand_to_slice(scope: &mut Scope, expand: Self::Expand) -> ExpandElementTyped<Slice<E>> {
466        expand.__expand_to_slice_method(scope)
467    }
468
469    /// Reinterprete the current type as a read-write slice.
470    #[allow(unused_variables, clippy::wrong_self_convention)]
471    fn to_slice_mut(&mut self) -> SliceMut<E> {
472        unexpanded!()
473    }
474
475    /// Expand function of [SliceOperator::to_slice_mut].
476    fn __expand_to_slice_mut(
477        scope: &mut Scope,
478        expand: Self::Expand,
479    ) -> ExpandElementTyped<SliceMut<E>> {
480        expand.__expand_to_slice_mut_method(scope)
481    }
482}
483
484pub trait SliceOperatorExpand<E: CubeType>: Into<ExpandElement> + Clone + Init + CubeDebug {
485    fn slice_base<Start: Index, End: Index>(
486        &self,
487        scope: &mut Scope,
488        start: Start,
489        end: End,
490    ) -> ExpandElement;
491
492    fn __expand_slice_method(
493        &self,
494        scope: &mut Scope,
495        start: ExpandElementTyped<u32>,
496        end: ExpandElementTyped<u32>,
497    ) -> ExpandElementTyped<Slice<E>> {
498        ExpandElementTyped::new(self.slice_base(scope, start, end))
499    }
500
501    fn __expand_slice_mut_method(
502        &self,
503        scope: &mut Scope,
504        start: ExpandElementTyped<u32>,
505        end: ExpandElementTyped<u32>,
506    ) -> ExpandElementTyped<SliceMut<E>> {
507        ExpandElementTyped::new(self.slice_base(scope, start, end))
508    }
509
510    fn __expand_to_slice_method(&self, _scope: &mut Scope) -> ExpandElementTyped<Slice<E>> {
511        let expand = self.clone().into();
512        ExpandElementTyped::new(expand)
513    }
514
515    fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> ExpandElementTyped<SliceMut<E>> {
516        let expand = self.clone().into();
517        ExpandElementTyped::new(expand)
518    }
519}
520
521macro_rules! slice_op {
522    ($type:ident) => {
523        impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
524            type Expand = ExpandElementTyped<$type<E>>;
525        }
526
527        impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
528            fn slice_base<Start: Index, End: Index>(
529                &self,
530                scope: &mut Scope,
531                start: Start,
532                end: End,
533            ) -> ExpandElement {
534                slice_expand(scope, self.clone(), start, end)
535            }
536        }
537    };
538    (slice $type:ident) => {
539        impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
540            type Expand = ExpandElementTyped<$type<E>>;
541        }
542
543        impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
544            fn slice_base<Start: Index, End: Index>(
545                &self,
546                scope: &mut Scope,
547                start: Start,
548                end: End,
549            ) -> ExpandElement {
550                slice_expand(scope, self.clone(), start, end)
551            }
552        }
553    };
554}
555
556slice_op!(Array);
557slice_op!(Tensor);
558slice_op!(SharedMemory);
559slice_op!(slice Slice);
560slice_op!(slice SliceMut);
561
562pub fn slice_expand<I: Into<ExpandElement>, S1: Index, S2: Index>(
563    scope: &mut Scope,
564    input: I,
565    start: S1,
566    end: S2, // Todo use it to get the length.
567) -> ExpandElement {
568    let input = input.into();
569    let out = scope.create_slice(input.item);
570
571    scope.register(Instruction::new(
572        Operator::Slice(cubecl_ir::SliceOperator {
573            input: *input,
574            start: start.value(),
575            end: end.value(),
576        }),
577        *out,
578    ));
579
580    out
581}
582
583impl<T: CubePrimitive> List<T> for Slice<T> {
584    fn __expand_read(
585        scope: &mut Scope,
586        this: ExpandElementTyped<Slice<T>>,
587        idx: ExpandElementTyped<u32>,
588    ) -> ExpandElementTyped<T> {
589        index::expand(scope, this, idx)
590    }
591}
592
593impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<Slice<T>> {
594    fn __expand_read_method(
595        self,
596        scope: &mut Scope,
597        idx: ExpandElementTyped<u32>,
598    ) -> ExpandElementTyped<T> {
599        index::expand(scope, self, idx)
600    }
601}
602
603impl<T: CubePrimitive> List<T> for SliceMut<T> {
604    fn __expand_read(
605        scope: &mut Scope,
606        this: ExpandElementTyped<SliceMut<T>>,
607        idx: ExpandElementTyped<u32>,
608    ) -> ExpandElementTyped<T> {
609        index::expand(scope, this, idx)
610    }
611}
612
613impl<T: CubePrimitive> ListExpand<T> for ExpandElementTyped<SliceMut<T>> {
614    fn __expand_read_method(
615        self,
616        scope: &mut Scope,
617        idx: ExpandElementTyped<u32>,
618    ) -> ExpandElementTyped<T> {
619        index::expand(scope, self, idx)
620    }
621}
622
623impl<T: CubePrimitive> ListMut<T> for SliceMut<T> {
624    fn __expand_write(
625        scope: &mut Scope,
626        this: ExpandElementTyped<SliceMut<T>>,
627        idx: ExpandElementTyped<u32>,
628        value: ExpandElementTyped<T>,
629    ) {
630        index_assign::expand(scope, this, idx, value);
631    }
632}
633
634impl<T: CubePrimitive> ListMutExpand<T> for ExpandElementTyped<SliceMut<T>> {
635    fn __expand_write_method(
636        self,
637        scope: &mut Scope,
638        idx: ExpandElementTyped<u32>,
639        value: ExpandElementTyped<T>,
640    ) {
641        index_assign::expand(scope, self, idx, value);
642    }
643}