Skip to main content

cubecl_core/frontend/container/slice/
base.rs

1use std::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4};
5
6use crate::{self as cubecl, unexpanded};
7use cubecl::prelude::*;
8use cubecl_ir::{Branch, ElemType, ExpandElement, FloatKind, LineSize, RangeLoop, Type, Variable};
9use cubecl_macros::intrinsic;
10
11#[derive(Clone, Copy)]
12pub struct ReadOnly;
13#[derive(Clone, Copy)]
14pub struct ReadWrite;
15
16/// A read-only contiguous list of elements
17///
18/// # Safety
19///
20/// Since data can't be deallocated during kernel execution, this is safe.
21#[derive(Clone, Copy)]
22pub struct Slice<E: CubePrimitive, IO: SliceVisibility = ReadOnly> {
23    _e: PhantomData<E>,
24    _io: PhantomData<IO>,
25    _offset: PhantomData<usize>,
26    length: usize,
27}
28
29#[derive(CubeType)]
30pub enum SliceOrigin<E: CubePrimitive> {
31    Tensor(Tensor<E>),
32    Array(Array<E>),
33    SharedMemory(SharedMemory<E>),
34}
35
36impl<E: CubePrimitive> SliceOriginExpand<E> {
37    pub fn line_size(&self) -> LineSize {
38        match self {
39            SliceOriginExpand::Tensor(t) => t.line_size(),
40            SliceOriginExpand::Array(t) => t.line_size(),
41            SliceOriginExpand::SharedMemory(t) => t.line_size(),
42        }
43    }
44}
45
46impl<E: CubePrimitive, IO: SliceVisibility> Iterator for Slice<E, IO> {
47    type Item = E;
48
49    fn next(&mut self) -> Option<Self::Item> {
50        unexpanded!()
51    }
52}
53
54pub trait SliceVisibility: Clone + Copy + Send + Sync + 'static {}
55
56impl SliceVisibility for ReadOnly {}
57
58impl SliceVisibility for ReadWrite {}
59
60pub struct SliceExpand<E: CubePrimitive, IO: SliceVisibility> {
61    pub(crate) origin: SliceOriginExpand<E>,
62    pub(crate) io: PhantomData<IO>,
63    pub(crate) offset: ExpandElementTyped<usize>,
64    pub(crate) length: ExpandElementTyped<usize>,
65    pub(crate) line_size: Option<LineSize>,
66}
67
68impl<E: CubePrimitive, IO: SliceVisibility> SliceExpand<E, IO> {
69    pub fn __to_raw_parts(&self) -> (Variable, Variable) {
70        let expand = match self.origin.clone() {
71            SliceOriginExpand::Tensor(expand) => expand.expand,
72            SliceOriginExpand::Array(expand) => expand.expand,
73            SliceOriginExpand::SharedMemory(expand) => expand.expand,
74        };
75
76        (*expand, *self.offset.expand)
77    }
78}
79
80#[cube]
81impl<E: CubePrimitive, IO: SliceVisibility> Slice<Line<E>, IO> {
82    /// Reinterprets how items are loaded and stored in memory.slicebase
83    ///
84    /// # Warning
85    ///
86    /// Currently, this only work with `cube(launch_unchecked)` and is not supported on wgpu.
87    #[allow(unused_variables)]
88    pub fn with_line_size(&self, #[comptime] line_size: LineSize) -> Slice<Line<E>, IO> {
89        intrinsic!(|scope| {
90            let (input, offset) = self.__to_raw_parts();
91            let mut item = input.ty;
92
93            if line_size == item.line_size() {
94                return self;
95            }
96
97            let current = input.ty.line_size();
98            let mut out = self.clone();
99
100            if current < line_size {
101                let ratio = line_size / current;
102                let length = cubecl::frontend::div::expand(scope, self.length, ratio.into());
103                let offset = cubecl::frontend::div::expand(scope, self.offset, ratio.into());
104                out.length = length;
105                out.offset = offset;
106            } else {
107                let ratio = current / line_size;
108                let length = cubecl::frontend::mul::expand(scope, self.length, ratio.into());
109                let offset = cubecl::frontend::mul::expand(scope, self.offset, ratio.into());
110                out.length = length;
111                out.offset = offset;
112            }
113
114            out.line_size = Some(line_size);
115            out
116        })
117    }
118}
119
120#[cube]
121impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
122    /// Returns the same slice, but with lines of length 1.
123    pub fn into_lined(&self) -> Slice<Line<E>, IO> {
124        intrinsic!(|_scope| {
125            SliceExpand::<Line<E>, IO> {
126                origin: self.origin.cast_unchecked(),
127                io: self.io.clone(),
128                offset: self.offset.clone(),
129                length: self.length.clone(),
130                line_size: None,
131            }
132        })
133    }
134    /// Downcast the slice to the given type and panic if the type isn't the same.
135    ///
136    /// This function should only be used to satisfy the Rust type system, when two generic
137    /// types are supposed to be the same.
138    pub fn downcast<T: CubePrimitive>(&self) -> Slice<T, IO> {
139        intrinsic!(|scope| {
140            if T::as_type(scope) != E::as_type(scope) && !is_tf32::<E, T>(scope) {
141                let elems = [T::as_type(scope).elem_type(), E::as_type(scope).elem_type()];
142                let is_flex32_cast = elems.contains(&ElemType::Float(FloatKind::F32))
143                    && elems.contains(&ElemType::Float(FloatKind::Flex32));
144
145                if !is_flex32_cast {
146                    panic!("Downcast should only be used to satisfy the Rust type system.")
147                }
148            }
149
150            SliceExpand::<T, IO> {
151                origin: self.origin.cast_unchecked(),
152                io: self.io.clone(),
153                offset: self.offset.clone(),
154                length: self.length.clone(),
155                line_size: self.line_size.clone(),
156            }
157        })
158    }
159}
160
161#[cube]
162impl<E: CubePrimitive> Slice<E, ReadOnly> {
163    pub fn as_mut_unchecked(&self) -> Slice<E, ReadWrite> {
164        intrinsic!(|scope| {
165            SliceExpand::<E, ReadWrite> {
166                origin: self.origin,
167                io: PhantomData,
168                offset: self.offset.clone(),
169                length: self.length.clone(),
170                line_size: self.line_size.clone(),
171            }
172        })
173    }
174}
175
176impl<E: CubePrimitive> SliceOriginExpand<E> {
177    fn cast_unchecked<T: CubePrimitive>(self) -> SliceOriginExpand<T> {
178        match self {
179            SliceOriginExpand::Tensor(expand) => {
180                SliceOriginExpand::<T>::Tensor(expand.expand.into())
181            }
182            SliceOriginExpand::Array(expand) => SliceOriginExpand::<T>::Array(expand.expand.into()),
183            SliceOriginExpand::SharedMemory(expand) => {
184                SliceOriginExpand::<T>::SharedMemory(expand.expand.into())
185            }
186        }
187    }
188}
189
190impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
191    pub fn new(_origin: SliceOrigin<E>, _offset: usize, _length: usize) -> Self {
192        unexpanded!()
193    }
194    pub fn __expand_new(
195        scope: &mut Scope,
196        origin: SliceOriginExpand<E>,
197        start: ExpandElementTyped<usize>,
198        end: ExpandElementTyped<usize>,
199    ) -> SliceExpand<E, IO> {
200        Self::__expand_new_expand(scope, origin, start, end)
201    }
202    pub fn __expand_new_expand(
203        scope: &mut Scope,
204        origin: SliceOriginExpand<E>,
205        start: ExpandElementTyped<usize>,
206        end: ExpandElementTyped<usize>,
207    ) -> SliceExpand<E, IO> {
208        let length = cubecl::frontend::sub::expand(scope, end, start.clone());
209
210        SliceExpand::<E, IO> {
211            origin,
212            io: PhantomData,
213            offset: start,
214            length,
215            line_size: None,
216        }
217    }
218}
219
220#[cube]
221impl<E: CubePrimitive, IO: SliceVisibility> Slice<E, IO> {
222    /// Get the length of the slice.
223    pub fn len(&self) -> usize {
224        self.length
225    }
226    /// Returns true if the slice is empty.
227    pub fn is_empty(&self) -> bool {
228        self.length == 0
229    }
230}
231
232impl<E: CubePrimitive, IO: SliceVisibility> CubeType for Slice<E, IO> {
233    type ExpandType = SliceExpand<E, IO>;
234}
235
236impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &Slice<E, IO> {
237    type ExpandType = SliceExpand<E, IO>;
238}
239
240impl<E: CubePrimitive, IO: SliceVisibility> CubeType for &mut Slice<E, IO> {
241    type ExpandType = SliceExpand<E, IO>;
242}
243
244impl<E: CubePrimitive, IO: SliceVisibility> IntoMut for SliceExpand<E, IO> {
245    fn into_mut(self, _scope: &mut cubecl_ir::Scope) -> Self {
246        self
247    }
248}
249
250impl<E: CubePrimitive, IO: SliceVisibility> CubeDebug for SliceExpand<E, IO> {}
251impl<E: CubePrimitive, IO: SliceVisibility> Clone for SliceExpand<E, IO> {
252    fn clone(&self) -> Self {
253        Self {
254            origin: self.origin.clone(),
255            offset: self.offset.clone(),
256            length: self.length.clone(),
257            line_size: self.line_size,
258            io: PhantomData,
259        }
260    }
261}
262
263// TODO: Fix
264impl<E: CubePrimitive> SizedContainer for Slice<E, ReadOnly> {
265    type Item = E;
266}
267
268impl<E: CubePrimitive> Iterable<E> for SliceExpand<E, ReadOnly> {
269    fn expand(
270        self,
271        scope: &mut Scope,
272        mut body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
273    ) {
274        let index_ty = Type::new(u32::as_type(scope));
275        let len: ExpandElement = self.length.clone().into();
276
277        let mut child = scope.child();
278        let i = child.create_local_restricted(index_ty);
279
280        let index = i.clone().into();
281        let item = index::expand(&mut child, self, index);
282        body(&mut child, item);
283
284        scope.register(Branch::RangeLoop(Box::new(RangeLoop {
285            i: *i,
286            start: 0usize.into(),
287            end: *len,
288            step: None,
289            inclusive: false,
290            scope: child,
291        })));
292    }
293
294    fn expand_unroll(
295        self,
296        _scope: &mut Scope,
297        _body: impl FnMut(&mut Scope, <E as CubeType>::ExpandType),
298    ) {
299        unimplemented!("Can't unroll slice iterator")
300    }
301}
302impl<E: CubePrimitive, IO: SliceVisibility> CubeIndex for Slice<E, IO> {
303    type Output = E;
304    type Idx = usize;
305
306    fn expand_index(
307        scope: &mut Scope,
308        array: Self::ExpandType,
309        index: ExpandElementTyped<usize>,
310    ) -> <Self::Output as CubeType>::ExpandType {
311        array.__expand_read_method(scope, index)
312    }
313}
314
315impl<E: CubePrimitive, IO: SliceVisibility> CubeIndexExpand for SliceExpand<E, IO> {
316    type Output = E::ExpandType;
317    type Idx = ExpandElementTyped<usize>;
318
319    fn expand_index(self, scope: &mut Scope, index: ExpandElementTyped<usize>) -> Self::Output {
320        self.__expand_read_method(scope, index)
321    }
322    fn expand_index_unchecked(
323        self,
324        scope: &mut Scope,
325        index: ExpandElementTyped<usize>,
326    ) -> Self::Output {
327        self.__expand_read_unchecked_method(scope, index)
328    }
329}
330
331impl<E: CubePrimitive, IO: SliceVisibility> List<E> for Slice<E, IO> {}
332impl<E: CubePrimitive, IO: SliceVisibility> ListExpand<E> for SliceExpand<E, IO> {
333    fn __expand_read_method(
334        &self,
335        scope: &mut cubecl_ir::Scope,
336        index: ExpandElementTyped<usize>,
337    ) -> <E as CubeType>::ExpandType {
338        read_offset::expand::<E>(
339            scope,
340            self.origin.clone(),
341            self.offset.clone(),
342            index,
343            self.line_size,
344            true,
345        )
346    }
347    fn __expand_read_unchecked_method(
348        &self,
349        scope: &mut cubecl_ir::Scope,
350        index: ExpandElementTyped<usize>,
351    ) -> <E as CubeType>::ExpandType {
352        read_offset::expand::<E>(
353            scope,
354            self.origin.clone(),
355            self.offset.clone(),
356            index,
357            self.line_size,
358            false,
359        )
360    }
361
362    fn __expand_len_method(&self, scope: &mut Scope) -> ExpandElementTyped<usize> {
363        Self::__expand_len(scope, self.clone())
364    }
365}
366
367impl<T: CubePrimitive, IO: SliceVisibility> Deref for Slice<T, IO> {
368    type Target = [T];
369
370    fn deref(&self) -> &Self::Target {
371        unexpanded!()
372    }
373}
374
375impl<T: CubePrimitive> DerefMut for Slice<T, ReadWrite> {
376    fn deref_mut(&mut self) -> &mut Self::Target {
377        unexpanded!()
378    }
379}
380
381impl<E: CubePrimitive, IO: SliceVisibility> Lined for Slice<E, IO> {}
382impl<E: CubePrimitive, IO: SliceVisibility> LinedExpand for SliceExpand<E, IO> {
383    fn line_size(&self) -> LineSize {
384        self.line_size.unwrap_or_else(|| self.origin.line_size())
385    }
386}
387
388impl<E: CubePrimitive> CubeIndexMut for Slice<E, ReadWrite> {
389    fn expand_index_mut(
390        scope: &mut Scope,
391        array: Self::ExpandType,
392        index: ExpandElementTyped<usize>,
393        value: ExpandElementTyped<E>,
394    ) {
395        array.__expand_write_method(scope, index, value)
396    }
397}
398
399impl<E: CubePrimitive> CubeIndexMutExpand for SliceExpand<E, ReadWrite> {
400    fn expand_index_mut(
401        self,
402        scope: &mut Scope,
403        index: ExpandElementTyped<usize>,
404        value: Self::Output,
405    ) {
406        self.__expand_write_method(scope, index, value)
407    }
408}
409
410impl<E: CubePrimitive> ListMut<E> for Slice<E, ReadWrite> {}
411impl<E: CubePrimitive> ListMutExpand<E> for SliceExpand<E, ReadWrite> {
412    fn __expand_write_method(
413        &self,
414        scope: &mut cubecl_ir::Scope,
415        index: ExpandElementTyped<usize>,
416        value: ExpandElementTyped<E>,
417    ) {
418        write_offset::expand::<E>(
419            scope,
420            self.origin.clone(),
421            self.offset.clone(),
422            index,
423            value,
424            self.line_size,
425        )
426    }
427}
428
429mod read_offset {
430    use super::*;
431
432    pub fn expand<E: CubePrimitive>(
433        scope: &mut cubecl::prelude::Scope,
434        origin: SliceOriginExpand<E>,
435        offset: <usize as cubecl::prelude::CubeType>::ExpandType,
436        index: <usize as cubecl::prelude::CubeType>::ExpandType,
437        line_size: Option<LineSize>,
438        checked: bool,
439    ) -> <E as cubecl::prelude::CubeType>::ExpandType {
440        let index = cubecl::frontend::add::expand(scope, offset, index);
441
442        match origin {
443            SliceOriginExpand::Tensor(expand) => {
444                expand_index_native::<Tensor<E>>(scope, expand, index, line_size, checked)
445            }
446            SliceOriginExpand::Array(expand) => {
447                expand_index_native::<Array<E>>(scope, expand, index, line_size, checked)
448            }
449            SliceOriginExpand::SharedMemory(expand) => {
450                expand_index_native::<SharedMemory<E>>(scope, expand, index, line_size, checked)
451            }
452        }
453    }
454}
455
456mod write_offset {
457    use super::*;
458
459    pub fn expand<E: CubePrimitive>(
460        scope: &mut cubecl::prelude::Scope,
461        origin: SliceOriginExpand<E>,
462        offset: <usize as cubecl::prelude::CubeType>::ExpandType,
463        index: <usize as cubecl::prelude::CubeType>::ExpandType,
464        value: <E as cubecl::prelude::CubeType>::ExpandType,
465        line_size: Option<LineSize>,
466    ) {
467        let index = cubecl::frontend::add::expand(scope, offset, index);
468
469        match origin {
470            SliceOriginExpand::Tensor(expand) => expand_index_assign_native::<Tensor<E>>(
471                scope, expand, index, value, line_size, true,
472            ),
473            SliceOriginExpand::Array(expand) => expand_index_assign_native::<Array<E>>(
474                scope, expand, index, value, line_size, false,
475            ),
476            SliceOriginExpand::SharedMemory(expand) => {
477                expand_index_assign_native::<SharedMemory<E>>(
478                    scope, expand, index, value, line_size, false,
479                )
480            }
481        }
482    }
483}