cubecl_core/frontend/operation/
assignation.rs

1use cubecl_ir::{Bitwise, ExpandElement, Operator, Scope};
2
3use crate::ir;
4use crate::{
5    frontend::{Array, SharedMemory, Tensor},
6    prelude::{CubeIndex, CubeIndexMut, CubePrimitive, CubeType},
7};
8
9pub mod cast {
10    use ir::Instruction;
11
12    use crate::prelude::ExpandElementTyped;
13
14    use self::ir::UnaryOperator;
15
16    use super::*;
17
18    pub fn expand<C: CubeType>(
19        scope: &mut Scope,
20        input: ExpandElementTyped<C>,
21        output: ExpandElementTyped<C>,
22    ) {
23        scope.register(Instruction::new(
24            Operator::Cast(UnaryOperator {
25                input: *input.expand,
26            }),
27            *output.expand,
28        ));
29    }
30}
31
32pub mod assign {
33    use ir::{Instruction, Operation};
34
35    use crate::prelude::ExpandElementTyped;
36
37    use super::*;
38
39    /// Expand the assign operation.
40    ///
41    /// If you want to assign to a manually initialized const variable, look into
42    /// [expand_no_check()].
43    pub fn expand<C: CubeType>(
44        scope: &mut Scope,
45        input: ExpandElementTyped<C>,
46        output: ExpandElementTyped<C>,
47    ) {
48        let output = *output.expand;
49        let input = *input.expand;
50
51        if output.is_immutable() {
52            panic!("Can't assign a value to a const variable. Try to use `RuntimeCell`.");
53        }
54
55        scope.register(Instruction::new(Operation::Copy(input), output));
56    }
57    /// Expand the assign operation without any check.
58    ///
59    /// You can't assign to a const variable with this [expand()].
60    pub fn expand_no_check<C: CubeType>(
61        scope: &mut Scope,
62        input: ExpandElementTyped<C>,
63        output: ExpandElementTyped<C>,
64    ) {
65        let output = *output.expand;
66        let input = *input.expand;
67
68        scope.register(Instruction::new(Operation::Copy(input), output));
69    }
70}
71
72pub mod index_assign {
73    use super::*;
74    use crate::prelude::{
75        CubeIndexMutExpand, ExpandElementTyped, Line, expand_index_assign_native,
76    };
77
78    pub fn expand<A: CubeIndexMutExpand<Output = ExpandElementTyped<V>>, V: CubePrimitive>(
79        scope: &mut Scope,
80        expand: A,
81        index: ExpandElementTyped<u32>,
82        value: ExpandElementTyped<V>,
83    ) {
84        expand.expand_index_mut(scope, index, value)
85    }
86
87    macro_rules! impl_index {
88        ($type:ident) => {
89            impl<E: CubePrimitive> CubeIndexMut for $type<E> {}
90
91            impl<E: CubePrimitive> CubeIndexMutExpand for ExpandElementTyped<$type<E>> {
92                type Output = ExpandElementTyped<E>;
93
94                fn expand_index_mut(
95                    self,
96                    scope: &mut Scope,
97                    index: ExpandElementTyped<u32>,
98                    value: Self::Output,
99                ) {
100                    expand_index_assign_native::<$type<E>>(scope, self, index, value, None, true);
101                }
102            }
103        };
104    }
105
106    impl_index!(Array);
107    impl_index!(Tensor);
108    impl_index!(SharedMemory);
109    impl_index!(Line);
110}
111
112pub mod index {
113    use super::*;
114    use crate::prelude::{CubeIndexExpand, ExpandElementTyped, Line, expand_index_native};
115
116    pub fn expand<A: CubeIndexExpand<Output = ExpandElementTyped<V>>, V: CubeType>(
117        scope: &mut Scope,
118        expand: A,
119        index: ExpandElementTyped<u32>,
120    ) -> ExpandElementTyped<V> {
121        expand.expand_index(scope, index)
122    }
123
124    pub fn expand_with<A: CubeIndexExpand<Output = ExpandElementTyped<V>>, V: CubeType>(
125        scope: &mut Scope,
126        expand: A,
127        index: ExpandElementTyped<u32>,
128    ) -> ExpandElementTyped<V> {
129        expand.expand_index(scope, index)
130    }
131
132    macro_rules! impl_index {
133        ($type:ident) => {
134            impl<E: CubePrimitive> CubeIndex for $type<E> {
135                type Output = E;
136            }
137
138            impl<E: CubePrimitive> CubeIndexExpand for ExpandElementTyped<$type<E>> {
139                type Output = ExpandElementTyped<E>;
140
141                fn expand_index(
142                    self,
143                    scope: &mut Scope,
144                    index: ExpandElementTyped<u32>,
145                ) -> Self::Output {
146                    expand_index_native(scope, self, index, None, true)
147                }
148                fn expand_index_unchecked(
149                    self,
150                    scope: &mut Scope,
151                    index: ExpandElementTyped<u32>,
152                ) -> Self::Output {
153                    expand_index_native(scope, self, index, None, false)
154                }
155            }
156        };
157    }
158
159    impl_index!(Array);
160    impl_index!(Tensor);
161    impl_index!(SharedMemory);
162    impl_index!(Line);
163}
164
165pub mod index_unchecked {
166    use super::*;
167    use crate::prelude::{CubeIndexExpand, ExpandElementTyped};
168
169    pub fn expand<A: CubeIndexExpand<Output = ExpandElementTyped<V>>, V: CubeType>(
170        scope: &mut Scope,
171        expand: A,
172        index: ExpandElementTyped<u32>,
173    ) -> ExpandElementTyped<V> {
174        expand.expand_index_unchecked(scope, index)
175    }
176}
177
178pub mod add_assign_array_op {
179    use self::ir::Arithmetic;
180    use super::*;
181    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
182
183    pub fn expand<A: CubeType + CubeIndex>(
184        scope: &mut Scope,
185        array: ExpandElementTyped<A>,
186        index: ExpandElementTyped<u32>,
187        value: ExpandElementTyped<A::Output>,
188    ) where
189        A::Output: CubeType + Sized,
190    {
191        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Add);
192    }
193}
194
195pub mod sub_assign_array_op {
196    use self::ir::Arithmetic;
197    use super::*;
198    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
199
200    pub fn expand<A: CubeType + CubeIndex>(
201        scope: &mut Scope,
202        array: ExpandElementTyped<A>,
203        index: ExpandElementTyped<u32>,
204        value: ExpandElementTyped<A::Output>,
205    ) where
206        A::Output: CubeType + Sized,
207    {
208        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Sub);
209    }
210}
211
212pub mod mul_assign_array_op {
213    use self::ir::Arithmetic;
214    use super::*;
215    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
216
217    pub fn expand<A: CubeType + CubeIndex>(
218        scope: &mut Scope,
219        array: ExpandElementTyped<A>,
220        index: ExpandElementTyped<u32>,
221        value: ExpandElementTyped<A::Output>,
222    ) where
223        A::Output: CubeType + Sized,
224    {
225        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Mul);
226    }
227}
228
229pub mod div_assign_array_op {
230    use self::ir::Arithmetic;
231    use super::*;
232    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
233
234    pub fn expand<A: CubeType + CubeIndex>(
235        scope: &mut Scope,
236        array: ExpandElementTyped<A>,
237        index: ExpandElementTyped<u32>,
238        value: ExpandElementTyped<A::Output>,
239    ) where
240        A::Output: CubeType + Sized,
241    {
242        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Div);
243    }
244}
245
246pub mod rem_assign_array_op {
247    use self::ir::Arithmetic;
248    use super::*;
249    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
250
251    pub fn expand<A: CubeType + CubeIndex>(
252        scope: &mut Scope,
253        array: ExpandElementTyped<A>,
254        index: ExpandElementTyped<u32>,
255        value: ExpandElementTyped<A::Output>,
256    ) where
257        A::Output: CubeType + Sized,
258    {
259        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Modulo);
260    }
261}
262
263pub mod bitor_assign_array_op {
264    use super::*;
265    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
266
267    pub fn expand<A: CubeType + CubeIndex>(
268        scope: &mut Scope,
269        array: ExpandElementTyped<A>,
270        index: ExpandElementTyped<u32>,
271        value: ExpandElementTyped<A::Output>,
272    ) where
273        A::Output: CubeType + Sized,
274    {
275        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseOr);
276    }
277}
278
279pub mod bitand_assign_array_op {
280    use super::*;
281    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
282
283    pub fn expand<A: CubeType + CubeIndex>(
284        scope: &mut Scope,
285        array: ExpandElementTyped<A>,
286        index: ExpandElementTyped<u32>,
287        value: ExpandElementTyped<A::Output>,
288    ) where
289        A::Output: CubeType + Sized,
290    {
291        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseAnd);
292    }
293}
294
295pub mod bitxor_assign_array_op {
296    use super::*;
297    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
298
299    pub fn expand<A: CubeType + CubeIndex>(
300        scope: &mut Scope,
301        array: ExpandElementTyped<A>,
302        index: ExpandElementTyped<u32>,
303        value: ExpandElementTyped<A::Output>,
304    ) where
305        A::Output: CubeType + Sized,
306    {
307        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseXor);
308    }
309}
310
311pub mod shl_assign_array_op {
312
313    use super::*;
314    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
315
316    pub fn expand<A: CubeType + CubeIndex>(
317        scope: &mut Scope,
318        array: ExpandElementTyped<A>,
319        index: ExpandElementTyped<u32>,
320        value: ExpandElementTyped<u32>,
321    ) where
322        A::Output: CubeType + Sized,
323    {
324        array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftLeft);
325    }
326}
327
328pub mod shr_assign_array_op {
329
330    use super::*;
331    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
332
333    pub fn expand<A: CubeType + CubeIndex>(
334        scope: &mut Scope,
335        array: ExpandElementTyped<A>,
336        index: ExpandElementTyped<u32>,
337        value: ExpandElementTyped<u32>,
338    ) where
339        A::Output: CubeType + Sized,
340    {
341        array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftRight);
342    }
343}
344
345pub mod add_assign_op {
346    use std::ops::AddAssign;
347
348    use self::ir::Arithmetic;
349    use crate::{
350        frontend::operation::base::assign_op_expand,
351        prelude::{CubeType, ExpandElementTyped},
352    };
353
354    use super::*;
355
356    pub fn expand<C: CubeType + AddAssign>(
357        scope: &mut Scope,
358        lhs: ExpandElementTyped<C>,
359        rhs: ExpandElementTyped<C>,
360    ) -> ExpandElementTyped<C> {
361        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
362    }
363}
364
365pub mod sub_assign_op {
366    use self::ir::Arithmetic;
367    use super::*;
368    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
369
370    pub fn expand<C: CubeType>(
371        scope: &mut Scope,
372        lhs: ExpandElementTyped<C>,
373        rhs: ExpandElementTyped<C>,
374    ) -> ExpandElement {
375        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub)
376    }
377}
378
379pub mod mul_assign_op {
380    use self::ir::Arithmetic;
381    use super::*;
382    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
383
384    pub fn expand<C: CubeType>(
385        scope: &mut Scope,
386        lhs: ExpandElementTyped<C>,
387        rhs: ExpandElementTyped<C>,
388    ) -> ExpandElement {
389        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul)
390    }
391}
392
393pub mod div_assign_op {
394    use self::ir::Arithmetic;
395    use super::*;
396    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
397
398    pub fn expand<C: CubeType>(
399        scope: &mut Scope,
400        lhs: ExpandElementTyped<C>,
401        rhs: ExpandElementTyped<C>,
402    ) -> ExpandElement {
403        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div)
404    }
405}
406
407pub mod rem_assign_op {
408    use self::ir::Arithmetic;
409    use super::*;
410    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
411
412    pub fn expand<C: CubeType>(
413        scope: &mut Scope,
414        lhs: ExpandElementTyped<C>,
415        rhs: ExpandElementTyped<C>,
416    ) -> ExpandElement {
417        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo)
418    }
419}
420
421pub mod bitor_assign_op {
422
423    use super::*;
424    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
425
426    pub fn expand<C: CubeType>(
427        scope: &mut Scope,
428        lhs: ExpandElementTyped<C>,
429        rhs: ExpandElementTyped<C>,
430    ) -> ExpandElement {
431        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr)
432    }
433}
434
435pub mod bitand_assign_op {
436
437    use super::*;
438    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
439
440    pub fn expand<C: CubeType>(
441        scope: &mut Scope,
442        lhs: ExpandElementTyped<C>,
443        rhs: ExpandElementTyped<C>,
444    ) -> ExpandElement {
445        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd)
446    }
447}
448
449pub mod bitxor_assign_op {
450
451    use super::*;
452    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
453
454    pub fn expand<C: CubeType>(
455        scope: &mut Scope,
456        lhs: ExpandElementTyped<C>,
457        rhs: ExpandElementTyped<C>,
458    ) -> ExpandElement {
459        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor)
460    }
461}
462
463pub mod shl_assign_op {
464
465    use super::*;
466    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
467
468    pub fn expand<C: CubeType>(
469        scope: &mut Scope,
470        lhs: ExpandElementTyped<C>,
471        rhs: ExpandElementTyped<u32>,
472    ) -> ExpandElement {
473        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft)
474    }
475}
476
477pub mod shr_assign_op {
478    use super::*;
479    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
480
481    pub fn expand<C: CubeType>(
482        scope: &mut Scope,
483        lhs: ExpandElementTyped<C>,
484        rhs: ExpandElementTyped<u32>,
485    ) -> ExpandElement {
486        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight)
487    }
488}
489
490pub mod add_assign {
491    use cubecl_ir::Arithmetic;
492
493    use super::*;
494    use crate::prelude::{CubePrimitive, ExpandElementTyped, assign_op_expand};
495
496    pub fn expand<C: CubePrimitive>(
497        scope: &mut Scope,
498        lhs: ExpandElementTyped<C>,
499        rhs: ExpandElementTyped<C>,
500    ) -> ExpandElementTyped<C> {
501        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
502    }
503}