1use cubecl_ir::{Bitwise, ExpandElement, Operator, Scope};
2
3use crate::ir;
4use crate::{
5 frontend::{Array, SharedMemory, Tensor},
6 prelude::{CubeIndex, CubeIndexMut, CubePrimitive, CubeType},
7};
8
9pub mod cast {
10 use ir::Instruction;
11
12 use crate::prelude::ExpandElementTyped;
13
14 use self::ir::UnaryOperator;
15
16 use super::*;
17
18 pub fn expand<C: CubeType>(
19 scope: &mut Scope,
20 input: ExpandElementTyped<C>,
21 output: ExpandElementTyped<C>,
22 ) {
23 scope.register(Instruction::new(
24 Operator::Cast(UnaryOperator {
25 input: *input.expand,
26 }),
27 *output.expand,
28 ));
29 }
30}
31
32pub mod assign {
33 use ir::{Instruction, Operation};
34
35 use crate::prelude::ExpandElementTyped;
36
37 use super::*;
38
39 pub fn expand<C: CubeType>(
44 scope: &mut Scope,
45 input: ExpandElementTyped<C>,
46 output: ExpandElementTyped<C>,
47 ) {
48 let output = *output.expand;
49 let input = *input.expand;
50
51 if output.is_immutable() {
52 panic!("Can't assign a value to a const variable. Try to use `RuntimeCell`.");
53 }
54
55 scope.register(Instruction::new(Operation::Copy(input), output));
56 }
57 pub fn expand_no_check<C: CubeType>(
61 scope: &mut Scope,
62 input: ExpandElementTyped<C>,
63 output: ExpandElementTyped<C>,
64 ) {
65 let output = *output.expand;
66 let input = *input.expand;
67
68 scope.register(Instruction::new(Operation::Copy(input), output));
69 }
70}
71
72pub mod index_assign {
73 use super::*;
74 use crate::prelude::{
75 CubeIndexMutExpand, ExpandElementTyped, Line, expand_index_assign_native,
76 };
77
78 pub fn expand<A: CubeIndexMutExpand<Output = ExpandElementTyped<V>>, V: CubePrimitive>(
79 scope: &mut Scope,
80 expand: A,
81 index: ExpandElementTyped<u32>,
82 value: ExpandElementTyped<V>,
83 ) {
84 expand.expand_index_mut(scope, index, value)
85 }
86
87 macro_rules! impl_index {
88 ($type:ident) => {
89 impl<E: CubePrimitive> CubeIndexMut for $type<E> {}
90
91 impl<E: CubePrimitive> CubeIndexMutExpand for ExpandElementTyped<$type<E>> {
92 type Output = ExpandElementTyped<E>;
93
94 fn expand_index_mut(
95 self,
96 scope: &mut Scope,
97 index: ExpandElementTyped<u32>,
98 value: Self::Output,
99 ) {
100 expand_index_assign_native::<$type<E>>(scope, self, index, value, None, true);
101 }
102 }
103 };
104 }
105
106 impl_index!(Array);
107 impl_index!(Tensor);
108 impl_index!(SharedMemory);
109 impl_index!(Line);
110}
111
112pub mod index {
113 use super::*;
114 use crate::prelude::{CubeIndexExpand, ExpandElementTyped, Line, expand_index_native};
115
116 pub fn expand<A: CubeIndexExpand<Output = ExpandElementTyped<V>>, V: CubeType>(
117 scope: &mut Scope,
118 expand: A,
119 index: ExpandElementTyped<u32>,
120 ) -> ExpandElementTyped<V> {
121 expand.expand_index(scope, index)
122 }
123
124 pub fn expand_with<A: CubeIndexExpand<Output = ExpandElementTyped<V>>, V: CubeType>(
125 scope: &mut Scope,
126 expand: A,
127 index: ExpandElementTyped<u32>,
128 ) -> ExpandElementTyped<V> {
129 expand.expand_index(scope, index)
130 }
131
132 macro_rules! impl_index {
133 ($type:ident) => {
134 impl<E: CubePrimitive> CubeIndex for $type<E> {
135 type Output = E;
136 }
137
138 impl<E: CubePrimitive> CubeIndexExpand for ExpandElementTyped<$type<E>> {
139 type Output = ExpandElementTyped<E>;
140
141 fn expand_index(
142 self,
143 scope: &mut Scope,
144 index: ExpandElementTyped<u32>,
145 ) -> Self::Output {
146 expand_index_native(scope, self, index, None, true)
147 }
148 fn expand_index_unchecked(
149 self,
150 scope: &mut Scope,
151 index: ExpandElementTyped<u32>,
152 ) -> Self::Output {
153 expand_index_native(scope, self, index, None, false)
154 }
155 }
156 };
157 }
158
159 impl_index!(Array);
160 impl_index!(Tensor);
161 impl_index!(SharedMemory);
162 impl_index!(Line);
163}
164
165pub mod index_unchecked {
166 use super::*;
167 use crate::prelude::{CubeIndexExpand, ExpandElementTyped};
168
169 pub fn expand<A: CubeIndexExpand<Output = ExpandElementTyped<V>>, V: CubeType>(
170 scope: &mut Scope,
171 expand: A,
172 index: ExpandElementTyped<u32>,
173 ) -> ExpandElementTyped<V> {
174 expand.expand_index_unchecked(scope, index)
175 }
176}
177
178pub mod add_assign_array_op {
179 use self::ir::Arithmetic;
180 use super::*;
181 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
182
183 pub fn expand<A: CubeType + CubeIndex>(
184 scope: &mut Scope,
185 array: ExpandElementTyped<A>,
186 index: ExpandElementTyped<u32>,
187 value: ExpandElementTyped<A::Output>,
188 ) where
189 A::Output: CubeType + Sized,
190 {
191 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Add);
192 }
193}
194
195pub mod sub_assign_array_op {
196 use self::ir::Arithmetic;
197 use super::*;
198 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
199
200 pub fn expand<A: CubeType + CubeIndex>(
201 scope: &mut Scope,
202 array: ExpandElementTyped<A>,
203 index: ExpandElementTyped<u32>,
204 value: ExpandElementTyped<A::Output>,
205 ) where
206 A::Output: CubeType + Sized,
207 {
208 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Sub);
209 }
210}
211
212pub mod mul_assign_array_op {
213 use self::ir::Arithmetic;
214 use super::*;
215 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
216
217 pub fn expand<A: CubeType + CubeIndex>(
218 scope: &mut Scope,
219 array: ExpandElementTyped<A>,
220 index: ExpandElementTyped<u32>,
221 value: ExpandElementTyped<A::Output>,
222 ) where
223 A::Output: CubeType + Sized,
224 {
225 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Mul);
226 }
227}
228
229pub mod div_assign_array_op {
230 use self::ir::Arithmetic;
231 use super::*;
232 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
233
234 pub fn expand<A: CubeType + CubeIndex>(
235 scope: &mut Scope,
236 array: ExpandElementTyped<A>,
237 index: ExpandElementTyped<u32>,
238 value: ExpandElementTyped<A::Output>,
239 ) where
240 A::Output: CubeType + Sized,
241 {
242 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Div);
243 }
244}
245
246pub mod rem_assign_array_op {
247 use self::ir::Arithmetic;
248 use super::*;
249 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
250
251 pub fn expand<A: CubeType + CubeIndex>(
252 scope: &mut Scope,
253 array: ExpandElementTyped<A>,
254 index: ExpandElementTyped<u32>,
255 value: ExpandElementTyped<A::Output>,
256 ) where
257 A::Output: CubeType + Sized,
258 {
259 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Modulo);
260 }
261}
262
263pub mod bitor_assign_array_op {
264 use super::*;
265 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
266
267 pub fn expand<A: CubeType + CubeIndex>(
268 scope: &mut Scope,
269 array: ExpandElementTyped<A>,
270 index: ExpandElementTyped<u32>,
271 value: ExpandElementTyped<A::Output>,
272 ) where
273 A::Output: CubeType + Sized,
274 {
275 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseOr);
276 }
277}
278
279pub mod bitand_assign_array_op {
280 use super::*;
281 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
282
283 pub fn expand<A: CubeType + CubeIndex>(
284 scope: &mut Scope,
285 array: ExpandElementTyped<A>,
286 index: ExpandElementTyped<u32>,
287 value: ExpandElementTyped<A::Output>,
288 ) where
289 A::Output: CubeType + Sized,
290 {
291 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseAnd);
292 }
293}
294
295pub mod bitxor_assign_array_op {
296 use super::*;
297 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
298
299 pub fn expand<A: CubeType + CubeIndex>(
300 scope: &mut Scope,
301 array: ExpandElementTyped<A>,
302 index: ExpandElementTyped<u32>,
303 value: ExpandElementTyped<A::Output>,
304 ) where
305 A::Output: CubeType + Sized,
306 {
307 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseXor);
308 }
309}
310
311pub mod shl_assign_array_op {
312
313 use super::*;
314 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
315
316 pub fn expand<A: CubeType + CubeIndex>(
317 scope: &mut Scope,
318 array: ExpandElementTyped<A>,
319 index: ExpandElementTyped<u32>,
320 value: ExpandElementTyped<u32>,
321 ) where
322 A::Output: CubeType + Sized,
323 {
324 array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftLeft);
325 }
326}
327
328pub mod shr_assign_array_op {
329
330 use super::*;
331 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
332
333 pub fn expand<A: CubeType + CubeIndex>(
334 scope: &mut Scope,
335 array: ExpandElementTyped<A>,
336 index: ExpandElementTyped<u32>,
337 value: ExpandElementTyped<u32>,
338 ) where
339 A::Output: CubeType + Sized,
340 {
341 array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftRight);
342 }
343}
344
345pub mod add_assign_op {
346 use std::ops::AddAssign;
347
348 use self::ir::Arithmetic;
349 use crate::{
350 frontend::operation::base::assign_op_expand,
351 prelude::{CubeType, ExpandElementTyped},
352 };
353
354 use super::*;
355
356 pub fn expand<C: CubeType + AddAssign>(
357 scope: &mut Scope,
358 lhs: ExpandElementTyped<C>,
359 rhs: ExpandElementTyped<C>,
360 ) -> ExpandElementTyped<C> {
361 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
362 }
363}
364
365pub mod sub_assign_op {
366 use self::ir::Arithmetic;
367 use super::*;
368 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
369
370 pub fn expand<C: CubeType>(
371 scope: &mut Scope,
372 lhs: ExpandElementTyped<C>,
373 rhs: ExpandElementTyped<C>,
374 ) -> ExpandElement {
375 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub)
376 }
377}
378
379pub mod mul_assign_op {
380 use self::ir::Arithmetic;
381 use super::*;
382 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
383
384 pub fn expand<C: CubeType>(
385 scope: &mut Scope,
386 lhs: ExpandElementTyped<C>,
387 rhs: ExpandElementTyped<C>,
388 ) -> ExpandElement {
389 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul)
390 }
391}
392
393pub mod div_assign_op {
394 use self::ir::Arithmetic;
395 use super::*;
396 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
397
398 pub fn expand<C: CubeType>(
399 scope: &mut Scope,
400 lhs: ExpandElementTyped<C>,
401 rhs: ExpandElementTyped<C>,
402 ) -> ExpandElement {
403 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div)
404 }
405}
406
407pub mod rem_assign_op {
408 use self::ir::Arithmetic;
409 use super::*;
410 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
411
412 pub fn expand<C: CubeType>(
413 scope: &mut Scope,
414 lhs: ExpandElementTyped<C>,
415 rhs: ExpandElementTyped<C>,
416 ) -> ExpandElement {
417 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo)
418 }
419}
420
421pub mod bitor_assign_op {
422
423 use super::*;
424 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
425
426 pub fn expand<C: CubeType>(
427 scope: &mut Scope,
428 lhs: ExpandElementTyped<C>,
429 rhs: ExpandElementTyped<C>,
430 ) -> ExpandElement {
431 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr)
432 }
433}
434
435pub mod bitand_assign_op {
436
437 use super::*;
438 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
439
440 pub fn expand<C: CubeType>(
441 scope: &mut Scope,
442 lhs: ExpandElementTyped<C>,
443 rhs: ExpandElementTyped<C>,
444 ) -> ExpandElement {
445 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd)
446 }
447}
448
449pub mod bitxor_assign_op {
450
451 use super::*;
452 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
453
454 pub fn expand<C: CubeType>(
455 scope: &mut Scope,
456 lhs: ExpandElementTyped<C>,
457 rhs: ExpandElementTyped<C>,
458 ) -> ExpandElement {
459 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor)
460 }
461}
462
463pub mod shl_assign_op {
464
465 use super::*;
466 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
467
468 pub fn expand<C: CubeType>(
469 scope: &mut Scope,
470 lhs: ExpandElementTyped<C>,
471 rhs: ExpandElementTyped<u32>,
472 ) -> ExpandElement {
473 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft)
474 }
475}
476
477pub mod shr_assign_op {
478 use super::*;
479 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
480
481 pub fn expand<C: CubeType>(
482 scope: &mut Scope,
483 lhs: ExpandElementTyped<C>,
484 rhs: ExpandElementTyped<u32>,
485 ) -> ExpandElement {
486 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight)
487 }
488}
489
490pub mod add_assign {
491 use cubecl_ir::Arithmetic;
492
493 use super::*;
494 use crate::prelude::{CubePrimitive, ExpandElementTyped, assign_op_expand};
495
496 pub fn expand<C: CubePrimitive>(
497 scope: &mut Scope,
498 lhs: ExpandElementTyped<C>,
499 rhs: ExpandElementTyped<C>,
500 ) -> ExpandElementTyped<C> {
501 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
502 }
503}