cubecl_core/frontend/container/slice/
operator.rs1use super::{ReadOnly, ReadWrite, Slice, SliceExpand, SliceOriginExpand, SliceVisibility};
2use crate as cubecl;
3use crate::{ir::Scope, prelude::*, unexpanded};
4use cubecl_common::tf32;
5use cubecl_ir::ExpandElement;
6
7pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
8 let ty_c = C::as_type(scope);
9 let ty_t = T::as_type(scope);
10 let ty_f32 = f32::as_type(scope);
11 let ty_tf32 = tf32::as_type(scope);
12
13 (ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
14}
15
16impl<E: CubePrimitive> SliceOperator<E> for SharedMemory<E> {}
17impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<SharedMemory<E>> {
18 fn __expand_slice_method(
19 &self,
20 scope: &mut Scope,
21 start: ExpandElementTyped<u32>,
22 end: ExpandElementTyped<u32>,
23 ) -> SliceExpand<E, ReadOnly> {
24 Slice::__expand_new(
25 scope,
26 SliceOriginExpand::SharedMemory(self.clone()),
27 start,
28 end,
29 )
30 }
31
32 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
33 let len = expand_length_native(scope, *self.expand);
34
35 Slice::__expand_new(
36 scope,
37 SliceOriginExpand::SharedMemory(self.clone()),
38 0u32.into(),
39 ExpandElement::Plain(len).into(),
40 )
41 }
42}
43
44impl<E: CubePrimitive> SliceMutOperator<E> for SharedMemory<E> {}
45impl<E: CubePrimitive> SliceMutOperatorExpand<E> for ExpandElementTyped<SharedMemory<E>> {
46 fn __expand_slice_mut_method(
47 &self,
48 scope: &mut Scope,
49 start: ExpandElementTyped<u32>,
50 end: ExpandElementTyped<u32>,
51 ) -> SliceExpand<E, ReadWrite> {
52 Slice::__expand_new(
53 scope,
54 SliceOriginExpand::SharedMemory(self.clone()),
55 start,
56 end,
57 )
58 }
59
60 fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
61 let len = expand_length_native(scope, *self.expand);
62
63 Slice::__expand_new(
64 scope,
65 SliceOriginExpand::SharedMemory(self.clone()),
66 0u32.into(),
67 ExpandElement::Plain(len).into(),
68 )
69 }
70}
71
72impl<E: CubePrimitive> SliceOperator<E> for Tensor<E> {}
73impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Tensor<E>> {
74 fn __expand_slice_method(
75 &self,
76 scope: &mut Scope,
77 start: ExpandElementTyped<u32>,
78 end: ExpandElementTyped<u32>,
79 ) -> SliceExpand<E, ReadOnly> {
80 Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
81 }
82
83 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
84 let len = self.clone().__expand_len_method(scope);
85 Slice::__expand_new(
86 scope,
87 SliceOriginExpand::Tensor(self.clone()),
88 0u32.into(),
89 len,
90 )
91 }
92}
93
94impl<E: CubePrimitive> SliceMutOperator<E> for Tensor<E> {}
95impl<E: CubePrimitive> SliceMutOperatorExpand<E> for ExpandElementTyped<Tensor<E>> {
96 fn __expand_slice_mut_method(
97 &self,
98 scope: &mut Scope,
99 start: ExpandElementTyped<u32>,
100 end: ExpandElementTyped<u32>,
101 ) -> SliceExpand<E, ReadWrite> {
102 Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
103 }
104
105 fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
106 let len = self.clone().__expand_len_method(scope);
107 Slice::__expand_new(
108 scope,
109 SliceOriginExpand::Tensor(self.clone()),
110 0u32.into(),
111 len,
112 )
113 }
114}
115
116impl<E: CubePrimitive> SliceOperator<E> for Array<E> {}
117impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Array<E>> {
118 fn __expand_slice_method(
119 &self,
120 scope: &mut Scope,
121 start: ExpandElementTyped<u32>,
122 end: ExpandElementTyped<u32>,
123 ) -> SliceExpand<E, ReadOnly> {
124 Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
125 }
126
127 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
128 let len = self.clone().__expand_len_method(scope);
129 Slice::__expand_new(
130 scope,
131 SliceOriginExpand::Array(self.clone()),
132 0u32.into(),
133 len,
134 )
135 }
136}
137
138impl<E: CubePrimitive> SliceMutOperator<E> for Array<E> {}
139impl<E: CubePrimitive> SliceMutOperatorExpand<E> for ExpandElementTyped<Array<E>> {
140 fn __expand_slice_mut_method(
141 &self,
142 scope: &mut Scope,
143 start: ExpandElementTyped<u32>,
144 end: ExpandElementTyped<u32>,
145 ) -> SliceExpand<E, ReadWrite> {
146 Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
147 }
148
149 fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
150 let len = self.clone().__expand_len_method(scope);
151 Slice::__expand_new(
152 scope,
153 SliceOriginExpand::Array(self.clone()),
154 0u32.into(),
155 len,
156 )
157 }
158}
159
160impl<E: CubePrimitive, IO: SliceVisibility> SliceOperator<E> for Slice<E, IO> {}
161impl<E: CubePrimitive, IO: SliceVisibility> SliceOperatorExpand<E> for SliceExpand<E, IO> {
162 fn __expand_slice_method(
163 &self,
164 scope: &mut Scope,
165 start: ExpandElementTyped<u32>,
166 end: ExpandElementTyped<u32>,
167 ) -> SliceExpand<E, ReadOnly> {
168 let length = crate::frontend::sub::expand(scope, end, start.clone());
169 let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
170
171 SliceExpand {
172 origin: self.origin.clone(),
173 io: std::marker::PhantomData,
174 offset,
175 length,
176 line_size: None,
177 }
178 }
179
180 fn __expand_to_slice_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
181 SliceExpand {
182 origin: self.origin.clone(),
183 io: std::marker::PhantomData,
184 offset: self.offset.clone(),
185 length: self.length.clone(),
186 line_size: self.line_size,
187 }
188 }
189}
190
191impl<E: CubePrimitive> SliceMutOperator<E> for Slice<E, ReadWrite> {}
192impl<E: CubePrimitive> SliceMutOperatorExpand<E> for SliceExpand<E, ReadWrite> {
193 fn __expand_slice_mut_method(
194 &self,
195 scope: &mut Scope,
196 start: ExpandElementTyped<u32>,
197 end: ExpandElementTyped<u32>,
198 ) -> SliceExpand<E, ReadWrite> {
199 let length = crate::frontend::sub::expand(scope, end, start.clone());
200 let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
201
202 SliceExpand {
203 origin: self.origin.clone(),
204 io: std::marker::PhantomData,
205 offset,
206 length,
207 line_size: None,
208 }
209 }
210
211 fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
212 SliceExpand {
213 origin: self.origin.clone(),
214 io: std::marker::PhantomData,
215 offset: self.offset.clone(),
216 length: self.length.clone(),
217 line_size: self.line_size,
218 }
219 }
220}
221
222#[cube(self_type = "ref")]
223pub trait SliceOperator<E: CubePrimitive> {
224 #[allow(unused_variables)]
228 fn slice(&self, start: u32, end: u32) -> Slice<E, ReadOnly> {
229 unexpanded!()
230 }
231
232 #[allow(unused_variables)]
234 fn to_slice(&self) -> Slice<E, ReadOnly> {
235 unexpanded!()
236 }
237}
238
239#[cube(self_type = "ref")]
240pub trait SliceMutOperator<E: CubePrimitive> {
241 #[allow(unused_variables)]
245 fn slice_mut(&mut self, start: u32, end: u32) -> Slice<E, ReadWrite> {
246 unexpanded!()
247 }
248
249 #[allow(unused_variables)]
251 fn to_slice_mut(&mut self) -> Slice<E, ReadWrite> {
252 unexpanded!()
253 }
254}
255
256impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a L where
258 &'a L: CubeType<ExpandType = L::ExpandType>
259{
260}
261
262impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a mut L where
264 &'a mut L: CubeType<ExpandType = L::ExpandType>
265{
266}
267
268impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a L where
270 &'a L: CubeType<ExpandType = L::ExpandType>
271{
272}
273
274impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a mut L where
276 &'a mut L: CubeType<ExpandType = L::ExpandType>
277{
278}