Skip to main content

cubecl_core/frontend/operation/
assignation.rs

1use cubecl_ir::{Bitwise, ManagedVariable, Operator, Scope};
2
3use crate::ir;
4use crate::{
5    frontend::{Array, SharedMemory, Tensor},
6    prelude::*,
7};
8
9pub mod cast {
10    use ir::Instruction;
11
12    use crate::prelude::NativeExpand;
13
14    use self::ir::UnaryOperator;
15
16    use super::*;
17
18    pub fn expand<From: CubeType, To: CubeType>(
19        scope: &mut Scope,
20        input: NativeExpand<From>,
21        output: NativeExpand<To>,
22    ) {
23        scope.register(Instruction::new(
24            Operator::Cast(UnaryOperator {
25                input: *input.expand,
26            }),
27            *output.expand,
28        ));
29    }
30}
31
32pub mod assign {
33    use ir::{Instruction, Operation};
34
35    use crate::prelude::NativeExpand;
36
37    use super::*;
38
39    /// Expand the assign operation.
40    ///
41    /// If you want to assign to a manually initialized const variable, look into
42    /// [`expand_no_check()`].
43    pub fn expand<C: CubeType>(scope: &mut Scope, input: NativeExpand<C>, output: NativeExpand<C>) {
44        let output = *output.expand;
45        let input = *input.expand;
46
47        if output.is_immutable() {
48            panic!("Can't assign a value to a const variable. Try to use `RuntimeCell`.");
49        }
50
51        scope.register(Instruction::new(Operation::Copy(input), output));
52    }
53    /// Expand the assign operation without any check.
54    ///
55    /// You can't assign to a const variable with this [`expand()`].
56    pub fn expand_no_check<C: CubeType>(
57        scope: &mut Scope,
58        input: NativeExpand<C>,
59        output: NativeExpand<C>,
60    ) {
61        let output = *output.expand;
62        let input = *input.expand;
63
64        scope.register(Instruction::new(Operation::Copy(input), output));
65    }
66
67    pub fn expand_element(scope: &mut Scope, input: ManagedVariable, output: ManagedVariable) {
68        if output.is_immutable() {
69            panic!("Can't assign a value to a const variable. Try to use `RuntimeCell`.");
70        }
71
72        scope.register(Instruction::new(Operation::Copy(*input), *output));
73    }
74}
75
76pub mod index_assign {
77    use super::*;
78
79    pub fn expand<A: CubeIndexMutExpand<Output = NativeExpand<V>>, V: CubePrimitive>(
80        scope: &mut Scope,
81        expand: A,
82        index: A::Idx,
83        value: NativeExpand<V>,
84    ) {
85        expand.expand_index_mut(scope, index, value)
86    }
87
88    macro_rules! impl_index {
89        ($type:ident) => {
90            impl<E: CubePrimitive> CubeIndexMut for $type<E> {}
91
92            impl<E: CubePrimitive> CubeIndexMutExpand for NativeExpand<$type<E>> {
93                fn expand_index_mut(
94                    self,
95                    scope: &mut Scope,
96                    index: NativeExpand<usize>,
97                    value: Self::Output,
98                ) {
99                    expand_index_assign_native::<$type<E>>(scope, self, index, value, None, true);
100                }
101            }
102        };
103    }
104
105    impl<E: Scalar, N: Size> CubeIndexMut for Vector<E, N> {}
106
107    impl<E: Scalar, N: Size> CubeIndexMutExpand for NativeExpand<Vector<E, N>> {
108        fn expand_index_mut(
109            self,
110            scope: &mut Scope,
111            index: NativeExpand<usize>,
112            value: Self::Output,
113        ) {
114            expand_index_assign_native::<Vector<E, N>>(scope, self, index, value, None, true);
115        }
116    }
117
118    impl_index!(Array);
119    impl_index!(Tensor);
120    impl_index!(SharedMemory);
121}
122
123pub mod index {
124    use super::*;
125
126    pub fn expand<A: CubeIndexExpand<Output = NativeExpand<V>>, V: CubeType>(
127        scope: &mut Scope,
128        expand: A,
129        index: A::Idx,
130    ) -> NativeExpand<V> {
131        expand.expand_index(scope, index)
132    }
133
134    pub fn expand_with<A: CubeIndexExpand<Output = NativeExpand<V>>, V: CubeType>(
135        scope: &mut Scope,
136        expand: A,
137        index: A::Idx,
138    ) -> NativeExpand<V> {
139        expand.expand_index(scope, index)
140    }
141
142    macro_rules! impl_index {
143        ($type:ident) => {
144            impl<E: CubePrimitive> CubeIndex for $type<E> {
145                type Output = E;
146                type Idx = usize;
147            }
148
149            impl<E: CubePrimitive> CubeIndexExpand for NativeExpand<$type<E>> {
150                type Output = NativeExpand<E>;
151                type Idx = NativeExpand<usize>;
152
153                fn expand_index(
154                    self,
155                    scope: &mut Scope,
156                    index: NativeExpand<usize>,
157                ) -> Self::Output {
158                    expand_index_native(scope, self, index, None, true)
159                }
160                fn expand_index_unchecked(
161                    self,
162                    scope: &mut Scope,
163                    index: NativeExpand<usize>,
164                ) -> Self::Output {
165                    expand_index_native(scope, self, index, None, false)
166                }
167            }
168        };
169    }
170
171    impl<E: Scalar, N: Size> CubeIndex for Vector<E, N> {
172        type Output = E;
173        type Idx = usize;
174    }
175    impl<E: Scalar, N: Size> CubeIndexExpand for NativeExpand<Vector<E, N>> {
176        type Output = NativeExpand<E>;
177        type Idx = NativeExpand<usize>;
178        fn expand_index(self, scope: &mut Scope, index: NativeExpand<usize>) -> Self::Output {
179            expand_index_native(scope, self, index, None, true)
180        }
181        fn expand_index_unchecked(
182            self,
183            scope: &mut Scope,
184            index: NativeExpand<usize>,
185        ) -> Self::Output {
186            expand_index_native(scope, self, index, None, false)
187        }
188    }
189
190    impl_index!(Array);
191    impl_index!(Tensor);
192    impl_index!(SharedMemory);
193}
194
195pub mod index_unchecked {
196    use super::*;
197    use crate::prelude::{CubeIndexExpand, NativeExpand};
198
199    pub fn expand<A: CubeIndexExpand<Output = NativeExpand<V>>, V: CubeType>(
200        scope: &mut Scope,
201        expand: A,
202        index: A::Idx,
203    ) -> NativeExpand<V> {
204        expand.expand_index_unchecked(scope, index)
205    }
206}
207
208pub mod add_assign_array_op {
209    use self::ir::Arithmetic;
210    use super::*;
211    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
212
213    pub fn expand<A: CubeType + CubeIndex>(
214        scope: &mut Scope,
215        array: NativeExpand<A>,
216        index: NativeExpand<usize>,
217        value: NativeExpand<A::Output>,
218    ) where
219        A::Output: CubeType + Sized,
220    {
221        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Add);
222    }
223}
224
225pub mod sub_assign_array_op {
226    use self::ir::Arithmetic;
227    use super::*;
228    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
229
230    pub fn expand<A: CubeType + CubeIndex>(
231        scope: &mut Scope,
232        array: NativeExpand<A>,
233        index: NativeExpand<usize>,
234        value: NativeExpand<A::Output>,
235    ) where
236        A::Output: CubeType + Sized,
237    {
238        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Sub);
239    }
240}
241
242pub mod mul_assign_array_op {
243    use self::ir::Arithmetic;
244    use super::*;
245    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
246
247    pub fn expand<A: CubeType + CubeIndex>(
248        scope: &mut Scope,
249        array: NativeExpand<A>,
250        index: NativeExpand<usize>,
251        value: NativeExpand<A::Output>,
252    ) where
253        A::Output: CubeType + Sized,
254    {
255        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Mul);
256    }
257}
258
259pub mod div_assign_array_op {
260    use self::ir::Arithmetic;
261    use super::*;
262    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
263
264    pub fn expand<A: CubeType + CubeIndex>(
265        scope: &mut Scope,
266        array: NativeExpand<A>,
267        index: NativeExpand<usize>,
268        value: NativeExpand<A::Output>,
269    ) where
270        A::Output: CubeType + Sized,
271    {
272        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Div);
273    }
274}
275
276pub mod rem_assign_array_op {
277    use self::ir::Arithmetic;
278    use super::*;
279    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
280
281    pub fn expand<A: CubeType + CubeIndex>(
282        scope: &mut Scope,
283        array: NativeExpand<A>,
284        index: NativeExpand<usize>,
285        value: NativeExpand<A::Output>,
286    ) where
287        A::Output: CubeType + Sized,
288    {
289        array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Modulo);
290    }
291}
292
293pub mod bitor_assign_array_op {
294    use super::*;
295    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
296
297    pub fn expand<A: CubeType + CubeIndex>(
298        scope: &mut Scope,
299        array: NativeExpand<A>,
300        index: NativeExpand<usize>,
301        value: NativeExpand<A::Output>,
302    ) where
303        A::Output: CubeType + Sized,
304    {
305        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseOr);
306    }
307}
308
309pub mod bitand_assign_array_op {
310    use super::*;
311    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
312
313    pub fn expand<A: CubeType + CubeIndex>(
314        scope: &mut Scope,
315        array: NativeExpand<A>,
316        index: NativeExpand<usize>,
317        value: NativeExpand<A::Output>,
318    ) where
319        A::Output: CubeType + Sized,
320    {
321        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseAnd);
322    }
323}
324
325pub mod bitxor_assign_array_op {
326    use super::*;
327    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
328
329    pub fn expand<A: CubeType + CubeIndex>(
330        scope: &mut Scope,
331        array: NativeExpand<A>,
332        index: NativeExpand<usize>,
333        value: NativeExpand<A::Output>,
334    ) where
335        A::Output: CubeType + Sized,
336    {
337        array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseXor);
338    }
339}
340
341pub mod shl_assign_array_op {
342
343    use super::*;
344    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
345
346    pub fn expand<A: CubeType + CubeIndex>(
347        scope: &mut Scope,
348        array: NativeExpand<A>,
349        index: NativeExpand<usize>,
350        value: NativeExpand<u32>,
351    ) where
352        A::Output: CubeType + Sized,
353    {
354        array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftLeft);
355    }
356}
357
358pub mod shr_assign_array_op {
359
360    use super::*;
361    use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
362
363    pub fn expand<A: CubeType + CubeIndex>(
364        scope: &mut Scope,
365        array: NativeExpand<A>,
366        index: NativeExpand<usize>,
367        value: NativeExpand<u32>,
368    ) where
369        A::Output: CubeType + Sized,
370    {
371        array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftRight);
372    }
373}
374
375pub mod add_assign_op {
376    use core::ops::AddAssign;
377
378    use self::ir::Arithmetic;
379    use crate::{
380        frontend::operation::base::assign_op_expand,
381        prelude::{CubeType, NativeExpand},
382    };
383
384    use super::*;
385
386    pub fn expand<C: CubeType + AddAssign>(
387        scope: &mut Scope,
388        lhs: NativeExpand<C>,
389        rhs: NativeExpand<C>,
390    ) -> NativeExpand<C> {
391        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
392    }
393}
394
395pub mod sub_assign_op {
396    use self::ir::Arithmetic;
397    use super::*;
398    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
399
400    pub fn expand<C: CubeType>(
401        scope: &mut Scope,
402        lhs: NativeExpand<C>,
403        rhs: NativeExpand<C>,
404    ) -> ManagedVariable {
405        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub)
406    }
407}
408
409pub mod mul_assign_op {
410    use self::ir::Arithmetic;
411    use super::*;
412    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
413
414    pub fn expand<C: CubeType>(
415        scope: &mut Scope,
416        lhs: NativeExpand<C>,
417        rhs: NativeExpand<C>,
418    ) -> ManagedVariable {
419        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul)
420    }
421}
422
423pub mod div_assign_op {
424    use self::ir::Arithmetic;
425    use super::*;
426    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
427
428    pub fn expand<C: CubeType>(
429        scope: &mut Scope,
430        lhs: NativeExpand<C>,
431        rhs: NativeExpand<C>,
432    ) -> ManagedVariable {
433        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div)
434    }
435}
436
437pub mod rem_assign_op {
438    use self::ir::Arithmetic;
439    use super::*;
440    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
441
442    pub fn expand<C: CubeType>(
443        scope: &mut Scope,
444        lhs: NativeExpand<C>,
445        rhs: NativeExpand<C>,
446    ) -> ManagedVariable {
447        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo)
448    }
449}
450
451pub mod bitor_assign_op {
452
453    use super::*;
454    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
455
456    pub fn expand<C: CubeType>(
457        scope: &mut Scope,
458        lhs: NativeExpand<C>,
459        rhs: NativeExpand<C>,
460    ) -> ManagedVariable {
461        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr)
462    }
463}
464
465pub mod bitand_assign_op {
466
467    use super::*;
468    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
469
470    pub fn expand<C: CubeType>(
471        scope: &mut Scope,
472        lhs: NativeExpand<C>,
473        rhs: NativeExpand<C>,
474    ) -> ManagedVariable {
475        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd)
476    }
477}
478
479pub mod bitxor_assign_op {
480
481    use super::*;
482    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
483
484    pub fn expand<C: CubeType>(
485        scope: &mut Scope,
486        lhs: NativeExpand<C>,
487        rhs: NativeExpand<C>,
488    ) -> ManagedVariable {
489        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor)
490    }
491}
492
493pub mod shl_assign_op {
494
495    use super::*;
496    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
497
498    pub fn expand<C: CubeType>(
499        scope: &mut Scope,
500        lhs: NativeExpand<C>,
501        rhs: NativeExpand<u32>,
502    ) -> ManagedVariable {
503        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft)
504    }
505}
506
507pub mod shr_assign_op {
508    use super::*;
509    use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
510
511    pub fn expand<C: CubeType>(
512        scope: &mut Scope,
513        lhs: NativeExpand<C>,
514        rhs: NativeExpand<u32>,
515    ) -> ManagedVariable {
516        assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight)
517    }
518}
519
520pub mod add_assign {
521    use cubecl_ir::Arithmetic;
522
523    use super::*;
524    use crate::prelude::{CubePrimitive, NativeExpand, assign_op_expand};
525
526    pub fn expand<C: CubePrimitive>(
527        scope: &mut Scope,
528        lhs: NativeExpand<C>,
529        rhs: NativeExpand<C>,
530    ) -> NativeExpand<C> {
531        assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
532    }
533}