1use super::{ReadOnly, ReadWrite, Slice, SliceExpand, SliceOriginExpand, SliceVisibility};
2use crate::{
3 frontend::{Array, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
4 ir::Scope,
5 prelude::{CubeDebug, SharedMemory, Tensor, expand_length_native},
6 unexpanded,
7};
8use cubecl_common::tf32;
9use cubecl_ir::ExpandElement;
10
11pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
12 let ty_c = C::as_elem(scope);
13 let ty_t = T::as_elem(scope);
14 let ty_f32 = f32::as_elem(scope);
15 let ty_tf32 = tf32::as_elem(scope);
16
17 (ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
18}
19
20impl<E: CubePrimitive> SliceOperator<E> for SharedMemory<E> {}
21impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<SharedMemory<E>> {
22 fn __expand_slice_method(
23 &self,
24 scope: &mut Scope,
25 start: ExpandElementTyped<u32>,
26 end: ExpandElementTyped<u32>,
27 ) -> SliceExpand<E, ReadOnly> {
28 Slice::__expand_new(
29 scope,
30 SliceOriginExpand::SharedMemory(self.clone()),
31 start,
32 end,
33 )
34 }
35
36 fn __expand_slice_mut_method(
37 &self,
38 scope: &mut Scope,
39 start: ExpandElementTyped<u32>,
40 end: ExpandElementTyped<u32>,
41 ) -> SliceExpand<E, ReadWrite> {
42 Slice::__expand_new(
43 scope,
44 SliceOriginExpand::SharedMemory(self.clone()),
45 start,
46 end,
47 )
48 }
49
50 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
51 let len = expand_length_native(scope, *self.expand);
52
53 Slice::__expand_new(
54 scope,
55 SliceOriginExpand::SharedMemory(self.clone()),
56 0u32.into(),
57 ExpandElement::Plain(len).into(),
58 )
59 }
60
61 fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
62 let len = expand_length_native(scope, *self.expand);
63
64 Slice::__expand_new(
65 scope,
66 SliceOriginExpand::SharedMemory(self.clone()),
67 0u32.into(),
68 ExpandElement::Plain(len).into(),
69 )
70 }
71}
72
73impl<E: CubePrimitive> SliceOperator<E> for Tensor<E> {}
74impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Tensor<E>> {
75 fn __expand_slice_method(
76 &self,
77 scope: &mut Scope,
78 start: ExpandElementTyped<u32>,
79 end: ExpandElementTyped<u32>,
80 ) -> SliceExpand<E, ReadOnly> {
81 Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
82 }
83
84 fn __expand_slice_mut_method(
85 &self,
86 scope: &mut Scope,
87 start: ExpandElementTyped<u32>,
88 end: ExpandElementTyped<u32>,
89 ) -> SliceExpand<E, ReadWrite> {
90 Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
91 }
92
93 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
94 let len = self.clone().__expand_len_method(scope);
95 Slice::__expand_new(
96 scope,
97 SliceOriginExpand::Tensor(self.clone()),
98 0u32.into(),
99 len,
100 )
101 }
102
103 fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
104 let len = self.clone().__expand_len_method(scope);
105 Slice::__expand_new(
106 scope,
107 SliceOriginExpand::Tensor(self.clone()),
108 0u32.into(),
109 len,
110 )
111 }
112}
113
114impl<E: CubePrimitive> SliceOperator<E> for Array<E> {}
115impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Array<E>> {
116 fn __expand_slice_method(
117 &self,
118 scope: &mut Scope,
119 start: ExpandElementTyped<u32>,
120 end: ExpandElementTyped<u32>,
121 ) -> SliceExpand<E, ReadOnly> {
122 Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
123 }
124
125 fn __expand_slice_mut_method(
126 &self,
127 scope: &mut Scope,
128 start: ExpandElementTyped<u32>,
129 end: ExpandElementTyped<u32>,
130 ) -> SliceExpand<E, ReadWrite> {
131 Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
132 }
133
134 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
135 let len = self.clone().__expand_len_method(scope);
136 Slice::__expand_new(
137 scope,
138 SliceOriginExpand::Array(self.clone()),
139 0u32.into(),
140 len,
141 )
142 }
143
144 fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
145 let len = self.clone().__expand_len_method(scope);
146 Slice::__expand_new(
147 scope,
148 SliceOriginExpand::Array(self.clone()),
149 0u32.into(),
150 len,
151 )
152 }
153}
154
155impl<E: CubePrimitive, IO: SliceVisibility> SliceOperator<E> for Slice<E, IO> {}
156impl<E: CubePrimitive, IO: SliceVisibility> SliceOperatorExpand<E> for SliceExpand<E, IO> {
157 fn __expand_slice_method(
158 &self,
159 scope: &mut Scope,
160 start: ExpandElementTyped<u32>,
161 end: ExpandElementTyped<u32>,
162 ) -> SliceExpand<E, ReadOnly> {
163 let length = crate::frontend::sub::expand(scope, end, start.clone());
164 let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
165
166 SliceExpand {
167 origin: self.origin.clone(),
168 io: std::marker::PhantomData,
169 offset,
170 length,
171 line_size: None,
172 }
173 }
174
175 fn __expand_slice_mut_method(
176 &self,
177 scope: &mut Scope,
178 start: ExpandElementTyped<u32>,
179 end: ExpandElementTyped<u32>,
180 ) -> SliceExpand<E, ReadWrite> {
181 let length = crate::frontend::sub::expand(scope, end, start.clone());
182 let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
183
184 SliceExpand {
185 origin: self.origin.clone(),
186 io: std::marker::PhantomData,
187 offset,
188 length,
189 line_size: None,
190 }
191 }
192
193 fn __expand_to_slice_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
194 SliceExpand {
195 origin: self.origin.clone(),
196 io: std::marker::PhantomData,
197 offset: self.offset.clone(),
198 length: self.length.clone(),
199 line_size: self.line_size,
200 }
201 }
202
203 fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
204 SliceExpand {
205 origin: self.origin.clone(),
206 io: std::marker::PhantomData,
207 offset: self.offset.clone(),
208 length: self.length.clone(),
209 line_size: self.line_size,
210 }
211 }
212}
213
214pub trait SliceOperator<E: CubePrimitive>: CubeType<ExpandType: SliceOperatorExpand<E>> {
215 #[allow(unused_variables)]
219 fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> Slice<E, ReadOnly> {
220 unexpanded!()
221 }
222 fn __expand_slice(
224 scope: &mut Scope,
225 expand: Self::ExpandType,
226 start: ExpandElementTyped<u32>,
227 end: ExpandElementTyped<u32>,
228 ) -> SliceExpand<E, ReadOnly> {
229 expand.__expand_slice_method(scope, start, end)
230 }
231
232 #[allow(unused_variables)]
236 fn slice_mut<Start: Index, End: Index>(
237 &mut self,
238 start: Start,
239 end: End,
240 ) -> Slice<E, ReadWrite> {
241 unexpanded!()
242 }
243
244 fn __expand_slice_mut(
246 scope: &mut Scope,
247 expand: Self::ExpandType,
248 start: ExpandElementTyped<u32>,
249 end: ExpandElementTyped<u32>,
250 ) -> SliceExpand<E, ReadWrite> {
251 expand.__expand_slice_mut_method(scope, start, end)
252 }
253
254 #[allow(unused_variables)]
256 fn to_slice(&self) -> Slice<E, ReadOnly> {
257 unexpanded!()
258 }
259
260 fn __expand_to_slice(scope: &mut Scope, expand: Self::ExpandType) -> SliceExpand<E, ReadOnly> {
262 expand.__expand_to_slice_method(scope)
263 }
264
265 #[allow(unused_variables, clippy::wrong_self_convention)]
267 fn to_slice_mut(&mut self) -> Slice<E, ReadWrite> {
268 unexpanded!()
269 }
270
271 fn __expand_to_slice_mut(
273 scope: &mut Scope,
274 expand: Self::ExpandType,
275 ) -> SliceExpand<E, ReadWrite> {
276 expand.__expand_to_slice_mut_method(scope)
277 }
278}
279
280pub trait SliceOperatorExpand<E: CubePrimitive>: Clone + IntoMut + CubeDebug {
281 fn __expand_slice_method(
282 &self,
283 scope: &mut Scope,
284 start: ExpandElementTyped<u32>,
285 end: ExpandElementTyped<u32>,
286 ) -> SliceExpand<E, ReadOnly>;
287
288 fn __expand_slice_mut_method(
289 &self,
290 scope: &mut Scope,
291 start: ExpandElementTyped<u32>,
292 end: ExpandElementTyped<u32>,
293 ) -> SliceExpand<E, ReadWrite>;
294
295 fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly>;
296 fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite>;
297}