cubecl_core/frontend/container/slice/
operator.rs

1use super::{ReadOnly, ReadWrite, Slice, SliceExpand, SliceOriginExpand, SliceVisibility};
2use crate as cubecl;
3use crate::{ir::Scope, prelude::*, unexpanded};
4use cubecl_common::tf32;
5use cubecl_ir::ExpandElement;
6
7pub(crate) fn is_tf32<C: CubePrimitive, T: CubePrimitive>(scope: &mut Scope) -> bool {
8    let ty_c = C::as_type(scope);
9    let ty_t = T::as_type(scope);
10    let ty_f32 = f32::as_type(scope);
11    let ty_tf32 = tf32::as_type(scope);
12
13    (ty_c == ty_f32 && ty_t == ty_tf32) || (ty_c == ty_tf32 && ty_t == ty_f32)
14}
15
16impl<E: CubePrimitive> SliceOperator<E> for SharedMemory<E> {}
17impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<SharedMemory<E>> {
18    fn __expand_slice_method(
19        &self,
20        scope: &mut Scope,
21        start: ExpandElementTyped<u32>,
22        end: ExpandElementTyped<u32>,
23    ) -> SliceExpand<E, ReadOnly> {
24        Slice::__expand_new(
25            scope,
26            SliceOriginExpand::SharedMemory(self.clone()),
27            start,
28            end,
29        )
30    }
31
32    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
33        let len = expand_length_native(scope, *self.expand);
34
35        Slice::__expand_new(
36            scope,
37            SliceOriginExpand::SharedMemory(self.clone()),
38            0u32.into(),
39            ExpandElement::Plain(len).into(),
40        )
41    }
42}
43
44impl<E: CubePrimitive> SliceMutOperator<E> for SharedMemory<E> {}
45impl<E: CubePrimitive> SliceMutOperatorExpand<E> for ExpandElementTyped<SharedMemory<E>> {
46    fn __expand_slice_mut_method(
47        &self,
48        scope: &mut Scope,
49        start: ExpandElementTyped<u32>,
50        end: ExpandElementTyped<u32>,
51    ) -> SliceExpand<E, ReadWrite> {
52        Slice::__expand_new(
53            scope,
54            SliceOriginExpand::SharedMemory(self.clone()),
55            start,
56            end,
57        )
58    }
59
60    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
61        let len = expand_length_native(scope, *self.expand);
62
63        Slice::__expand_new(
64            scope,
65            SliceOriginExpand::SharedMemory(self.clone()),
66            0u32.into(),
67            ExpandElement::Plain(len).into(),
68        )
69    }
70}
71
72impl<E: CubePrimitive> SliceOperator<E> for Tensor<E> {}
73impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Tensor<E>> {
74    fn __expand_slice_method(
75        &self,
76        scope: &mut Scope,
77        start: ExpandElementTyped<u32>,
78        end: ExpandElementTyped<u32>,
79    ) -> SliceExpand<E, ReadOnly> {
80        Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
81    }
82
83    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
84        let len = self.clone().__expand_len_method(scope);
85        Slice::__expand_new(
86            scope,
87            SliceOriginExpand::Tensor(self.clone()),
88            0u32.into(),
89            len,
90        )
91    }
92}
93
94impl<E: CubePrimitive> SliceMutOperator<E> for Tensor<E> {}
95impl<E: CubePrimitive> SliceMutOperatorExpand<E> for ExpandElementTyped<Tensor<E>> {
96    fn __expand_slice_mut_method(
97        &self,
98        scope: &mut Scope,
99        start: ExpandElementTyped<u32>,
100        end: ExpandElementTyped<u32>,
101    ) -> SliceExpand<E, ReadWrite> {
102        Slice::__expand_new(scope, SliceOriginExpand::Tensor(self.clone()), start, end)
103    }
104
105    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
106        let len = self.clone().__expand_len_method(scope);
107        Slice::__expand_new(
108            scope,
109            SliceOriginExpand::Tensor(self.clone()),
110            0u32.into(),
111            len,
112        )
113    }
114}
115
116impl<E: CubePrimitive> SliceOperator<E> for Array<E> {}
117impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<Array<E>> {
118    fn __expand_slice_method(
119        &self,
120        scope: &mut Scope,
121        start: ExpandElementTyped<u32>,
122        end: ExpandElementTyped<u32>,
123    ) -> SliceExpand<E, ReadOnly> {
124        Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
125    }
126
127    fn __expand_to_slice_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
128        let len = self.clone().__expand_len_method(scope);
129        Slice::__expand_new(
130            scope,
131            SliceOriginExpand::Array(self.clone()),
132            0u32.into(),
133            len,
134        )
135    }
136}
137
138impl<E: CubePrimitive> SliceMutOperator<E> for Array<E> {}
139impl<E: CubePrimitive> SliceMutOperatorExpand<E> for ExpandElementTyped<Array<E>> {
140    fn __expand_slice_mut_method(
141        &self,
142        scope: &mut Scope,
143        start: ExpandElementTyped<u32>,
144        end: ExpandElementTyped<u32>,
145    ) -> SliceExpand<E, ReadWrite> {
146        Slice::__expand_new(scope, SliceOriginExpand::Array(self.clone()), start, end)
147    }
148
149    fn __expand_to_slice_mut_method(&self, scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
150        let len = self.clone().__expand_len_method(scope);
151        Slice::__expand_new(
152            scope,
153            SliceOriginExpand::Array(self.clone()),
154            0u32.into(),
155            len,
156        )
157    }
158}
159
160impl<E: CubePrimitive, IO: SliceVisibility> SliceOperator<E> for Slice<E, IO> {}
161impl<E: CubePrimitive, IO: SliceVisibility> SliceOperatorExpand<E> for SliceExpand<E, IO> {
162    fn __expand_slice_method(
163        &self,
164        scope: &mut Scope,
165        start: ExpandElementTyped<u32>,
166        end: ExpandElementTyped<u32>,
167    ) -> SliceExpand<E, ReadOnly> {
168        let length = crate::frontend::sub::expand(scope, end, start.clone());
169        let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
170
171        SliceExpand {
172            origin: self.origin.clone(),
173            io: std::marker::PhantomData,
174            offset,
175            length,
176            line_size: None,
177        }
178    }
179
180    fn __expand_to_slice_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadOnly> {
181        SliceExpand {
182            origin: self.origin.clone(),
183            io: std::marker::PhantomData,
184            offset: self.offset.clone(),
185            length: self.length.clone(),
186            line_size: self.line_size,
187        }
188    }
189}
190
191impl<E: CubePrimitive> SliceMutOperator<E> for Slice<E, ReadWrite> {}
192impl<E: CubePrimitive> SliceMutOperatorExpand<E> for SliceExpand<E, ReadWrite> {
193    fn __expand_slice_mut_method(
194        &self,
195        scope: &mut Scope,
196        start: ExpandElementTyped<u32>,
197        end: ExpandElementTyped<u32>,
198    ) -> SliceExpand<E, ReadWrite> {
199        let length = crate::frontend::sub::expand(scope, end, start.clone());
200        let offset = crate::frontend::add::expand(scope, start, self.offset.clone());
201
202        SliceExpand {
203            origin: self.origin.clone(),
204            io: std::marker::PhantomData,
205            offset,
206            length,
207            line_size: None,
208        }
209    }
210
211    fn __expand_to_slice_mut_method(&self, _scope: &mut Scope) -> SliceExpand<E, ReadWrite> {
212        SliceExpand {
213            origin: self.origin.clone(),
214            io: std::marker::PhantomData,
215            offset: self.offset.clone(),
216            length: self.length.clone(),
217            line_size: self.line_size,
218        }
219    }
220}
221
222#[cube(self_type = "ref")]
223pub trait SliceOperator<E: CubePrimitive> {
224    /// Return a read-only view of all elements comprise between the `start` and `end` indices.
225    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
226    /// the length of `self`.
227    #[allow(unused_variables)]
228    fn slice(&self, start: u32, end: u32) -> Slice<E, ReadOnly> {
229        unexpanded!()
230    }
231
232    /// Reinterprete the current type as a read-only slice.
233    #[allow(unused_variables)]
234    fn to_slice(&self) -> Slice<E, ReadOnly> {
235        unexpanded!()
236    }
237}
238
239#[cube(self_type = "ref")]
240pub trait SliceMutOperator<E: CubePrimitive> {
241    /// Return a read-write view of all elements comprise between the `start` and `end` indices.
242    /// In `checked` mode, if the `end` index is out-of-bound, it is replaced by
243    /// the length of `self`.
244    #[allow(unused_variables)]
245    fn slice_mut(&mut self, start: u32, end: u32) -> Slice<E, ReadWrite> {
246        unexpanded!()
247    }
248
249    /// Reinterprete the current type as a read-write slice.
250    #[allow(unused_variables)]
251    fn to_slice_mut(&mut self) -> Slice<E, ReadWrite> {
252        unexpanded!()
253    }
254}
255
256// Automatic implementation for references to SliceOperator.
257impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a L where
258    &'a L: CubeType<ExpandType = L::ExpandType>
259{
260}
261
262// Automatic implementation for mutable references to SliceOperator.
263impl<'a, T: CubePrimitive, L: SliceOperator<T>> SliceOperator<T> for &'a mut L where
264    &'a mut L: CubeType<ExpandType = L::ExpandType>
265{
266}
267
268// Automatic implementation for references to SliceMutOperator.
269impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a L where
270    &'a L: CubeType<ExpandType = L::ExpandType>
271{
272}
273
274// Automatic implementation for mutable references to SliceMutOperator.
275impl<'a, T: CubePrimitive, L: SliceMutOperator<T>> SliceMutOperator<T> for &'a mut L where
276    &'a mut L: CubeType<ExpandType = L::ExpandType>
277{
278}