Skip to main content

cubecl_core/frontend/
list.rs

1use std::ops::{Deref, DerefMut};
2
3use super::{CubeType, ExpandElementTyped};
4use crate as cubecl;
5use crate::{prelude::*, unexpanded};
6use cubecl_ir::{LineSize, Scope};
7
8/// Type from which we can read values in cube functions.
9/// For a mutable version, see [`ListMut`].
10#[allow(clippy::len_without_is_empty)]
11#[cube(self_type = "ref", expand_base_traits = "SliceOperatorExpand<T>")]
12pub trait List<T: CubePrimitive>: SliceOperator<T> + Lined + Deref<Target = [T]> {
13    #[allow(unused)]
14    fn read(&self, index: usize) -> T {
15        unexpanded!()
16    }
17
18    #[allow(unused)]
19    fn read_unchecked(&self, index: usize) -> T {
20        unexpanded!()
21    }
22
23    #[allow(unused)]
24    fn len(&self) -> usize {
25        unexpanded!();
26    }
27}
28
29/// Type for which we can read and write values in cube functions.
30/// For an immutable version, see [List].
31#[cube(self_type = "ref", expand_base_traits = "SliceMutOperatorExpand<T>")]
32pub trait ListMut<T: CubePrimitive>:
33    List<T> + SliceMutOperator<T> + DerefMut<Target = [T]>
34{
35    #[allow(unused)]
36    fn write(&self, index: usize, value: T) {
37        unexpanded!()
38    }
39}
40
41// Automatic implementation for references to List.
42impl<'a, T: CubePrimitive, L: List<T>> List<T> for &'a L
43where
44    &'a L: CubeType<ExpandType = L::ExpandType>,
45    &'a L: Deref<Target = [T]>,
46{
47    fn read(&self, index: usize) -> T {
48        L::read(self, index)
49    }
50
51    fn __expand_read(
52        scope: &mut Scope,
53        this: Self::ExpandType,
54        index: ExpandElementTyped<usize>,
55    ) -> <T as CubeType>::ExpandType {
56        L::__expand_read(scope, this, index)
57    }
58}
59
60// Automatic implementation for mutable references to List.
61impl<'a, T: CubePrimitive, L: List<T>> List<T> for &'a mut L
62where
63    &'a mut L: CubeType<ExpandType = L::ExpandType>,
64    &'a mut L: Deref<Target = [T]>,
65{
66    fn read(&self, index: usize) -> T {
67        L::read(self, index)
68    }
69
70    fn __expand_read(
71        scope: &mut Scope,
72        this: Self::ExpandType,
73        index: ExpandElementTyped<usize>,
74    ) -> <T as CubeType>::ExpandType {
75        L::__expand_read(scope, this, index)
76    }
77}
78
79// Automatic implementation for references to ListMut.
80impl<'a, T: CubePrimitive, L: ListMut<T>> ListMut<T> for &'a L
81where
82    &'a L: CubeType<ExpandType = L::ExpandType>,
83    &'a L: DerefMut<Target = [T]>,
84{
85    fn write(&self, index: usize, value: T) {
86        L::write(self, index, value);
87    }
88
89    fn __expand_write(
90        scope: &mut Scope,
91        this: Self::ExpandType,
92        index: ExpandElementTyped<usize>,
93        value: T::ExpandType,
94    ) {
95        L::__expand_write(scope, this, index, value);
96    }
97}
98
99// Automatic implementation for mutable references to ListMut.
100impl<'a, T: CubePrimitive, L: ListMut<T>> ListMut<T> for &'a mut L
101where
102    &'a mut L: CubeType<ExpandType = L::ExpandType>,
103    &'a mut L: DerefMut<Target = [T]>,
104{
105    fn write(&self, index: usize, value: T) {
106        L::write(self, index, value);
107    }
108
109    fn __expand_write(
110        scope: &mut Scope,
111        this: Self::ExpandType,
112        index: ExpandElementTyped<usize>,
113        value: T::ExpandType,
114    ) {
115        L::__expand_write(scope, this, index, value);
116    }
117}
118
119pub trait Lined: CubeType<ExpandType: LinedExpand> {
120    fn line_size(&self) -> LineSize {
121        unexpanded!()
122    }
123    fn __expand_line_size(_scope: &mut Scope, this: Self::ExpandType) -> LineSize {
124        this.line_size()
125    }
126}
127
128pub trait LinedExpand {
129    fn line_size(&self) -> LineSize;
130    fn __expand_line_size_method(&self, _scope: &mut Scope) -> LineSize {
131        self.line_size()
132    }
133}
134
135impl<'a, L: Lined> Lined for &'a L where &'a L: CubeType<ExpandType: LinedExpand> {}
136impl<'a, L: Lined> Lined for &'a mut L where &'a mut L: CubeType<ExpandType: LinedExpand> {}