cubecl_core/frontend/operation/
assignation.rs

1use half::{bf16, f16};
2
3use crate::{
4    frontend::{Array, CubeContext, ExpandElement, SharedMemory, Tensor},
5    prelude::{CubeIndex, CubeIndexMut, CubeType},
6};
7use crate::{ir, prelude::Index};
8
9pub mod cast {
10    use ir::Instruction;
11
12    use crate::prelude::ExpandElementTyped;
13
14    use self::ir::{Operator, UnaryOperator};
15
16    use super::*;
17
18    pub fn expand<C: CubeType>(
19        context: &mut CubeContext,
20        input: ExpandElementTyped<C>,
21        output: ExpandElementTyped<C>,
22    ) {
23        context.register(Instruction::new(
24            Operator::Cast(UnaryOperator {
25                input: *input.expand,
26            }),
27            *output.expand,
28        ));
29    }
30}
31
32pub mod assign {
33    use ir::{Instruction, Operation};
34
35    use crate::prelude::ExpandElementTyped;
36
37    use super::*;
38
39    pub fn expand<C: CubeType>(
40        context: &mut CubeContext,
41        input: ExpandElementTyped<C>,
42        output: ExpandElementTyped<C>,
43    ) {
44        context.register(Instruction::new(
45            Operation::Copy(*input.expand),
46            *output.expand,
47        ));
48    }
49}
50
51pub mod index_assign {
52    use ir::{Instruction, UIntKind, VariableKind};
53
54    use crate::{
55        flex32,
56        frontend::CubeType,
57        prelude::{ExpandElementTyped, SliceMut},
58        tf32,
59    };
60
61    use self::ir::{BinaryOperator, Operator, Variable};
62
63    use super::*;
64
65    pub fn expand<A: CubeType + CubeIndex<u32>>(
66        context: &mut CubeContext,
67        array: ExpandElementTyped<A>,
68        index: ExpandElementTyped<u32>,
69        value: ExpandElementTyped<A::Output>,
70    ) where
71        A::Output: CubeType + Sized,
72    {
73        let index: Variable = index.expand.into();
74        let index = match index.kind {
75            VariableKind::ConstantScalar(value) => {
76                Variable::constant(ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32))
77            }
78            _ => index,
79        };
80        context.register(Instruction::new(
81            Operator::IndexAssign(BinaryOperator {
82                lhs: index,
83                rhs: value.expand.into(),
84            }),
85            array.expand.into(),
86        ));
87    }
88
89    macro_rules! impl_index {
90        ($type:ident) => {
91            impl<E: CubeType, I: Index> CubeIndexMut<I> for $type<E> {}
92        };
93    }
94    macro_rules! impl_index_vec {
95        ($($type:ident),*) => {
96            $(
97                impl<I: Index> CubeIndexMut<I> for $type {}
98            )*
99        };
100    }
101
102    impl_index!(Array);
103    impl_index!(Tensor);
104    impl_index!(SharedMemory);
105    impl_index_vec!(i64, i32, i16, i8, f16, bf16, flex32, tf32, f32, f64, u64, u32, u16, u8);
106
107    impl<E: CubeType, I: Index> CubeIndexMut<I> for SliceMut<E> {}
108}
109
110pub mod index {
111    use ir::{UIntKind, VariableKind};
112
113    use crate::{
114        flex32,
115        frontend::{
116            operation::base::{binary_expand, binary_expand_no_vec},
117            CubeType,
118        },
119        prelude::{ExpandElementTyped, Slice, SliceMut},
120        tf32,
121    };
122
123    use self::ir::{Operator, Variable};
124
125    use super::*;
126
127    pub fn expand<A: CubeType + CubeIndex<ExpandElementTyped<u32>>>(
128        context: &mut CubeContext,
129        array: ExpandElementTyped<A>,
130        index: ExpandElementTyped<u32>,
131    ) -> ExpandElementTyped<A::Output>
132    where
133        A::Output: CubeType + Sized,
134    {
135        let index: ExpandElement = index.into();
136        let index_var: Variable = *index;
137        let index = match index_var.kind {
138            VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
139                ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
140            )),
141            _ => index,
142        };
143        let array: ExpandElement = array.into();
144        let var: Variable = *array;
145        let var = match var.kind {
146            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
147                binary_expand_no_vec(context, array, index, Operator::Index)
148            }
149            _ => binary_expand(context, array, index, Operator::Index),
150        };
151
152        ExpandElementTyped::new(var)
153    }
154
155    macro_rules! impl_index {
156        ($type:ident) => {
157            impl<E: CubeType, I: Index> CubeIndex<I> for $type<E> {
158                type Output = E;
159            }
160        };
161    }
162    macro_rules! impl_index_vec {
163        ($($type:ident),*) => {
164            $(
165                impl<I: Index> CubeIndex<I> for $type {
166                    type Output = Self;
167                }
168            )*
169        };
170    }
171
172    impl_index!(Array);
173    impl_index!(Tensor);
174    impl_index!(SharedMemory);
175    impl_index_vec!(i64, i32, i16, i8, f16, flex32, tf32, bf16, f32, f64, u64, u32, u16, u8);
176
177    impl<E: CubeType, I: Index> CubeIndex<I> for Slice<E> {
178        type Output = E;
179    }
180
181    impl<E: CubeType, I: Index> CubeIndex<I> for SliceMut<E> {
182        type Output = E;
183    }
184}
185
186pub mod add_assign_array_op {
187    use self::ir::Operator;
188    use super::*;
189    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
190
191    pub fn expand<A: CubeType + CubeIndex<u32>>(
192        context: &mut CubeContext,
193        array: ExpandElementTyped<A>,
194        index: ExpandElementTyped<u32>,
195        value: ExpandElementTyped<A::Output>,
196    ) where
197        A::Output: CubeType + Sized,
198    {
199        array_assign_binary_op_expand(context, array, index, value, Operator::Add);
200    }
201}
202
203pub mod sub_assign_array_op {
204    use self::ir::Operator;
205    use super::*;
206    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
207
208    pub fn expand<A: CubeType + CubeIndex<u32>>(
209        context: &mut CubeContext,
210        array: ExpandElementTyped<A>,
211        index: ExpandElementTyped<u32>,
212        value: ExpandElementTyped<A::Output>,
213    ) where
214        A::Output: CubeType + Sized,
215    {
216        array_assign_binary_op_expand(context, array, index, value, Operator::Sub);
217    }
218}
219
220pub mod mul_assign_array_op {
221    use self::ir::Operator;
222    use super::*;
223    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
224
225    pub fn expand<A: CubeType + CubeIndex<u32>>(
226        context: &mut CubeContext,
227        array: ExpandElementTyped<A>,
228        index: ExpandElementTyped<u32>,
229        value: ExpandElementTyped<A::Output>,
230    ) where
231        A::Output: CubeType + Sized,
232    {
233        array_assign_binary_op_expand(context, array, index, value, Operator::Mul);
234    }
235}
236
237pub mod div_assign_array_op {
238    use self::ir::Operator;
239    use super::*;
240    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
241
242    pub fn expand<A: CubeType + CubeIndex<u32>>(
243        context: &mut CubeContext,
244        array: ExpandElementTyped<A>,
245        index: ExpandElementTyped<u32>,
246        value: ExpandElementTyped<A::Output>,
247    ) where
248        A::Output: CubeType + Sized,
249    {
250        array_assign_binary_op_expand(context, array, index, value, Operator::Div);
251    }
252}
253
254pub mod rem_assign_array_op {
255    use self::ir::Operator;
256    use super::*;
257    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
258
259    pub fn expand<A: CubeType + CubeIndex<u32>>(
260        context: &mut CubeContext,
261        array: ExpandElementTyped<A>,
262        index: ExpandElementTyped<u32>,
263        value: ExpandElementTyped<A::Output>,
264    ) where
265        A::Output: CubeType + Sized,
266    {
267        array_assign_binary_op_expand(context, array, index, value, Operator::Modulo);
268    }
269}
270
271pub mod bitor_assign_array_op {
272    use self::ir::Operator;
273    use super::*;
274    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
275
276    pub fn expand<A: CubeType + CubeIndex<u32>>(
277        context: &mut CubeContext,
278        array: ExpandElementTyped<A>,
279        index: ExpandElementTyped<u32>,
280        value: ExpandElementTyped<A::Output>,
281    ) where
282        A::Output: CubeType + Sized,
283    {
284        array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseOr);
285    }
286}
287
288pub mod bitand_assign_array_op {
289    use self::ir::Operator;
290    use super::*;
291    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
292
293    pub fn expand<A: CubeType + CubeIndex<u32>>(
294        context: &mut CubeContext,
295        array: ExpandElementTyped<A>,
296        index: ExpandElementTyped<u32>,
297        value: ExpandElementTyped<A::Output>,
298    ) where
299        A::Output: CubeType + Sized,
300    {
301        array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseAnd);
302    }
303}
304
305pub mod bitxor_assign_array_op {
306    use self::ir::Operator;
307    use super::*;
308    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
309
310    pub fn expand<A: CubeType + CubeIndex<u32>>(
311        context: &mut CubeContext,
312        array: ExpandElementTyped<A>,
313        index: ExpandElementTyped<u32>,
314        value: ExpandElementTyped<A::Output>,
315    ) where
316        A::Output: CubeType + Sized,
317    {
318        array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseXor);
319    }
320}
321
322pub mod shl_assign_array_op {
323    use self::ir::Operator;
324    use super::*;
325    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
326
327    pub fn expand<A: CubeType + CubeIndex<u32>>(
328        context: &mut CubeContext,
329        array: ExpandElementTyped<A>,
330        index: ExpandElementTyped<u32>,
331        value: ExpandElementTyped<u32>,
332    ) where
333        A::Output: CubeType + Sized,
334    {
335        array_assign_binary_op_expand(context, array, index, value, Operator::ShiftLeft);
336    }
337}
338
339pub mod shr_assign_array_op {
340    use self::ir::Operator;
341    use super::*;
342    use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
343
344    pub fn expand<A: CubeType + CubeIndex<u32>>(
345        context: &mut CubeContext,
346        array: ExpandElementTyped<A>,
347        index: ExpandElementTyped<u32>,
348        value: ExpandElementTyped<u32>,
349    ) where
350        A::Output: CubeType + Sized,
351    {
352        array_assign_binary_op_expand(context, array, index, value, Operator::ShiftRight);
353    }
354}
355
356pub mod add_assign_op {
357    use std::ops::AddAssign;
358
359    use self::ir::Operator;
360    use crate::{
361        frontend::operation::base::assign_op_expand,
362        prelude::{CubeType, ExpandElementTyped},
363    };
364
365    use super::*;
366
367    pub fn expand<C: CubeType + AddAssign>(
368        context: &mut CubeContext,
369        lhs: ExpandElementTyped<C>,
370        rhs: ExpandElementTyped<C>,
371    ) -> ExpandElementTyped<C> {
372        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Add).into()
373    }
374}
375
376pub mod sub_assign_op {
377    use self::ir::Operator;
378    use super::*;
379    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
380
381    pub fn expand<C: CubeType>(
382        context: &mut CubeContext,
383        lhs: ExpandElementTyped<C>,
384        rhs: ExpandElementTyped<C>,
385    ) -> ExpandElement {
386        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Sub)
387    }
388}
389
390pub mod mul_assign_op {
391    use self::ir::Operator;
392    use super::*;
393    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
394
395    pub fn expand<C: CubeType>(
396        context: &mut CubeContext,
397        lhs: ExpandElementTyped<C>,
398        rhs: ExpandElementTyped<C>,
399    ) -> ExpandElement {
400        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Mul)
401    }
402}
403
404pub mod div_assign_op {
405    use self::ir::Operator;
406    use super::*;
407    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
408
409    pub fn expand<C: CubeType>(
410        context: &mut CubeContext,
411        lhs: ExpandElementTyped<C>,
412        rhs: ExpandElementTyped<C>,
413    ) -> ExpandElement {
414        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Div)
415    }
416}
417
418pub mod rem_assign_op {
419    use self::ir::Operator;
420    use super::*;
421    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
422
423    pub fn expand<C: CubeType>(
424        context: &mut CubeContext,
425        lhs: ExpandElementTyped<C>,
426        rhs: ExpandElementTyped<C>,
427    ) -> ExpandElement {
428        assign_op_expand(context, lhs.into(), rhs.into(), Operator::Modulo)
429    }
430}
431
432pub mod bitor_assign_op {
433    use self::ir::Operator;
434    use super::*;
435    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
436
437    pub fn expand<C: CubeType>(
438        context: &mut CubeContext,
439        lhs: ExpandElementTyped<C>,
440        rhs: ExpandElementTyped<C>,
441    ) -> ExpandElement {
442        assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseOr)
443    }
444}
445
446pub mod bitand_assign_op {
447    use self::ir::Operator;
448    use super::*;
449    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
450
451    pub fn expand<C: CubeType>(
452        context: &mut CubeContext,
453        lhs: ExpandElementTyped<C>,
454        rhs: ExpandElementTyped<C>,
455    ) -> ExpandElement {
456        assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseAnd)
457    }
458}
459
460pub mod bitxor_assign_op {
461    use self::ir::Operator;
462    use super::*;
463    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
464
465    pub fn expand<C: CubeType>(
466        context: &mut CubeContext,
467        lhs: ExpandElementTyped<C>,
468        rhs: ExpandElementTyped<C>,
469    ) -> ExpandElement {
470        assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseXor)
471    }
472}
473
474pub mod shl_assign_op {
475    use self::ir::Operator;
476    use super::*;
477    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
478
479    pub fn expand<C: CubeType>(
480        context: &mut CubeContext,
481        lhs: ExpandElementTyped<C>,
482        rhs: ExpandElementTyped<u32>,
483    ) -> ExpandElement {
484        assign_op_expand(context, lhs.into(), rhs.into(), Operator::ShiftLeft)
485    }
486}
487
488pub mod shr_assign_op {
489    use self::ir::Operator;
490    use super::*;
491    use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
492
493    pub fn expand<C: CubeType>(
494        context: &mut CubeContext,
495        lhs: ExpandElementTyped<C>,
496        rhs: ExpandElementTyped<u32>,
497    ) -> ExpandElement {
498        assign_op_expand(context, lhs.into(), rhs.into(), Operator::ShiftRight)
499    }
500}