cubecl_core/frontend/container/slice/
base.rs

1use std::{marker::PhantomData, num::NonZero};
2
3use crate::{self as cubecl, unexpanded};
4use cubecl::prelude::*;
5use cubecl_ir::{Branch, Elem, ExpandElement, FloatKind, Item, RangeLoop, Variable};
6use cubecl_macros::intrinsic;
7
8#[derive(Clone)]
9pub struct ReadOnly;
10#[derive(Clone)]
11pub struct ReadWrite;
12
13/// A read-only contiguous list of elements
14///
15/// # Safety
16///
17/// Since data can't be deallocated during kernel execution, this is safe.
18#[derive(Clone, Copy)]
19pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
20    _e: PhantomData<E>,
21    _io: PhantomData<IO>,
22    _offset: PhantomData<u32>,
23    length: u32,
24}
25
26#[derive(CubeType)]
27pub enum SliceOrigin<E: CubePrimitive> {
28    Tensor(Tensor<E>),
29    Array(Array<E>),
30    SharedMemory(SharedMemory<E>),
31}
32
33impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
34    type Item = E;
35
36    fn next(&mut self) -> Option<Self::Item> {
37        unexpanded!()
38    }
39}
40
41pub trait SliceVisibility {}
42
43impl SliceVisibility for ReadOnly {}
44
45impl SliceVisibility for ReadWrite {}
46
47pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
48    pub(crate) origin: SliceOriginExpand<E>,
49    pub(crate) io: PhantomData<IO>,
50    pub(crate) offset: ExpandElementTyped<u32>,
51    pub(crate) length: ExpandElementTyped<u32>,
52    pub(crate) line_size: Option<u32>,
53}
54
55impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
56    pub fn __to_raw_parts(&self) -> (Variable, Variable) {
57        let expand = match self.origin.clone() {
58            SliceOriginExpand::Tensor(expand) => expand.expand,
59            SliceOriginExpand::Array(expand) => expand.expand,
60            SliceOriginExpand::SharedMemory(expand) => expand.expand,
61        };
62
63        (*expand, *self.offset.expand)
64    }
65}
66
67#[cube]
68impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
69    /// Reinterprets how items are loaded and stored in memory.slicebase
70    ///
71    /// # Warning
72    ///
73    /// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
74    #[allow(unused_variables)]
75    pub fn with_line_size(&self, #[comptime] line_size: u32) -> Slice<Line<E>, IO> {
76        intrinsic!(|scope| {
77            let (input, offset) = self.__to_raw_parts();
78            let mut item = input.item;
79
80            if line_size as u8 == item.vectorization.unwrap_or(NonZero::new(1).unwrap()).get() {
81                return self;
82            }
83
84            let current = input.item.vectorization.map(|a| a.get()).unwrap_or(1) as u32;
85            let mut out = self.clone();
86
87            if current < line_size {
88                let ratio = line_size / current;
89                let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
90                let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
91                out.length = length;
92                out.offset = offset;
93            } else {
94                let ratio = current / line_size;
95                let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
96                let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
97                out.length = length;
98                out.offset = offset;
99            }
100
101            out.line_size = Some(line_size);
102            out
103        })
104    }
105}
106
107#[cube]
108impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
109    /// Returns the same slice, but with lines of length 1.
110    pub fn into_lined(&self) -> Slice<Line<E>, IO> {
111        intrinsic!(|_scope| {
112            SliceExpand::<Line<E>, IO> {
113                origin: self.origin.cast_unchecked(),
114                io: self.io.clone(),
115                offset: self.offset.clone(),
116                length: self.length.clone(),
117                line_size: None,
118            }
119        })
120    }
121    /// Returns the same slice, but with lines of length 1.
122    /// Try to cast the slice to the given type and panic if the type isn't the same.
123    ///
124    /// This function should only be used to satisfy the Rust type system, when two generic
125    /// types are supposed to be the same.
126    pub fn try_cast_unchecked<T: CubePrimitive>(&self) -> Slice<T, IO> {
127        intrinsic!(|scope| {
128            if T::as_elem(scope) != E::as_elem(scope) && !is_tf32::<E, T>(scope) {
129                let elems = [T::as_elem(scope), E::as_elem(scope)];
130                let is_flex32_cast = elems.contains(&Elem::Float(FloatKind::F32))
131                    && elems.contains(&Elem::Float(FloatKind::Flex32));
132
133                if !is_flex32_cast {
134                    panic!(
135                        "Try cast unchecked should only be used to satisfy the rust type system."
136                    )
137                }
138            }
139
140            SliceExpand::<T, IO> {
141                origin: self.origin.cast_unchecked(),
142                io: self.io.clone(),
143                offset: self.offset.clone(),
144                length: self.length.clone(),
145                line_size: self.line_size.clone(),
146            }
147        })
148    }
149}
150
151impl<E: CubePrimitive> SliceOriginExpand<E> {
152    fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
153        match self {
154            SliceOriginExpand::Tensor(expand) => {
155                SliceOriginExpand::<T>::Tensor(expand.expand.into())
156            }
157            SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
158            SliceOriginExpand::SharedMemory(expand) => {
159                SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
160            }
161        }
162    }
163}
164
165impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
166    pub fn new(_origin: SliceOrigin<E>, _offset: u32, _length: u32) -> Self {
167        unexpanded!()
168    }
169    pub fn __expand_new(
170        scope: &mut Scope,
171        origin: SliceOriginExpand<E>,
172        start: ExpandElementTyped<u32>,
173        end: ExpandElementTyped<u32>,
174    ) -> SliceExpand<E, IO> {
175        Self::__expand_new_expand(scope, origin, start, end)
176    }
177    pub fn __expand_new_expand(
178        scope: &mut Scope,
179        origin: SliceOriginExpand<E>,
180        start: ExpandElementTyped<u32>,
181        end: ExpandElementTyped<u32>,
182    ) -> SliceExpand<E, IO> {
183        let length = cubecl::frontend::sub::expand(scope, end, start.clone());
184
185        SliceExpand::<E, IO> {
186            origin,
187            io: PhantomData,
188            offset: start,
189            length,
190            line_size: None,
191        }
192    }
193}
194
195#[cube]
196impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
197    /// Get the length of the slice.
198    pub fn len(&self) -> u32 {
199        self.length
200    }
201    /// Returns true if the slice is empty.
202    pub fn is_empty(&self) -> bool {
203        self.length == 0
204    }
205}
206
207impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
208    type ExpandType = SliceExpand<E, IO>;
209}
210
211impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
212    fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
213        self
214    }
215}
216
217impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
218impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
219    fn clone(&self) -> Self {
220        Self {
221            origin: self.origin.clone(),
222            offset: self.offset.clone(),
223            length: self.length.clone(),
224            line_size: self.line_size,
225            io: PhantomData,
226        }
227    }
228}
229
230// TODO: Fix
231impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
232    type Item = E;
233}
234
235impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
236    fn expand(
237        self,
238        scope: &mut Scope,
239        mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
240    ) {
241        let index_ty = Item::new(u32::as_elem(scope));
242        let len: ExpandElement = self.length.clone().into();
243
244        let mut child = scope.child();
245        let i = child.create_local_restricted(index_ty);
246
247        let index = i.clone().into();
248        let item = index::expand(&mut child, self, index);
249        body(&mut child, item);
250
251        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
252            i: *i,
253            start: 0u32.into(),
254            end: *len,
255            step: None,
256            inclusive: false,
257            scope: child,
258        })));
259    }
260
261    fn expand_unroll(
262        self,
263        _scope: &mut Scope,
264        _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
265    ) {
266        unimplemented!("Can't unroll slice iterator")
267    }
268}
269impl<E: CubePrimitive> CubeIndex for Slice<E, ReadOnly> {
270    type Output = E;
271
272    fn expand_index(
273        scope: &mut Scope,
274        array: Self::ExpandType,
275        index: ExpandElementTyped<u32>,
276    ) -> <Self::Output as CubeType>::ExpandType {
277        array.__expand_read_method(scope, index)
278    }
279}
280
281impl<E: CubePrimitive> CubeIndexExpand for SliceExpand<E, ReadOnly> {
282    type Output = E::ExpandType;
283
284    fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output {
285        self.__expand_read_method(scope, index)
286    }
287    fn expand_index_unchecked(
288        self,
289        scope: &mut Scope,
290        index: ExpandElementTyped<u32>,
291    ) -> Self::Output {
292        self.__expand_read_unchecked_method(scope, index)
293    }
294}
295
296impl<E: CubePrimitive> List<E> for Slice<E, ReadOnly> {}
297impl<E: CubePrimitive> ListExpand<E> for SliceExpand<E, ReadOnly> {
298    fn __expand_read_method(
299        &self,
300        scope: &mut cubecl_ir::Scope,
301        index: ExpandElementTyped<u32>,
302    ) -> <E as CubeType>::ExpandType {
303        read_offset::expand::<E>(
304            scope,
305            self.origin.clone(),
306            self.offset.clone(),
307            index,
308            self.line_size,
309            true,
310        )
311    }
312    fn __expand_read_unchecked_method(
313        &self,
314        scope: &mut cubecl_ir::Scope,
315        index: ExpandElementTyped<u32>,
316    ) -> <E as CubeType>::ExpandType {
317        read_offset::expand::<E>(
318            scope,
319            self.origin.clone(),
320            self.offset.clone(),
321            index,
322            self.line_size,
323            false,
324        )
325    }
326}
327
328impl<E: CubePrimitive> CubeIndex for Slice<E, ReadWrite> {
329    type Output = E;
330
331    fn expand_index(
332        scope: &mut Scope,
333        array: Self::ExpandType,
334        index: ExpandElementTyped<u32>,
335    ) -> <Self::Output as CubeType>::ExpandType {
336        array.__expand_read_method(scope, index)
337    }
338}
339
340impl<E: CubePrimitive> CubeIndexExpand for SliceExpand<E, ReadWrite> {
341    type Output = E::ExpandType;
342
343    fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<u32>) -> Self::Output {
344        self.__expand_read_method(scope, index)
345    }
346    fn expand_index_unchecked(
347        self,
348        scope: &mut Scope,
349        index: ExpandElementTyped<u32>,
350    ) -> Self::Output {
351        self.__expand_read_unchecked_method(scope, index)
352    }
353}
354
355impl<E: CubePrimitive> List<E> for Slice<E, ReadWrite> {}
356impl<E: CubePrimitive> ListExpand<E> for SliceExpand<E, ReadWrite> {
357    fn __expand_read_method(
358        &self,
359        scope: &mut cubecl_ir::Scope,
360        index: ExpandElementTyped<u32>,
361    ) -> <E as CubeType>::ExpandType {
362        read_offset::expand::<E>(
363            scope,
364            self.origin.clone(),
365            self.offset.clone(),
366            index,
367            self.line_size,
368            true,
369        )
370    }
371    fn __expand_read_unchecked_method(
372        &self,
373        scope: &mut cubecl_ir::Scope,
374        index: ExpandElementTyped<u32>,
375    ) -> <E as CubeType>::ExpandType {
376        read_offset::expand::<E>(
377            scope,
378            self.origin.clone(),
379            self.offset.clone(),
380            index,
381            self.line_size,
382            false,
383        )
384    }
385}
386
387impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
388    fn expand_index_mut(
389        scope: &mut Scope,
390        array: Self::ExpandType,
391        index: ExpandElementTyped<u32>,
392        value: ExpandElementTyped<E>,
393    ) {
394        array.__expand_write_method(scope, index, value)
395    }
396}
397
398impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
399    type Output = E::ExpandType;
400
401    fn expand_index_mut(
402        self,
403        scope: &mut Scope,
404        index: ExpandElementTyped<u32>,
405        value: Self::Output,
406    ) {
407        self.__expand_write_method(scope, index, value)
408    }
409}
410
411impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
412impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
413    fn __expand_write_method(
414        &self,
415        scope: &mut cubecl_ir::Scope,
416        index: ExpandElementTyped<u32>,
417        value: ExpandElementTyped<E>,
418    ) {
419        write_offset::expand::<E>(
420            scope,
421            self.origin.clone(),
422            self.offset.clone(),
423            index,
424            value,
425            self.line_size,
426        )
427    }
428}
429
430mod read_offset {
431    use super::*;
432
433    pub fn expand<E: CubePrimitive>(
434        scope: &mut cubecl::prelude::Scope,
435        origin: SliceOriginExpand<E>,
436        offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
437        index: <u32 as cubecl::prelude::CubeType>::ExpandType,
438        line_size: Option<u32>,
439        checked: bool,
440    ) -> <E as cubecl::prelude::CubeType>::ExpandType {
441        let index = cubecl::frontend::add::expand(scope, offset, index);
442
443        match origin {
444            SliceOriginExpand::Tensor(expand) => {
445                expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
446            }
447            SliceOriginExpand::Array(expand) => {
448                expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
449            }
450            SliceOriginExpand::SharedMemory(expand) => {
451                expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
452            }
453        }
454    }
455}
456
457mod write_offset {
458    use super::*;
459
460    pub fn expand<E: CubePrimitive>(
461        scope: &mut cubecl::prelude::Scope,
462        origin: SliceOriginExpand<E>,
463        offset: <u32 as cubecl::prelude::CubeType>::ExpandType,
464        index: <u32 as cubecl::prelude::CubeType>::ExpandType,
465        value: <E as cubecl::prelude::CubeType>::ExpandType,
466        line_size: Option<u32>,
467    ) {
468        let index = cubecl::frontend::add::expand(scope, offset, index);
469
470        match origin {
471            SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
472                scope, expand, index, value, line_size, true,
473            ),
474            SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
475                scope, expand, index, value, line_size, false,
476            ),
477            SliceOriginExpand::SharedMemory(expand) => {
478                expand_index_assign_native::<SharedMemory<E>>(
479                    scope, expand, index, value, line_size, false,
480                )
481            }
482        }
483    }
484}