cubecl_core/frontend/container/slice/
operator.rs

1use super::{ReadOnly, ReadWrite, Slice, SliceExpand, SliceOriginExpand, SliceVisibility};
2use crate::{
3    frontend::{Array, CubePrimitive, CubeType, ExpandElementTyped, IntoMut, indexation::Index},
4    ir::Scope,
5    prelude::{CubeDebug, SharedMemory, Tensor, expand_length_native},
6    unexpanded,
7};
8use cubecl_common::tf32;
9use cubecl_ir::ExpandElement;
10
11pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
12    let ty_c = C::as_elem(scope);
13    let ty_t = T::as_elem(scope);
14    let ty_f32 = f32::as_elem(scope);
15    let ty_tf32 = tf32::as_elem(scope);
16
17    (ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
18}
19
20impl<E: CubePrimitive> SliceOperator<E> for SharedMemory<E> {}
21impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<SharedMemory<E>> {
22    fn __expand_slice_method(
23        &self,
24        scope: &mut Scope,
25        start: ExpandElementTyped<u32>,
26        end: ExpandElementTyped<u32>,
27    ) -> SliceExpand<E, ReadOnly> {
28        Slice::__expand_new(
29            scope,
30            SliceOriginExpand::SharedMemory(self.clone()),
31            start,
32            end,
33        )
34    }
35
36    fn __expand_slice_mut_method(
37        &self,
38        scope: &mut Scope,
39        start: ExpandElementTyped<u32>,
40        end: ExpandElementTyped<u32>,
41    ) -> SliceExpand<E, ReadWrite> {
42        Slice::__expand_new(
43            scope,
44            SliceOriginExpand::SharedMemory(self.clone()),
45            start,
46            end,
47        )
48    }
49
50    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
51        let len = expand_length_native(scope, *self.expand);
52
53        Slice::__expand_new(
54            scope,
55            SliceOriginExpand::SharedMemory(self.clone()),
56            0u32.into(),
57            ExpandElement::Plain(len).into(),
58        )
59    }
60
61    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
62        let len = expand_length_native(scope, *self.expand);
63
64        Slice::__expand_new(
65            scope,
66            SliceOriginExpand::SharedMemory(self.clone()),
67            0u32.into(),
68            ExpandElement::Plain(len).into(),
69        )
70    }
71}
72
73impl<E: CubePrimitive> SliceOperator<E> for Tensor<E> {}
74impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Tensor<E>> {
75    fn __expand_slice_method(
76        &self,
77        scope: &mut Scope,
78        start: ExpandElementTyped<u32>,
79        end: ExpandElementTyped<u32>,
80    ) -> SliceExpand<E, ReadOnly> {
81        Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
82    }
83
84    fn __expand_slice_mut_method(
85        &self,
86        scope: &mut Scope,
87        start: ExpandElementTyped<u32>,
88        end: ExpandElementTyped<u32>,
89    ) -> SliceExpand<E, ReadWrite> {
90        Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
91    }
92
93    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
94        let len = self.clone().__expand_len_method(scope);
95        Slice::__expand_new(
96            scope,
97            SliceOriginExpand::Tensor(self.clone()),
98            0u32.into(),
99            len,
100        )
101    }
102
103    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
104        let len = self.clone().__expand_len_method(scope);
105        Slice::__expand_new(
106            scope,
107            SliceOriginExpand::Tensor(self.clone()),
108            0u32.into(),
109            len,
110        )
111    }
112}
113
114impl<E: CubePrimitive> SliceOperator<E> for Array<E> {}
115impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Array<E>> {
116    fn __expand_slice_method(
117        &self,
118        scope: &mut Scope,
119        start: ExpandElementTyped<u32>,
120        end: ExpandElementTyped<u32>,
121    ) -> SliceExpand<E, ReadOnly> {
122        Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
123    }
124
125    fn __expand_slice_mut_method(
126        &self,
127        scope: &mut Scope,
128        start: ExpandElementTyped<u32>,
129        end: ExpandElementTyped<u32>,
130    ) -> SliceExpand<E, ReadWrite> {
131        Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
132    }
133
134    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
135        let len = self.clone().__expand_len_method(scope);
136        Slice::__expand_new(
137            scope,
138            SliceOriginExpand::Array(self.clone()),
139            0u32.into(),
140            len,
141        )
142    }
143
144    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
145        let len = self.clone().__expand_len_method(scope);
146        Slice::__expand_new(
147            scope,
148            SliceOriginExpand::Array(self.clone()),
149            0u32.into(),
150            len,
151        )
152    }
153}
154
155impl<E: CubePrimitive, IO: SliceVisibility> SliceOperator<E> for Slice<E, IO> {}
156impl<E: CubePrimitive, IO: SliceVisibility> SliceOperatorExpand<E> for SliceExpand<E, IO> {
157    fn __expand_slice_method(
158        &self,
159        scope: &mut Scope,
160        start: ExpandElementTyped<u32>,
161        end: ExpandElementTyped<u32>,
162    ) -> SliceExpand<E, ReadOnly> {
163        let length = crate::frontend::sub::expand(scope, end, start.clone());
164        let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
165
166        SliceExpand {
167            origin: self.origin.clone(),
168            io: std::marker::PhantomData,
169            offset,
170            length,
171            line_size: None,
172        }
173    }
174
175    fn __expand_slice_mut_method(
176        &self,
177        scope: &mut Scope,
178        start: ExpandElementTyped<u32>,
179        end: ExpandElementTyped<u32>,
180    ) -> SliceExpand<E, ReadWrite> {
181        let length = crate::frontend::sub::expand(scope, end, start.clone());
182        let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
183
184        SliceExpand {
185            origin: self.origin.clone(),
186            io: std::marker::PhantomData,
187            offset,
188            length,
189            line_size: None,
190        }
191    }
192
193    fn __expand_to_slice_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
194        SliceExpand {
195            origin: self.origin.clone(),
196            io: std::marker::PhantomData,
197            offset: self.offset.clone(),
198            length: self.length.clone(),
199            line_size: self.line_size,
200        }
201    }
202
203    fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
204        SliceExpand {
205            origin: self.origin.clone(),
206            io: std::marker::PhantomData,
207            offset: self.offset.clone(),
208            length: self.length.clone(),
209            line_size: self.line_size,
210        }
211    }
212}
213
214pub trait SliceOperator<E: CubePrimitive>: CubeType<ExpandType: SliceOperatorExpand<E>> {
215    /// Return a read-only view of all elements comprise between the `start` and `end` indices.
216    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
217    /// the length of `self`.
218    #[allow(unused_variables)]
219    fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> Slice<E, ReadOnly> {
220        unexpanded!()
221    }
222    /// Expand function of [SliceOperator::slice].
223    fn __expand_slice(
224        scope: &mut Scope,
225        expand: Self::ExpandType,
226        start: ExpandElementTyped<u32>,
227        end: ExpandElementTyped<u32>,
228    ) -> SliceExpand<E, ReadOnly> {
229        expand.__expand_slice_method(scope, start, end)
230    }
231
232    /// Return a read-write view of all elements comprise between the `start` and `end` indices.
233    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
234    /// the length of `self`.
235    #[allow(unused_variables)]
236    fn slice_mut<Start: Index, End: Index>(
237        &mut self,
238        start: Start,
239        end: End,
240    ) -> Slice<E, ReadWrite> {
241        unexpanded!()
242    }
243
244    /// Expand function of [SliceOperator::slice_mut].
245    fn __expand_slice_mut(
246        scope: &mut Scope,
247        expand: Self::ExpandType,
248        start: ExpandElementTyped<u32>,
249        end: ExpandElementTyped<u32>,
250    ) -> SliceExpand<E, ReadWrite> {
251        expand.__expand_slice_mut_method(scope, start, end)
252    }
253
254    /// Reinterprete the current type as a read-only slice.
255    #[allow(unused_variables)]
256    fn to_slice(&self) -> Slice<E, ReadOnly> {
257        unexpanded!()
258    }
259
260    /// Expand function of [SliceOperator::to_slice].
261    fn __expand_to_slice(scope: &mut Scope, expand: Self::ExpandType) -> SliceExpand<E, ReadOnly> {
262        expand.__expand_to_slice_method(scope)
263    }
264
265    /// Reinterprete the current type as a read-write slice.
266    #[allow(unused_variables, clippy::wrong_self_convention)]
267    fn to_slice_mut(&mut self) -> Slice<E, ReadWrite> {
268        unexpanded!()
269    }
270
271    /// Expand function of [SliceOperator::to_slice_mut].
272    fn __expand_to_slice_mut(
273        scope: &mut Scope,
274        expand: Self::ExpandType,
275    ) -> SliceExpand<E, ReadWrite> {
276        expand.__expand_to_slice_mut_method(scope)
277    }
278}
279
280pub trait SliceOperatorExpand<E: CubePrimitive>: Clone + IntoMut + CubeDebug {
281    fn __expand_slice_method(
282        &self,
283        scope: &mut Scope,
284        start: ExpandElementTyped<u32>,
285        end: ExpandElementTyped<u32>,
286    ) -> SliceExpand<E, ReadOnly>;
287
288    fn __expand_slice_mut_method(
289        &self,
290        scope: &mut Scope,
291        start: ExpandElementTyped<u32>,
292        end: ExpandElementTyped<u32>,
293    ) -> SliceExpand<E, ReadWrite>;
294
295    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly>;
296    fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite>;
297}