Skip to main content

cubecl_core/frontend/container/slice/
base.rs

1use alloc::boxed::Box;
2use core::{
3    marker::PhantomData,
4    ops::{Deref, DerefMut},
5};
6
7use crate::{self as cubecl, unexpanded};
8use cubecl::prelude::*;
9use cubecl_ir::{Branch, ElemType, ExpandElement, FloatKind, LineSize, RangeLoop, Type, Variable};
10use cubecl_macros::intrinsic;
11
12#[derive(Clone, Copy)]
13pub struct ReadOnly;
14#[derive(Clone, Copy)]
15pub struct ReadWrite;
16
17/// A read-only contiguous list of elements
18///
19/// # Safety
20///
21/// Since data can't be deallocated during kernel execution, this is safe.
22#[derive(Clone, Copy)]
23pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
24    _e: PhantomData<E>,
25    _io: PhantomData<IO>,
26    _offset: PhantomData<usize>,
27    length: usize,
28}
29
30#[derive(CubeType)]
31pub enum SliceOrigin<E: CubePrimitive> {
32    Tensor(Tensor<E>),
33    Array(Array<E>),
34    SharedMemory(SharedMemory<E>),
35}
36
37impl<E: CubePrimitive> SliceOriginExpand<E> {
38    pub fn line_size(&self) -> LineSize {
39        match self {
40            SliceOriginExpand::Tensor(t) => t.line_size(),
41            SliceOriginExpand::Array(t) => t.line_size(),
42            SliceOriginExpand::SharedMemory(t) => t.line_size(),
43        }
44    }
45}
46
47impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
48    type Item = E;
49
50    fn next(&mut self) -> Option<Self::Item> {
51        unexpanded!()
52    }
53}
54
55pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
56
57impl SliceVisibility for ReadOnly {}
58
59impl SliceVisibility for ReadWrite {}
60
61pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
62    pub(crate) origin: SliceOriginExpand<E>,
63    pub(crate) io: PhantomData<IO>,
64    pub(crate) offset: ExpandElementTyped<usize>,
65    pub(crate) length: ExpandElementTyped<usize>,
66    pub(crate) line_size: Option<LineSize>,
67}
68
69impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
70    pub fn __to_raw_parts(&self) -> (Variable, Variable) {
71        let expand = match self.origin.clone() {
72            SliceOriginExpand::Tensor(expand) => expand.expand,
73            SliceOriginExpand::Array(expand) => expand.expand,
74            SliceOriginExpand::SharedMemory(expand) => expand.expand,
75        };
76
77        (*expand, *self.offset.expand)
78    }
79}
80
81#[cube]
82impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
83    /// Reinterprets how items are loaded and stored in memory.slicebase
84    ///
85    /// # Warning
86    ///
87    /// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
88    #[allow(unused_variables)]
89    pub fn with_line_size(&self, #[comptime] line_size: LineSize) -> Slice<Line<E>, IO> {
90        intrinsic!(|scope| {
91            let (input, offset) = self.__to_raw_parts();
92            let mut item = input.ty;
93
94            if line_size == item.line_size() {
95                return self;
96            }
97
98            let current = input.ty.line_size();
99            let mut out = self.clone();
100
101            if current < line_size {
102                let ratio = line_size / current;
103                let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
104                let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
105                out.length = length;
106                out.offset = offset;
107            } else {
108                let ratio = current / line_size;
109                let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
110                let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
111                out.length = length;
112                out.offset = offset;
113            }
114
115            out.line_size = Some(line_size);
116            out
117        })
118    }
119}
120
121#[cube]
122impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
123    /// Returns the same slice, but with lines of length 1.
124    pub fn into_lined(&self) -> Slice<Line<E>, IO> {
125        intrinsic!(|_scope| {
126            SliceExpand::<Line<E>, IO> {
127                origin: self.origin.cast_unchecked(),
128                io: self.io.clone(),
129                offset: self.offset.clone(),
130                length: self.length.clone(),
131                line_size: None,
132            }
133        })
134    }
135    /// Downcast the slice to the given type and panic if the type isn't the same.
136    ///
137    /// This function should only be used to satisfy the Rust type system, when two generic
138    /// types are supposed to be the same.
139    pub fn downcast<T: CubePrimitive>(&self) -> Slice<T, IO> {
140        intrinsic!(|scope| {
141            if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
142                let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
143                let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
144                    && elems.contains(&ElemType::Float(FloatKind::Flex32));
145
146                if !is_flex32_cast {
147                    panic!("Downcast should only be used to satisfy the Rust type system.")
148                }
149            }
150
151            SliceExpand::<T, IO> {
152                origin: self.origin.cast_unchecked(),
153                io: self.io.clone(),
154                offset: self.offset.clone(),
155                length: self.length.clone(),
156                line_size: self.line_size.clone(),
157            }
158        })
159    }
160}
161
162#[cube]
163impl<E: CubePrimitive> Slice<E, ReadOnly> {
164    pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
165        intrinsic!(|scope| {
166            SliceExpand::<E, ReadWrite> {
167                origin: self.origin,
168                io: PhantomData,
169                offset: self.offset.clone(),
170                length: self.length.clone(),
171                line_size: self.line_size.clone(),
172            }
173        })
174    }
175}
176
177impl<E: CubePrimitive> SliceOriginExpand<E> {
178    fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
179        match self {
180            SliceOriginExpand::Tensor(expand) => {
181                SliceOriginExpand::<T>::Tensor(expand.expand.into())
182            }
183            SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
184            SliceOriginExpand::SharedMemory(expand) => {
185                SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
186            }
187        }
188    }
189}
190
191impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
192    pub fn new(_origin: SliceOrigin<E>, _offset: usize, _length: usize) -> Self {
193        unexpanded!()
194    }
195    pub fn __expand_new(
196        scope: &mut Scope,
197        origin: SliceOriginExpand<E>,
198        start: ExpandElementTyped<usize>,
199        end: ExpandElementTyped<usize>,
200    ) -> SliceExpand<E, IO> {
201        Self::__expand_new_expand(scope, origin, start, end)
202    }
203    pub fn __expand_new_expand(
204        scope: &mut Scope,
205        origin: SliceOriginExpand<E>,
206        start: ExpandElementTyped<usize>,
207        end: ExpandElementTyped<usize>,
208    ) -> SliceExpand<E, IO> {
209        let length = cubecl::frontend::sub::expand(scope, end, start.clone());
210
211        SliceExpand::<E, IO> {
212            origin,
213            io: PhantomData,
214            offset: start,
215            length,
216            line_size: None,
217        }
218    }
219}
220
221#[cube]
222impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
223    /// Get the length of the slice.
224    pub fn len(&self) -> usize {
225        self.length
226    }
227    /// Returns true if the slice is empty.
228    pub fn is_empty(&self) -> bool {
229        self.length == 0
230    }
231}
232
233impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
234    type ExpandType = SliceExpand<E, IO>;
235}
236
237impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
238    type ExpandType = SliceExpand<E, IO>;
239}
240
241impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
242    type ExpandType = SliceExpand<E, IO>;
243}
244
245impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
246    fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
247        self
248    }
249}
250
251impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
252impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
253    fn clone(&self) -> Self {
254        Self {
255            origin: self.origin.clone(),
256            offset: self.offset.clone(),
257            length: self.length.clone(),
258            line_size: self.line_size,
259            io: PhantomData,
260        }
261    }
262}
263
264// TODO: Fix
265impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
266    type Item = E;
267}
268
269impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
270    fn expand(
271        self,
272        scope: &mut Scope,
273        mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
274    ) {
275        let index_ty = Type::new(u32::as_type(scope));
276        let len: ExpandElement = self.length.clone().into();
277
278        let mut child = scope.child();
279        let i = child.create_local_restricted(index_ty);
280
281        let index = i.clone().into();
282        let item = index::expand(&mut child, self, index);
283        body(&mut child, item);
284
285        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
286            i: *i,
287            start: 0usize.into(),
288            end: *len,
289            step: None,
290            inclusive: false,
291            scope: child,
292        })));
293    }
294
295    fn expand_unroll(
296        self,
297        _scope: &mut Scope,
298        _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
299    ) {
300        unimplemented!("Can't unroll slice iterator")
301    }
302}
303impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
304    type Output = E;
305    type Idx = usize;
306
307    fn expand_index(
308        scope: &mut Scope,
309        array: Self::ExpandType,
310        index: ExpandElementTyped<usize>,
311    ) -> <Self::Output as CubeType>::ExpandType {
312        array.__expand_read_method(scope, index)
313    }
314}
315
316impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
317    type Output = E::ExpandType;
318    type Idx = ExpandElementTyped<usize>;
319
320    fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<usize>) -> Self::Output {
321        self.__expand_read_method(scope, index)
322    }
323    fn expand_index_unchecked(
324        self,
325        scope: &mut Scope,
326        index: ExpandElementTyped<usize>,
327    ) -> Self::Output {
328        self.__expand_read_unchecked_method(scope, index)
329    }
330}
331
332impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
333impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
334    fn __expand_read_method(
335        &self,
336        scope: &mut cubecl_ir::Scope,
337        index: ExpandElementTyped<usize>,
338    ) -> <E as CubeType>::ExpandType {
339        read_offset::expand::<E>(
340            scope,
341            self.origin.clone(),
342            self.offset.clone(),
343            index,
344            self.line_size,
345            true,
346        )
347    }
348    fn __expand_read_unchecked_method(
349        &self,
350        scope: &mut cubecl_ir::Scope,
351        index: ExpandElementTyped<usize>,
352    ) -> <E as CubeType>::ExpandType {
353        read_offset::expand::<E>(
354            scope,
355            self.origin.clone(),
356            self.offset.clone(),
357            index,
358            self.line_size,
359            false,
360        )
361    }
362
363    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
364        Self::__expand_len(scope, self.clone())
365    }
366}
367
368impl<T: CubePrimitive, IO: SliceVisibility> Deref for Slice<T, IO> {
369    type Target = [T];
370
371    fn deref(&self) -> &Self::Target {
372        unexpanded!()
373    }
374}
375
376impl<T: CubePrimitive> DerefMut for Slice<T, ReadWrite> {
377    fn deref_mut(&mut self) -> &mut Self::Target {
378        unexpanded!()
379    }
380}
381
382impl<E: CubePrimitive, IO: SliceVisibility> Lined for Slice<E, IO> {}
383impl<E: CubePrimitive, IO: SliceVisibility> LinedExpand for SliceExpand<E, IO> {
384    fn line_size(&self) -> LineSize {
385        self.line_size.unwrap_or_else(|| self.origin.line_size())
386    }
387}
388
389impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
390    fn expand_index_mut(
391        scope: &mut Scope,
392        array: Self::ExpandType,
393        index: ExpandElementTyped<usize>,
394        value: ExpandElementTyped<E>,
395    ) {
396        array.__expand_write_method(scope, index, value)
397    }
398}
399
400impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
401    fn expand_index_mut(
402        self,
403        scope: &mut Scope,
404        index: ExpandElementTyped<usize>,
405        value: Self::Output,
406    ) {
407        self.__expand_write_method(scope, index, value)
408    }
409}
410
411impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
412impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
413    fn __expand_write_method(
414        &self,
415        scope: &mut cubecl_ir::Scope,
416        index: ExpandElementTyped<usize>,
417        value: ExpandElementTyped<E>,
418    ) {
419        write_offset::expand::<E>(
420            scope,
421            self.origin.clone(),
422            self.offset.clone(),
423            index,
424            value,
425            self.line_size,
426        )
427    }
428}
429
430mod read_offset {
431    use super::*;
432
433    pub fn expand<E: CubePrimitive>(
434        scope: &mut cubecl::prelude::Scope,
435        origin: SliceOriginExpand<E>,
436        offset: <usize as cubecl::prelude::CubeType>::ExpandType,
437        index: <usize as cubecl::prelude::CubeType>::ExpandType,
438        line_size: Option<LineSize>,
439        checked: bool,
440    ) -> <E as cubecl::prelude::CubeType>::ExpandType {
441        let index = cubecl::frontend::add::expand(scope, offset, index);
442
443        match origin {
444            SliceOriginExpand::Tensor(expand) => {
445                expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
446            }
447            SliceOriginExpand::Array(expand) => {
448                expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
449            }
450            SliceOriginExpand::SharedMemory(expand) => {
451                expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
452            }
453        }
454    }
455}
456
457mod write_offset {
458    use super::*;
459
460    pub fn expand<E: CubePrimitive>(
461        scope: &mut cubecl::prelude::Scope,
462        origin: SliceOriginExpand<E>,
463        offset: <usize as cubecl::prelude::CubeType>::ExpandType,
464        index: <usize as cubecl::prelude::CubeType>::ExpandType,
465        value: <E as cubecl::prelude::CubeType>::ExpandType,
466        line_size: Option<LineSize>,
467    ) {
468        let index = cubecl::frontend::add::expand(scope, offset, index);
469
470        match origin {
471            SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
472                scope, expand, index, value, line_size, true,
473            ),
474            SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
475                scope, expand, index, value, line_size, false,
476            ),
477            SliceOriginExpand::SharedMemory(expand) => {
478                expand_index_assign_native::<SharedMemory<E>>(
479                    scope, expand, index, value, line_size, false,
480                )
481            }
482        }
483    }
484}