Skip to main content

cubecl_core/frontend/container/slice/
base.rs

1use alloc::boxed::Box;
2use core::{
3    marker::PhantomData,
4    ops::{Deref, DerefMut},
5};
6
7use crate::{self as cubecl, unexpanded};
8use cubecl::prelude::*;
9use cubecl_ir::{Branch, ElemType, FloatKind, ManagedVariable, RangeLoop, Variable, VectorSize};
10use cubecl_macros::intrinsic;
11
12#[derive(Clone, Copy)]
13pub struct ReadOnly;
14#[derive(Clone, Copy)]
15pub struct ReadWrite;
16
17/// A read-only contiguous list of elements
18///
19/// # Safety
20///
21/// Since data can't be deallocated during kernel execution, this is safe.
22#[derive(Clone, Copy)]
23pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
24    _e: PhantomData<E>,
25    _io: PhantomData<IO>,
26    _offset: PhantomData<usize>,
27    length: usize,
28}
29
30#[derive(CubeType)]
31pub enum SliceOrigin<E: CubePrimitive> {
32    Tensor(Tensor<E>),
33    Array(Array<E>),
34    SharedMemory(SharedMemory<E>),
35}
36
37impl<E: CubePrimitive> SliceOriginExpand<E> {
38    pub fn vector_size(&self) -> VectorSize {
39        match self {
40            SliceOriginExpand::Tensor(t) => t.vector_size(),
41            SliceOriginExpand::Array(t) => t.vector_size(),
42            SliceOriginExpand::SharedMemory(t) => t.vector_size(),
43        }
44    }
45}
46
47impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
48    type Item = E;
49
50    fn next(&mut self) -> Option<Self::Item> {
51        unexpanded!()
52    }
53}
54
55pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
56
57impl SliceVisibility for ReadOnly {}
58
59impl SliceVisibility for ReadWrite {}
60
61pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
62    pub(crate) origin: SliceOriginExpand<E>,
63    pub(crate) io: PhantomData<IO>,
64    pub(crate) offset: NativeExpand<usize>,
65    pub(crate) length: NativeExpand<usize>,
66    pub(crate) vector_size: Option<VectorSize>,
67}
68
69impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
70    pub fn __to_raw_parts(&self) -> (Variable, Variable) {
71        let expand = match self.origin.clone() {
72            SliceOriginExpand::Tensor(expand) => expand.expand,
73            SliceOriginExpand::Array(expand) => expand.expand,
74            SliceOriginExpand::SharedMemory(expand) => expand.expand,
75        };
76
77        (*expand, *self.offset.expand)
78    }
79}
80
81#[cube]
82impl<E: Scalar, N: Size, IO: SliceVisibility> Slice<Vector<E, N>, IO> {
83    /// Reinterprets how items are loaded and stored in memory.slicebase
84    ///
85    /// # Warning
86    ///
87    /// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
88    #[allow(unused_variables)]
89    pub fn with_vector_size<N2: Size>(&self) -> Slice<Vector<E, N2>, IO> {
90        intrinsic!(|scope| {
91            let vector_size = N2::__expand_value(scope);
92            let (input, offset) = self.__to_raw_parts();
93            let mut item = input.ty;
94
95            let current = input.ty.vector_size();
96            let mut out = self
97                .clone()
98                .__expand_downcast_unchecked_method::<Vector<E, N2>>(scope);
99
100            if vector_size == item.vector_size() {
101                return out;
102            }
103
104            if current < vector_size {
105                let ratio = vector_size / current;
106                let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
107                let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
108                out.length = length;
109                out.offset = offset;
110            } else {
111                let ratio = current / vector_size;
112                let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
113                let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
114                out.length = length;
115                out.offset = offset;
116            }
117
118            out.vector_size = Some(vector_size);
119            out
120        })
121    }
122}
123
124#[cube]
125impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
126    /// Returns the same slice, but with the type reinterpreted as `Vector`.
127    /// Preserves existing vector size of the primitive.
128    pub fn into_vectorized(&self) -> Slice<Vector<E::Scalar, E::Size>, IO> {
129        intrinsic!(|scope| {
130            SliceExpand::<Vector<E::Scalar, E::Size>, IO> {
131                origin: self.origin.cast_unchecked(),
132                io: self.io.clone(),
133                offset: self.offset.clone(),
134                length: self.length.clone(),
135                vector_size: self.vector_size,
136            }
137        })
138    }
139    /// Downcast the slice to the given type and panic if the type isn't the same.
140    ///
141    /// This function should only be used to satisfy the Rust type system, when two generic
142    /// types are supposed to be the same.
143    pub fn downcast<T: CubePrimitive>(&self) -> Slice<T, IO> {
144        intrinsic!(|scope| {
145            if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
146                let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
147                let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
148                    && elems.contains(&ElemType::Float(FloatKind::Flex32));
149
150                if !is_flex32_cast {
151                    panic!("Downcast should only be used to satisfy the Rust type system.")
152                }
153            }
154
155            unsafe { self.__expand_downcast_unchecked_method(scope) }
156        })
157    }
158
159    /// Unsafely downcast the slice to the given type and panic if the type isn't the same.
160    ///
161    /// # Safety
162    /// This function converts unsafely, and should only be used for temporary storage with a dummy
163    /// type (i.e. `ReinterpretSlice`)
164    pub unsafe fn downcast_unchecked<T: CubePrimitive>(&self) -> Slice<T, IO> {
165        intrinsic!(|scope| {
166            SliceExpand::<T, IO> {
167                origin: self.origin.cast_unchecked(),
168                io: self.io.clone(),
169                offset: self.offset.clone(),
170                length: self.length.clone(),
171                vector_size: self.vector_size.clone(),
172            }
173        })
174    }
175}
176
177#[cube]
178impl<E: CubePrimitive> Slice<E, ReadOnly> {
179    pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
180        intrinsic!(|scope| {
181            SliceExpand::<E, ReadWrite> {
182                origin: self.origin,
183                io: PhantomData,
184                offset: self.offset.clone(),
185                length: self.length.clone(),
186                vector_size: self.vector_size.clone(),
187            }
188        })
189    }
190}
191
192impl<E: CubePrimitive> SliceOriginExpand<E> {
193    fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
194        match self {
195            SliceOriginExpand::Tensor(expand) => {
196                SliceOriginExpand::<T>::Tensor(expand.expand.into())
197            }
198            SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
199            SliceOriginExpand::SharedMemory(expand) => {
200                SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
201            }
202        }
203    }
204}
205
206impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
207    pub fn new(_origin: SliceOrigin<E>, _offset: usize, _length: usize) -> Self {
208        unexpanded!()
209    }
210    pub fn __expand_new(
211        scope: &mut Scope,
212        origin: SliceOriginExpand<E>,
213        start: NativeExpand<usize>,
214        end: NativeExpand<usize>,
215    ) -> SliceExpand<E, IO> {
216        Self::__expand_new_expand(scope, origin, start, end)
217    }
218    pub fn __expand_new_expand(
219        scope: &mut Scope,
220        origin: SliceOriginExpand<E>,
221        start: NativeExpand<usize>,
222        end: NativeExpand<usize>,
223    ) -> SliceExpand<E, IO> {
224        let length = cubecl::frontend::sub::expand(scope, end, start.clone());
225
226        SliceExpand::<E, IO> {
227            origin,
228            io: PhantomData,
229            offset: start,
230            length,
231            vector_size: None,
232        }
233    }
234}
235
236#[cube]
237impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
238    /// Get the length of the slice.
239    pub fn len(&self) -> usize {
240        self.length
241    }
242    /// Returns true if the slice is empty.
243    pub fn is_empty(&self) -> bool {
244        self.length == 0
245    }
246}
247
248impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
249    type ExpandType = SliceExpand<E, IO>;
250}
251
252impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
253    type ExpandType = SliceExpand<E, IO>;
254}
255
256impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
257    type ExpandType = SliceExpand<E, IO>;
258}
259
260impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
261    fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
262        self
263    }
264}
265
266impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
267impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
268    fn clone(&self) -> Self {
269        Self {
270            origin: self.origin.clone(),
271            offset: self.offset.clone(),
272            length: self.length.clone(),
273            vector_size: self.vector_size,
274            io: PhantomData,
275        }
276    }
277}
278
279// TODO: Fix
280impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
281    type Item = E;
282}
283
284impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
285    fn expand(
286        self,
287        scope: &mut Scope,
288        mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
289    ) {
290        let index_ty = u32::as_type(scope);
291        let len: ManagedVariable = self.length.clone().into();
292
293        let mut child = scope.child();
294        let i = child.create_local_restricted(index_ty);
295
296        let index = i.clone().into();
297        let item = index::expand(&mut child, self, index);
298        body(&mut child, item);
299
300        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
301            i: *i,
302            start: 0usize.into(),
303            end: *len,
304            step: None,
305            inclusive: false,
306            scope: child,
307        })));
308    }
309
310    fn expand_unroll(
311        self,
312        _scope: &mut Scope,
313        _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
314    ) {
315        unimplemented!("Can't unroll slice iterator")
316    }
317}
318impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
319    type Output = E;
320    type Idx = usize;
321
322    fn expand_index(
323        scope: &mut Scope,
324        array: Self::ExpandType,
325        index: NativeExpand<usize>,
326    ) -> <Self::Output as CubeType>::ExpandType {
327        array.__expand_read_method(scope, index)
328    }
329}
330
331impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
332    type Output = E::ExpandType;
333    type Idx = NativeExpand<usize>;
334
335    fn expand_index(self, scope: &mut Scope, index: NativeExpand<usize>) -> Self::Output {
336        self.__expand_read_method(scope, index)
337    }
338    fn expand_index_unchecked(self, scope: &mut Scope, index: NativeExpand<usize>) -> Self::Output {
339        self.__expand_read_unchecked_method(scope, index)
340    }
341}
342
343impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
344impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
345    fn __expand_read_method(
346        &self,
347        scope: &mut cubecl_ir::Scope,
348        index: NativeExpand<usize>,
349    ) -> <E as CubeType>::ExpandType {
350        read_offset::expand::<E>(
351            scope,
352            self.origin.clone(),
353            self.offset.clone(),
354            index,
355            self.vector_size,
356            true,
357        )
358    }
359    fn __expand_read_unchecked_method(
360        &self,
361        scope: &mut cubecl_ir::Scope,
362        index: NativeExpand<usize>,
363    ) -> <E as CubeType>::ExpandType {
364        read_offset::expand::<E>(
365            scope,
366            self.origin.clone(),
367            self.offset.clone(),
368            index,
369            self.vector_size,
370            false,
371        )
372    }
373
374    fn __expand_len_method(&self, scope: &mut Scope) -> NativeExpand<usize> {
375        Self::__expand_len(scope, self.clone())
376    }
377}
378
379impl<T: CubePrimitive, IO: SliceVisibility> Deref for Slice<T, IO> {
380    type Target = [T];
381
382    fn deref(&self) -> &Self::Target {
383        unexpanded!()
384    }
385}
386
387impl<T: CubePrimitive> DerefMut for Slice<T, ReadWrite> {
388    fn deref_mut(&mut self) -> &mut Self::Target {
389        unexpanded!()
390    }
391}
392
393impl<E: CubePrimitive, IO: SliceVisibility> Vectorized for Slice<E, IO> {}
394impl<E: CubePrimitive, IO: SliceVisibility> VectorizedExpand for SliceExpand<E, IO> {
395    fn vector_size(&self) -> VectorSize {
396        self.vector_size
397            .unwrap_or_else(|| self.origin.vector_size())
398    }
399}
400
401impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
402    fn expand_index_mut(
403        scope: &mut Scope,
404        array: Self::ExpandType,
405        index: NativeExpand<usize>,
406        value: NativeExpand<E>,
407    ) {
408        array.__expand_write_method(scope, index, value)
409    }
410}
411
412impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
413    fn expand_index_mut(self, scope: &mut Scope, index: NativeExpand<usize>, value: Self::Output) {
414        self.__expand_write_method(scope, index, value)
415    }
416}
417
418impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
419impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
420    fn __expand_write_method(
421        &self,
422        scope: &mut cubecl_ir::Scope,
423        index: NativeExpand<usize>,
424        value: NativeExpand<E>,
425    ) {
426        write_offset::expand::<E>(
427            scope,
428            self.origin.clone(),
429            self.offset.clone(),
430            index,
431            value,
432            self.vector_size,
433        )
434    }
435}
436
437mod read_offset {
438    use super::*;
439
440    pub fn expand<E: CubePrimitive>(
441        scope: &mut cubecl::prelude::Scope,
442        origin: SliceOriginExpand<E>,
443        offset: <usize as cubecl::prelude::CubeType>::ExpandType,
444        index: <usize as cubecl::prelude::CubeType>::ExpandType,
445        vector_size: Option<VectorSize>,
446        checked: bool,
447    ) -> <E as cubecl::prelude::CubeType>::ExpandType {
448        let index = cubecl::frontend::add::expand(scope, offset, index);
449
450        match origin {
451            SliceOriginExpand::Tensor(expand) => {
452                expand_index_native::<Tensor<E>>(scope, expand, index, vector_size, checked)
453            }
454            SliceOriginExpand::Array(expand) => {
455                expand_index_native::<Array<E>>(scope, expand, index, vector_size, checked)
456            }
457            SliceOriginExpand::SharedMemory(expand) => {
458                expand_index_native::<SharedMemory<E>>(scope, expand, index, vector_size, checked)
459            }
460        }
461    }
462}
463
464mod write_offset {
465    use super::*;
466
467    pub fn expand<E: CubePrimitive>(
468        scope: &mut cubecl::prelude::Scope,
469        origin: SliceOriginExpand<E>,
470        offset: <usize as cubecl::prelude::CubeType>::ExpandType,
471        index: <usize as cubecl::prelude::CubeType>::ExpandType,
472        value: <E as cubecl::prelude::CubeType>::ExpandType,
473        vector_size: Option<VectorSize>,
474    ) {
475        let index = cubecl::frontend::add::expand(scope, offset, index);
476
477        match origin {
478            SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
479                scope,
480                expand,
481                index,
482                value,
483                vector_size,
484                true,
485            ),
486            SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
487                scope,
488                expand,
489                index,
490                value,
491                vector_size,
492                false,
493            ),
494            SliceOriginExpand::SharedMemory(expand) => {
495                expand_index_assign_native::<SharedMemory<E>>(
496                    scope,
497                    expand,
498                    index,
499                    value,
500                    vector_size,
501                    false,
502                )
503            }
504        }
505    }
506}