cubecl_core/frontend/
list.rs

1use super::{CubeType, ExpandElementTyped};
2use crate::unexpanded;
3use cubecl_ir::Scope;
4
5/// Type from which we can read values in cube functions.
6/// For a mutable version, see [ListMut].
7pub trait List<T: CubeType>: CubeType<ExpandType: ListExpand<T>> {
8    #[allow(unused)]
9    fn read(&self, index: u32) -> T {
10        unexpanded!()
11    }
12
13    #[allow(unused)]
14    fn read_unchecked(&self, index: u32) -> T {
15        unexpanded!()
16    }
17
18    fn __expand_read(
19        scope: &mut Scope,
20        this: Self::ExpandType,
21        index: ExpandElementTyped<u32>,
22    ) -> T::ExpandType {
23        this.__expand_read_method(scope, index)
24    }
25
26    fn __expand_read_unchecked(
27        scope: &mut Scope,
28        this: Self::ExpandType,
29        index: ExpandElementTyped<u32>,
30    ) -> T::ExpandType {
31        this.__expand_read_unchecked_method(scope, index)
32    }
33}
34
35/// Expand version of [CubeRead].
36pub trait ListExpand<T: CubeType> {
37    fn __expand_read_method(
38        &self,
39        scope: &mut Scope,
40        index: ExpandElementTyped<u32>,
41    ) -> T::ExpandType;
42    fn __expand_read_unchecked_method(
43        &self,
44        scope: &mut Scope,
45        index: ExpandElementTyped<u32>,
46    ) -> T::ExpandType;
47}
48
49/// Type for which we can read and write values in cube functions.
50/// For an immutable version, see [List].
51pub trait ListMut<T: CubeType>: CubeType<ExpandType: ListMutExpand<T>> + List<T> {
52    #[allow(unused)]
53    fn write(&self, index: u32, value: T) {
54        unexpanded!()
55    }
56
57    fn __expand_write(
58        scope: &mut Scope,
59        this: Self::ExpandType,
60        index: ExpandElementTyped<u32>,
61        value: T::ExpandType,
62    ) {
63        this.__expand_write_method(scope, index, value)
64    }
65}
66
67/// Expand version of [CubeWrite].
68pub trait ListMutExpand<T: CubeType>: ListExpand<T> {
69    fn __expand_write_method(
70        &self,
71        scope: &mut Scope,
72        index: ExpandElementTyped<u32>,
73        value: T::ExpandType,
74    );
75}
76
77// Automatic implementation for mutable references to List.
78impl<'a, T: CubeType, L: List<T>> List<T> for &'a L
79where
80    &'a L: CubeType<ExpandType = L::ExpandType>,
81{
82    fn read(&self, index: u32) -> T {
83        L::read(self, index)
84    }
85
86    fn __expand_read(
87        scope: &mut Scope,
88        this: Self::ExpandType,
89        index: ExpandElementTyped<u32>,
90    ) -> <T as CubeType>::ExpandType {
91        L::__expand_read(scope, this, index)
92    }
93}
94
95// Automatic implementation for mutable references to List.
96impl<'a, T: CubeType, L: List<T>> List<T> for &'a mut L
97where
98    &'a mut L: CubeType<ExpandType = L::ExpandType>,
99{
100    fn read(&self, index: u32) -> T {
101        L::read(self, index)
102    }
103
104    fn __expand_read(
105        scope: &mut Scope,
106        this: Self::ExpandType,
107        index: ExpandElementTyped<u32>,
108    ) -> <T as CubeType>::ExpandType {
109        L::__expand_read(scope, this, index)
110    }
111}
112
113// Automatic implementation for references to ListMut.
114impl<'a, T: CubeType, L: ListMut<T>> ListMut<T> for &'a L
115where
116    &'a L: CubeType<ExpandType = L::ExpandType>,
117{
118    fn write(&self, index: u32, value: T) {
119        L::write(self, index, value);
120    }
121
122    fn __expand_write(
123        scope: &mut Scope,
124        this: Self::ExpandType,
125        index: ExpandElementTyped<u32>,
126        value: T::ExpandType,
127    ) {
128        L::__expand_write(scope, this, index, value);
129    }
130}
131
132// Automatic implementation for references to ListMut.
133impl<'a, T: CubeType, L: ListMut<T>> ListMut<T> for &'a mut L
134where
135    &'a mut L: CubeType<ExpandType = L::ExpandType>,
136{
137    fn write(&self, index: u32, value: T) {
138        L::write(self, index, value);
139    }
140
141    fn __expand_write(
142        scope: &mut Scope,
143        this: Self::ExpandType,
144        index: ExpandElementTyped<u32>,
145        value: T::ExpandType,
146    ) {
147        L::__expand_write(scope, this, index, value);
148    }
149}