1use cubecl_ir::{Bitwise, ManagedVariable, Operator, Scope};
2
3use crate::ir;
4use crate::{
5 frontend::{Array, SharedMemory, Tensor},
6 prelude::*,
7};
8
9pub mod cast {
10 use ir::Instruction;
11
12 use crate::prelude::NativeExpand;
13
14 use self::ir::UnaryOperator;
15
16 use super::*;
17
18 pub fn expand<From: CubeType, To: CubeType>(
19 scope: &mut Scope,
20 input: NativeExpand<From>,
21 output: NativeExpand<To>,
22 ) {
23 scope.register(Instruction::new(
24 Operator::Cast(UnaryOperator {
25 input: *input.expand,
26 }),
27 *output.expand,
28 ));
29 }
30}
31
32pub mod assign {
33 use ir::{Instruction, Operation};
34
35 use crate::prelude::NativeExpand;
36
37 use super::*;
38
39 pub fn expand<C: CubeType>(scope: &mut Scope, input: NativeExpand<C>, output: NativeExpand<C>) {
44 let output = *output.expand;
45 let input = *input.expand;
46
47 if output.is_immutable() {
48 panic!("Can't assign a value to a const variable. Try to use `RuntimeCell`.");
49 }
50
51 scope.register(Instruction::new(Operation::Copy(input), output));
52 }
53 pub fn expand_no_check<C: CubeType>(
57 scope: &mut Scope,
58 input: NativeExpand<C>,
59 output: NativeExpand<C>,
60 ) {
61 let output = *output.expand;
62 let input = *input.expand;
63
64 scope.register(Instruction::new(Operation::Copy(input), output));
65 }
66
67 pub fn expand_element(scope: &mut Scope, input: ManagedVariable, output: ManagedVariable) {
68 if output.is_immutable() {
69 panic!("Can't assign a value to a const variable. Try to use `RuntimeCell`.");
70 }
71
72 scope.register(Instruction::new(Operation::Copy(*input), *output));
73 }
74}
75
76pub mod index_assign {
77 use super::*;
78
79 pub fn expand<A: CubeIndexMutExpand<Output = NativeExpand<V>>, V: CubePrimitive>(
80 scope: &mut Scope,
81 expand: A,
82 index: A::Idx,
83 value: NativeExpand<V>,
84 ) {
85 expand.expand_index_mut(scope, index, value)
86 }
87
88 macro_rules! impl_index {
89 ($type:ident) => {
90 impl<E: CubePrimitive> CubeIndexMut for $type<E> {}
91
92 impl<E: CubePrimitive> CubeIndexMutExpand for NativeExpand<$type<E>> {
93 fn expand_index_mut(
94 self,
95 scope: &mut Scope,
96 index: NativeExpand<usize>,
97 value: Self::Output,
98 ) {
99 expand_index_assign_native::<$type<E>>(scope, self, index, value, None, true);
100 }
101 }
102 };
103 }
104
105 impl<E: Scalar, N: Size> CubeIndexMut for Vector<E, N> {}
106
107 impl<E: Scalar, N: Size> CubeIndexMutExpand for NativeExpand<Vector<E, N>> {
108 fn expand_index_mut(
109 self,
110 scope: &mut Scope,
111 index: NativeExpand<usize>,
112 value: Self::Output,
113 ) {
114 expand_index_assign_native::<Vector<E, N>>(scope, self, index, value, None, true);
115 }
116 }
117
118 impl_index!(Array);
119 impl_index!(Tensor);
120 impl_index!(SharedMemory);
121}
122
123pub mod index {
124 use super::*;
125
126 pub fn expand<A: CubeIndexExpand<Output = NativeExpand<V>>, V: CubeType>(
127 scope: &mut Scope,
128 expand: A,
129 index: A::Idx,
130 ) -> NativeExpand<V> {
131 expand.expand_index(scope, index)
132 }
133
134 pub fn expand_with<A: CubeIndexExpand<Output = NativeExpand<V>>, V: CubeType>(
135 scope: &mut Scope,
136 expand: A,
137 index: A::Idx,
138 ) -> NativeExpand<V> {
139 expand.expand_index(scope, index)
140 }
141
142 macro_rules! impl_index {
143 ($type:ident) => {
144 impl<E: CubePrimitive> CubeIndex for $type<E> {
145 type Output = E;
146 type Idx = usize;
147 }
148
149 impl<E: CubePrimitive> CubeIndexExpand for NativeExpand<$type<E>> {
150 type Output = NativeExpand<E>;
151 type Idx = NativeExpand<usize>;
152
153 fn expand_index(
154 self,
155 scope: &mut Scope,
156 index: NativeExpand<usize>,
157 ) -> Self::Output {
158 expand_index_native(scope, self, index, None, true)
159 }
160 fn expand_index_unchecked(
161 self,
162 scope: &mut Scope,
163 index: NativeExpand<usize>,
164 ) -> Self::Output {
165 expand_index_native(scope, self, index, None, false)
166 }
167 }
168 };
169 }
170
171 impl<E: Scalar, N: Size> CubeIndex for Vector<E, N> {
172 type Output = E;
173 type Idx = usize;
174 }
175 impl<E: Scalar, N: Size> CubeIndexExpand for NativeExpand<Vector<E, N>> {
176 type Output = NativeExpand<E>;
177 type Idx = NativeExpand<usize>;
178 fn expand_index(self, scope: &mut Scope, index: NativeExpand<usize>) -> Self::Output {
179 expand_index_native(scope, self, index, None, true)
180 }
181 fn expand_index_unchecked(
182 self,
183 scope: &mut Scope,
184 index: NativeExpand<usize>,
185 ) -> Self::Output {
186 expand_index_native(scope, self, index, None, false)
187 }
188 }
189
190 impl_index!(Array);
191 impl_index!(Tensor);
192 impl_index!(SharedMemory);
193}
194
195pub mod index_unchecked {
196 use super::*;
197 use crate::prelude::{CubeIndexExpand, NativeExpand};
198
199 pub fn expand<A: CubeIndexExpand<Output = NativeExpand<V>>, V: CubeType>(
200 scope: &mut Scope,
201 expand: A,
202 index: A::Idx,
203 ) -> NativeExpand<V> {
204 expand.expand_index_unchecked(scope, index)
205 }
206}
207
208pub mod add_assign_array_op {
209 use self::ir::Arithmetic;
210 use super::*;
211 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
212
213 pub fn expand<A: CubeType + CubeIndex>(
214 scope: &mut Scope,
215 array: NativeExpand<A>,
216 index: NativeExpand<usize>,
217 value: NativeExpand<A::Output>,
218 ) where
219 A::Output: CubeType + Sized,
220 {
221 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Add);
222 }
223}
224
225pub mod sub_assign_array_op {
226 use self::ir::Arithmetic;
227 use super::*;
228 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
229
230 pub fn expand<A: CubeType + CubeIndex>(
231 scope: &mut Scope,
232 array: NativeExpand<A>,
233 index: NativeExpand<usize>,
234 value: NativeExpand<A::Output>,
235 ) where
236 A::Output: CubeType + Sized,
237 {
238 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Sub);
239 }
240}
241
242pub mod mul_assign_array_op {
243 use self::ir::Arithmetic;
244 use super::*;
245 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
246
247 pub fn expand<A: CubeType + CubeIndex>(
248 scope: &mut Scope,
249 array: NativeExpand<A>,
250 index: NativeExpand<usize>,
251 value: NativeExpand<A::Output>,
252 ) where
253 A::Output: CubeType + Sized,
254 {
255 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Mul);
256 }
257}
258
259pub mod div_assign_array_op {
260 use self::ir::Arithmetic;
261 use super::*;
262 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
263
264 pub fn expand<A: CubeType + CubeIndex>(
265 scope: &mut Scope,
266 array: NativeExpand<A>,
267 index: NativeExpand<usize>,
268 value: NativeExpand<A::Output>,
269 ) where
270 A::Output: CubeType + Sized,
271 {
272 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Div);
273 }
274}
275
276pub mod rem_assign_array_op {
277 use self::ir::Arithmetic;
278 use super::*;
279 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
280
281 pub fn expand<A: CubeType + CubeIndex>(
282 scope: &mut Scope,
283 array: NativeExpand<A>,
284 index: NativeExpand<usize>,
285 value: NativeExpand<A::Output>,
286 ) where
287 A::Output: CubeType + Sized,
288 {
289 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Modulo);
290 }
291}
292
293pub mod bitor_assign_array_op {
294 use super::*;
295 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
296
297 pub fn expand<A: CubeType + CubeIndex>(
298 scope: &mut Scope,
299 array: NativeExpand<A>,
300 index: NativeExpand<usize>,
301 value: NativeExpand<A::Output>,
302 ) where
303 A::Output: CubeType + Sized,
304 {
305 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseOr);
306 }
307}
308
309pub mod bitand_assign_array_op {
310 use super::*;
311 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
312
313 pub fn expand<A: CubeType + CubeIndex>(
314 scope: &mut Scope,
315 array: NativeExpand<A>,
316 index: NativeExpand<usize>,
317 value: NativeExpand<A::Output>,
318 ) where
319 A::Output: CubeType + Sized,
320 {
321 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseAnd);
322 }
323}
324
325pub mod bitxor_assign_array_op {
326 use super::*;
327 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
328
329 pub fn expand<A: CubeType + CubeIndex>(
330 scope: &mut Scope,
331 array: NativeExpand<A>,
332 index: NativeExpand<usize>,
333 value: NativeExpand<A::Output>,
334 ) where
335 A::Output: CubeType + Sized,
336 {
337 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseXor);
338 }
339}
340
341pub mod shl_assign_array_op {
342
343 use super::*;
344 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
345
346 pub fn expand<A: CubeType + CubeIndex>(
347 scope: &mut Scope,
348 array: NativeExpand<A>,
349 index: NativeExpand<usize>,
350 value: NativeExpand<u32>,
351 ) where
352 A::Output: CubeType + Sized,
353 {
354 array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftLeft);
355 }
356}
357
358pub mod shr_assign_array_op {
359
360 use super::*;
361 use crate::prelude::{CubeType, NativeExpand, array_assign_binary_op_expand};
362
363 pub fn expand<A: CubeType + CubeIndex>(
364 scope: &mut Scope,
365 array: NativeExpand<A>,
366 index: NativeExpand<usize>,
367 value: NativeExpand<u32>,
368 ) where
369 A::Output: CubeType + Sized,
370 {
371 array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftRight);
372 }
373}
374
375pub mod add_assign_op {
376 use core::ops::AddAssign;
377
378 use self::ir::Arithmetic;
379 use crate::{
380 frontend::operation::base::assign_op_expand,
381 prelude::{CubeType, NativeExpand},
382 };
383
384 use super::*;
385
386 pub fn expand<C: CubeType + AddAssign>(
387 scope: &mut Scope,
388 lhs: NativeExpand<C>,
389 rhs: NativeExpand<C>,
390 ) -> NativeExpand<C> {
391 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
392 }
393}
394
395pub mod sub_assign_op {
396 use self::ir::Arithmetic;
397 use super::*;
398 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
399
400 pub fn expand<C: CubeType>(
401 scope: &mut Scope,
402 lhs: NativeExpand<C>,
403 rhs: NativeExpand<C>,
404 ) -> ManagedVariable {
405 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub)
406 }
407}
408
409pub mod mul_assign_op {
410 use self::ir::Arithmetic;
411 use super::*;
412 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
413
414 pub fn expand<C: CubeType>(
415 scope: &mut Scope,
416 lhs: NativeExpand<C>,
417 rhs: NativeExpand<C>,
418 ) -> ManagedVariable {
419 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul)
420 }
421}
422
423pub mod div_assign_op {
424 use self::ir::Arithmetic;
425 use super::*;
426 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
427
428 pub fn expand<C: CubeType>(
429 scope: &mut Scope,
430 lhs: NativeExpand<C>,
431 rhs: NativeExpand<C>,
432 ) -> ManagedVariable {
433 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div)
434 }
435}
436
437pub mod rem_assign_op {
438 use self::ir::Arithmetic;
439 use super::*;
440 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
441
442 pub fn expand<C: CubeType>(
443 scope: &mut Scope,
444 lhs: NativeExpand<C>,
445 rhs: NativeExpand<C>,
446 ) -> ManagedVariable {
447 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo)
448 }
449}
450
451pub mod bitor_assign_op {
452
453 use super::*;
454 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
455
456 pub fn expand<C: CubeType>(
457 scope: &mut Scope,
458 lhs: NativeExpand<C>,
459 rhs: NativeExpand<C>,
460 ) -> ManagedVariable {
461 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr)
462 }
463}
464
465pub mod bitand_assign_op {
466
467 use super::*;
468 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
469
470 pub fn expand<C: CubeType>(
471 scope: &mut Scope,
472 lhs: NativeExpand<C>,
473 rhs: NativeExpand<C>,
474 ) -> ManagedVariable {
475 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd)
476 }
477}
478
479pub mod bitxor_assign_op {
480
481 use super::*;
482 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
483
484 pub fn expand<C: CubeType>(
485 scope: &mut Scope,
486 lhs: NativeExpand<C>,
487 rhs: NativeExpand<C>,
488 ) -> ManagedVariable {
489 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor)
490 }
491}
492
493pub mod shl_assign_op {
494
495 use super::*;
496 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
497
498 pub fn expand<C: CubeType>(
499 scope: &mut Scope,
500 lhs: NativeExpand<C>,
501 rhs: NativeExpand<u32>,
502 ) -> ManagedVariable {
503 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft)
504 }
505}
506
507pub mod shr_assign_op {
508 use super::*;
509 use crate::{frontend::operation::base::assign_op_expand, prelude::NativeExpand};
510
511 pub fn expand<C: CubeType>(
512 scope: &mut Scope,
513 lhs: NativeExpand<C>,
514 rhs: NativeExpand<u32>,
515 ) -> ManagedVariable {
516 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight)
517 }
518}
519
520pub mod add_assign {
521 use cubecl_ir::Arithmetic;
522
523 use super::*;
524 use crate::prelude::{CubePrimitive, NativeExpand, assign_op_expand};
525
526 pub fn expand<C: CubePrimitive>(
527 scope: &mut Scope,
528 lhs: NativeExpand<C>,
529 rhs: NativeExpand<C>,
530 ) -> NativeExpand<C> {
531 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
532 }
533}