cubecl_core/frontend/
list.rs1use super::{CubeType, ExpandElementTyped};
2use crate::unexpanded;
3use cubecl_ir::Scope;
4
5pub trait List<T: CubeType>: CubeType<ExpandType: ListExpand<T>> {
8 #[allow(unused)]
9 fn read(&self, index: u32) -> T {
10 unexpanded!()
11 }
12
13 #[allow(unused)]
14 fn read_unchecked(&self, index: u32) -> T {
15 unexpanded!()
16 }
17
18 fn __expand_read(
19 scope: &mut Scope,
20 this: Self::ExpandType,
21 index: ExpandElementTyped<u32>,
22 ) -> T::ExpandType {
23 this.__expand_read_method(scope, index)
24 }
25
26 fn __expand_read_unchecked(
27 scope: &mut Scope,
28 this: Self::ExpandType,
29 index: ExpandElementTyped<u32>,
30 ) -> T::ExpandType {
31 this.__expand_read_unchecked_method(scope, index)
32 }
33}
34
35pub trait ListExpand<T: CubeType> {
37 fn __expand_read_method(
38 &self,
39 scope: &mut Scope,
40 index: ExpandElementTyped<u32>,
41 ) -> T::ExpandType;
42 fn __expand_read_unchecked_method(
43 &self,
44 scope: &mut Scope,
45 index: ExpandElementTyped<u32>,
46 ) -> T::ExpandType;
47}
48
49pub trait ListMut<T: CubeType>: CubeType<ExpandType: ListMutExpand<T>> + List<T> {
52 #[allow(unused)]
53 fn write(&self, index: u32, value: T) {
54 unexpanded!()
55 }
56
57 fn __expand_write(
58 scope: &mut Scope,
59 this: Self::ExpandType,
60 index: ExpandElementTyped<u32>,
61 value: T::ExpandType,
62 ) {
63 this.__expand_write_method(scope, index, value)
64 }
65}
66
67pub trait ListMutExpand<T: CubeType>: ListExpand<T> {
69 fn __expand_write_method(
70 &self,
71 scope: &mut Scope,
72 index: ExpandElementTyped<u32>,
73 value: T::ExpandType,
74 );
75}
76
77impl<'a, T: CubeType, L: List<T>> List<T> for &'a L
79where
80 &'a L: CubeType<ExpandType = L::ExpandType>,
81{
82 fn read(&self, index: u32) -> T {
83 L::read(self, index)
84 }
85
86 fn __expand_read(
87 scope: &mut Scope,
88 this: Self::ExpandType,
89 index: ExpandElementTyped<u32>,
90 ) -> <T as CubeType>::ExpandType {
91 L::__expand_read(scope, this, index)
92 }
93}
94
95impl<'a, T: CubeType, L: List<T>> List<T> for &'a mut L
97where
98 &'a mut L: CubeType<ExpandType = L::ExpandType>,
99{
100 fn read(&self, index: u32) -> T {
101 L::read(self, index)
102 }
103
104 fn __expand_read(
105 scope: &mut Scope,
106 this: Self::ExpandType,
107 index: ExpandElementTyped<u32>,
108 ) -> <T as CubeType>::ExpandType {
109 L::__expand_read(scope, this, index)
110 }
111}
112
113impl<'a, T: CubeType, L: ListMut<T>> ListMut<T> for &'a L
115where
116 &'a L: CubeType<ExpandType = L::ExpandType>,
117{
118 fn write(&self, index: u32, value: T) {
119 L::write(self, index, value);
120 }
121
122 fn __expand_write(
123 scope: &mut Scope,
124 this: Self::ExpandType,
125 index: ExpandElementTyped<u32>,
126 value: T::ExpandType,
127 ) {
128 L::__expand_write(scope, this, index, value);
129 }
130}
131
132impl<'a, T: CubeType, L: ListMut<T>> ListMut<T> for &'a mut L
134where
135 &'a mut L: CubeType<ExpandType = L::ExpandType>,
136{
137 fn write(&self, index: u32, value: T) {
138 L::write(self, index, value);
139 }
140
141 fn __expand_write(
142 scope: &mut Scope,
143 this: Self::ExpandType,
144 index: ExpandElementTyped<u32>,
145 value: T::ExpandType,
146 ) {
147 L::__expand_write(scope, this, index, value);
148 }
149}