cubecl_core/frontend/operation/
assignation.rs

1use cubecl_ir::{Bitwise, ExpandElement, Operator, Scope};
2use half::{bf16, f16};
3
4use crate::{
5    frontend::{Array, SharedMemory, Tensor},
6    prelude::{CubeIndex, CubeIndexMut, CubeType},
7};
8use crate::{ir, prelude::Index};
9
10pub mod cast {
11    use ir::Instruction;
12
13    use crate::prelude::ExpandElementTyped;
14
15    use self::ir::UnaryOperator;
16
17    use super::*;
18
19    pub fn expand<C: CubeType>(
20        scope: &mut Scope,
21        input: ExpandElementTyped<C>,
22        output: ExpandElementTyped<C>,
23    ) {
24        scope.register(Instruction::new(
25            Operator::Cast(UnaryOperator {
26                input: *input.expand,
27            }),
28            *output.expand,
29        ));
30    }
31}
32
33pub mod assign {
34    use ir::{Instruction, Operation};
35
36    use crate::prelude::ExpandElementTyped;
37
38    use super::*;
39
40    pub fn expand<C: CubeType>(
41        scope: &mut Scope,
42        input: ExpandElementTyped<C>,
43        output: ExpandElementTyped<C>,
44    ) {
45        scope.register(Instruction::new(
46            Operation::Copy(*input.expand),
47            *output.expand,
48        ));
49    }
50}
51
52pub mod index_assign {
53    use ir::{Instruction, UIntKind, VariableKind};
54
55    use crate::{
56        flex32,
57        frontend::CubeType,
58        prelude::{ExpandElementTyped, SliceMut},
59        tf32,
60    };
61
62    use self::ir::{BinaryOperator, Variable};
63
64    use super::*;
65
66    pub fn expand<A: CubeType + CubeIndex<u32>>(
67        scope: &mut Scope,
68        array: ExpandElementTyped<A>,
69        index: ExpandElementTyped<u32>,
70        value: ExpandElementTyped<A::Output>,
71    ) where
72        A::Output: CubeType + Sized,
73    {
74        let index: Variable = index.expand.into();
75        let index = match index.kind {
76            VariableKind::ConstantScalar(value) => {
77                Variable::constant(ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32))
78            }
79            _ => index,
80        };
81        scope.register(Instruction::new(
82            Operator::IndexAssign(BinaryOperator {
83                lhs: index,
84                rhs: value.expand.into(),
85            }),
86            array.expand.into(),
87        ));
88    }
89
90    macro_rules! impl_index {
91        ($type:ident) => {
92            impl<E: CubeType, I: Index> CubeIndexMut<I> for $type<E> {}
93        };
94    }
95    macro_rules! impl_index_vec {
96        ($($type:ident),*) => {
97            $(
98                impl<I: Index> CubeIndexMut<I> for $type {}
99            )*
100        };
101    }
102
103    impl_index!(Array);
104    impl_index!(Tensor);
105    impl_index!(SharedMemory);
106    impl_index_vec!(
107        i64, i32, i16, i8, f16, bf16, flex32, tf32, f32, f64, u64, u32, u16, u8
108    );
109
110    impl<E: CubeType, I: Index> CubeIndexMut<I> for SliceMut<E> {}
111}
112
113pub mod index {
114    use ir::{UIntKind, VariableKind};
115
116    use crate::{
117        flex32,
118        frontend::{
119            CubeType,
120            operation::base::{binary_expand, binary_expand_no_vec},
121        },
122        prelude::{ExpandElementTyped, Slice, SliceMut},
123        tf32,
124    };
125
126    use self::ir::Variable;
127
128    use super::*;
129
130    pub fn expand<A: CubeType + CubeIndex<ExpandElementTyped<u32>>>(
131        scope: &mut Scope,
132        array: ExpandElementTyped<A>,
133        index: ExpandElementTyped<u32>,
134    ) -> ExpandElementTyped<A::Output>
135    where
136        A::Output: CubeType + Sized,
137    {
138        let index: ExpandElement = index.into();
139        let index_var: Variable = *index;
140        let index = match index_var.kind {
141            VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
142                ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
143            )),
144            _ => index,
145        };
146        let array: ExpandElement = array.into();
147        let var: Variable = *array;
148        let var = match var.kind {
149            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
150                binary_expand_no_vec(scope, array, index, Operator::Index)
151            }
152            _ => binary_expand(scope, array, index, Operator::Index),
153        };
154
155        ExpandElementTyped::new(var)
156    }
157
158    macro_rules! impl_index {
159        ($type:ident) => {
160            impl<E: CubeType, I: Index> CubeIndex<I> for $type<E> {
161                type Output = E;
162            }
163        };
164    }
165    macro_rules! impl_index_vec {
166        ($($type:ident),*) => {
167            $(
168                impl<I: Index> CubeIndex<I> for $type {
169                    type Output = Self;
170                }
171            )*
172        };
173    }
174
175    impl_index!(Array);
176    impl_index!(Tensor);
177    impl_index!(SharedMemory);
178    impl_index_vec!(
179        i64, i32, i16, i8, f16, flex32, tf32, bf16, f32, f64, u64, u32, u16, u8
180    );
181
182    impl<E: CubeType, I: Index> CubeIndex<I> for Slice<E> {
183        type Output = E;
184    }
185
186    impl<E: CubeType, I: Index> CubeIndex<I> for SliceMut<E> {
187        type Output = E;
188    }
189}
190
191pub mod add_assign_array_op {
192    use self::ir::Arithmetic;
193    use super::*;
194    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
195
196    pub fn expand<A: CubeType + CubeIndex<u32>>(
197        scope: &mut Scope,
198        array: ExpandElementTyped<A>,
199        index: ExpandElementTyped<u32>,
200        value: ExpandElementTyped<A::Output>,
201    ) where
202        A::Output: CubeType + Sized,
203    {
204        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Add);
205    }
206}
207
208pub mod sub_assign_array_op {
209    use self::ir::Arithmetic;
210    use super::*;
211    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
212
213    pub fn expand<A: CubeType + CubeIndex<u32>>(
214        scope: &mut Scope,
215        array: ExpandElementTyped<A>,
216        index: ExpandElementTyped<u32>,
217        value: ExpandElementTyped<A::Output>,
218    ) where
219        A::Output: CubeType + Sized,
220    {
221        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Sub);
222    }
223}
224
225pub mod mul_assign_array_op {
226    use self::ir::Arithmetic;
227    use super::*;
228    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
229
230    pub fn expand<A: CubeType + CubeIndex<u32>>(
231        scope: &mut Scope,
232        array: ExpandElementTyped<A>,
233        index: ExpandElementTyped<u32>,
234        value: ExpandElementTyped<A::Output>,
235    ) where
236        A::Output: CubeType + Sized,
237    {
238        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Mul);
239    }
240}
241
242pub mod div_assign_array_op {
243    use self::ir::Arithmetic;
244    use super::*;
245    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
246
247    pub fn expand<A: CubeType + CubeIndex<u32>>(
248        scope: &mut Scope,
249        array: ExpandElementTyped<A>,
250        index: ExpandElementTyped<u32>,
251        value: ExpandElementTyped<A::Output>,
252    ) where
253        A::Output: CubeType + Sized,
254    {
255        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Div);
256    }
257}
258
259pub mod rem_assign_array_op {
260    use self::ir::Arithmetic;
261    use super::*;
262    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
263
264    pub fn expand<A: CubeType + CubeIndex<u32>>(
265        scope: &mut Scope,
266        array: ExpandElementTyped<A>,
267        index: ExpandElementTyped<u32>,
268        value: ExpandElementTyped<A::Output>,
269    ) where
270        A::Output: CubeType + Sized,
271    {
272        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Modulo);
273    }
274}
275
276pub mod bitor_assign_array_op {
277    use super::*;
278    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
279
280    pub fn expand<A: CubeType + CubeIndex<u32>>(
281        scope: &mut Scope,
282        array: ExpandElementTyped<A>,
283        index: ExpandElementTyped<u32>,
284        value: ExpandElementTyped<A::Output>,
285    ) where
286        A::Output: CubeType + Sized,
287    {
288        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseOr);
289    }
290}
291
292pub mod bitand_assign_array_op {
293    use super::*;
294    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
295
296    pub fn expand<A: CubeType + CubeIndex<u32>>(
297        scope: &mut Scope,
298        array: ExpandElementTyped<A>,
299        index: ExpandElementTyped<u32>,
300        value: ExpandElementTyped<A::Output>,
301    ) where
302        A::Output: CubeType + Sized,
303    {
304        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseAnd);
305    }
306}
307
308pub mod bitxor_assign_array_op {
309    use super::*;
310    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
311
312    pub fn expand<A: CubeType + CubeIndex<u32>>(
313        scope: &mut Scope,
314        array: ExpandElementTyped<A>,
315        index: ExpandElementTyped<u32>,
316        value: ExpandElementTyped<A::Output>,
317    ) where
318        A::Output: CubeType + Sized,
319    {
320        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseXor);
321    }
322}
323
324pub mod shl_assign_array_op {
325
326    use super::*;
327    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
328
329    pub fn expand<A: CubeType + CubeIndex<u32>>(
330        scope: &mut Scope,
331        array: ExpandElementTyped<A>,
332        index: ExpandElementTyped<u32>,
333        value: ExpandElementTyped<u32>,
334    ) where
335        A::Output: CubeType + Sized,
336    {
337        array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftLeft);
338    }
339}
340
341pub mod shr_assign_array_op {
342
343    use super::*;
344    use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
345
346    pub fn expand<A: CubeType + CubeIndex<u32>>(
347        scope: &mut Scope,
348        array: ExpandElementTyped<A>,
349        index: ExpandElementTyped<u32>,
350        value: ExpandElementTyped<u32>,
351    ) where
352        A::Output: CubeType + Sized,
353    {
354        array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftRight);
355    }
356}
357
358pub mod add_assign_op {
359    use std::ops::AddAssign;
360
361    use self::ir::Arithmetic;
362    use crate::{
363        frontend::operation::base::assign_op_expand,
364        prelude::{CubeType, ExpandElementTyped},
365    };
366
367    use super::*;
368
369    pub fn expand<C: CubeType + AddAssign>(
370        scope: &mut Scope,
371        lhs: ExpandElementTyped<C>,
372        rhs: ExpandElementTyped<C>,
373    ) -> ExpandElementTyped<C> {
374        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
375    }
376}
377
378pub mod sub_assign_op {
379    use self::ir::Arithmetic;
380    use super::*;
381    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
382
383    pub fn expand<C: CubeType>(
384        scope: &mut Scope,
385        lhs: ExpandElementTyped<C>,
386        rhs: ExpandElementTyped<C>,
387    ) -> ExpandElement {
388        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub)
389    }
390}
391
392pub mod mul_assign_op {
393    use self::ir::Arithmetic;
394    use super::*;
395    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
396
397    pub fn expand<C: CubeType>(
398        scope: &mut Scope,
399        lhs: ExpandElementTyped<C>,
400        rhs: ExpandElementTyped<C>,
401    ) -> ExpandElement {
402        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul)
403    }
404}
405
406pub mod div_assign_op {
407    use self::ir::Arithmetic;
408    use super::*;
409    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
410
411    pub fn expand<C: CubeType>(
412        scope: &mut Scope,
413        lhs: ExpandElementTyped<C>,
414        rhs: ExpandElementTyped<C>,
415    ) -> ExpandElement {
416        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div)
417    }
418}
419
420pub mod rem_assign_op {
421    use self::ir::Arithmetic;
422    use super::*;
423    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
424
425    pub fn expand<C: CubeType>(
426        scope: &mut Scope,
427        lhs: ExpandElementTyped<C>,
428        rhs: ExpandElementTyped<C>,
429    ) -> ExpandElement {
430        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo)
431    }
432}
433
434pub mod bitor_assign_op {
435
436    use super::*;
437    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
438
439    pub fn expand<C: CubeType>(
440        scope: &mut Scope,
441        lhs: ExpandElementTyped<C>,
442        rhs: ExpandElementTyped<C>,
443    ) -> ExpandElement {
444        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr)
445    }
446}
447
448pub mod bitand_assign_op {
449
450    use super::*;
451    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
452
453    pub fn expand<C: CubeType>(
454        scope: &mut Scope,
455        lhs: ExpandElementTyped<C>,
456        rhs: ExpandElementTyped<C>,
457    ) -> ExpandElement {
458        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd)
459    }
460}
461
462pub mod bitxor_assign_op {
463
464    use super::*;
465    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
466
467    pub fn expand<C: CubeType>(
468        scope: &mut Scope,
469        lhs: ExpandElementTyped<C>,
470        rhs: ExpandElementTyped<C>,
471    ) -> ExpandElement {
472        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor)
473    }
474}
475
476pub mod shl_assign_op {
477
478    use super::*;
479    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
480
481    pub fn expand<C: CubeType>(
482        scope: &mut Scope,
483        lhs: ExpandElementTyped<C>,
484        rhs: ExpandElementTyped<u32>,
485    ) -> ExpandElement {
486        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft)
487    }
488}
489
490pub mod shr_assign_op {
491    use super::*;
492    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
493
494    pub fn expand<C: CubeType>(
495        scope: &mut Scope,
496        lhs: ExpandElementTyped<C>,
497        rhs: ExpandElementTyped<u32>,
498    ) -> ExpandElement {
499        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight)
500    }
501}
502
503pub mod add_assign {
504    use cubecl_ir::Arithmetic;
505
506    use super::*;
507    use crate::prelude::{CubePrimitive, ExpandElementTyped, assign_op_expand};
508
509    pub fn expand<C: CubePrimitive>(
510        scope: &mut Scope,
511        lhs: ExpandElementTyped<C>,
512        rhs: ExpandElementTyped<C>,
513    ) -> ExpandElementTyped<C> {
514        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
515    }
516}