1use cubecl_ir::{Bitwise, ExpandElement, Operator, Scope};
2use half::{bf16, f16};
3
4use crate::{
5 frontend::{Array, SharedMemory, Tensor},
6 prelude::{CubeIndex, CubeIndexMut, CubeType},
7};
8use crate::{ir, prelude::Index};
9
10pub mod cast {
11 use ir::Instruction;
12
13 use crate::prelude::ExpandElementTyped;
14
15 use self::ir::UnaryOperator;
16
17 use super::*;
18
19 pub fn expand<C: CubeType>(
20 scope: &mut Scope,
21 input: ExpandElementTyped<C>,
22 output: ExpandElementTyped<C>,
23 ) {
24 scope.register(Instruction::new(
25 Operator::Cast(UnaryOperator {
26 input: *input.expand,
27 }),
28 *output.expand,
29 ));
30 }
31}
32
33pub mod assign {
34 use ir::{Instruction, Operation};
35
36 use crate::prelude::ExpandElementTyped;
37
38 use super::*;
39
40 pub fn expand<C: CubeType>(
41 scope: &mut Scope,
42 input: ExpandElementTyped<C>,
43 output: ExpandElementTyped<C>,
44 ) {
45 scope.register(Instruction::new(
46 Operation::Copy(*input.expand),
47 *output.expand,
48 ));
49 }
50}
51
52pub mod index_assign {
53 use ir::{Instruction, UIntKind, VariableKind};
54
55 use crate::{
56 flex32,
57 frontend::CubeType,
58 prelude::{ExpandElementTyped, SliceMut},
59 tf32,
60 };
61
62 use self::ir::{BinaryOperator, Variable};
63
64 use super::*;
65
66 pub fn expand<A: CubeType + CubeIndex<u32>>(
67 scope: &mut Scope,
68 array: ExpandElementTyped<A>,
69 index: ExpandElementTyped<u32>,
70 value: ExpandElementTyped<A::Output>,
71 ) where
72 A::Output: CubeType + Sized,
73 {
74 let index: Variable = index.expand.into();
75 let index = match index.kind {
76 VariableKind::ConstantScalar(value) => {
77 Variable::constant(ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32))
78 }
79 _ => index,
80 };
81 scope.register(Instruction::new(
82 Operator::IndexAssign(BinaryOperator {
83 lhs: index,
84 rhs: value.expand.into(),
85 }),
86 array.expand.into(),
87 ));
88 }
89
90 macro_rules! impl_index {
91 ($type:ident) => {
92 impl<E: CubeType, I: Index> CubeIndexMut<I> for $type<E> {}
93 };
94 }
95 macro_rules! impl_index_vec {
96 ($($type:ident),*) => {
97 $(
98 impl<I: Index> CubeIndexMut<I> for $type {}
99 )*
100 };
101 }
102
103 impl_index!(Array);
104 impl_index!(Tensor);
105 impl_index!(SharedMemory);
106 impl_index_vec!(
107 i64, i32, i16, i8, f16, bf16, flex32, tf32, f32, f64, u64, u32, u16, u8
108 );
109
110 impl<E: CubeType, I: Index> CubeIndexMut<I> for SliceMut<E> {}
111}
112
113pub mod index {
114 use ir::{UIntKind, VariableKind};
115
116 use crate::{
117 flex32,
118 frontend::{
119 CubeType,
120 operation::base::{binary_expand, binary_expand_no_vec},
121 },
122 prelude::{ExpandElementTyped, Slice, SliceMut},
123 tf32,
124 };
125
126 use self::ir::Variable;
127
128 use super::*;
129
130 pub fn expand<A: CubeType + CubeIndex<ExpandElementTyped<u32>>>(
131 scope: &mut Scope,
132 array: ExpandElementTyped<A>,
133 index: ExpandElementTyped<u32>,
134 ) -> ExpandElementTyped<A::Output>
135 where
136 A::Output: CubeType + Sized,
137 {
138 let index: ExpandElement = index.into();
139 let index_var: Variable = *index;
140 let index = match index_var.kind {
141 VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
142 ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
143 )),
144 _ => index,
145 };
146 let array: ExpandElement = array.into();
147 let var: Variable = *array;
148 let var = match var.kind {
149 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
150 binary_expand_no_vec(scope, array, index, Operator::Index)
151 }
152 _ => binary_expand(scope, array, index, Operator::Index),
153 };
154
155 ExpandElementTyped::new(var)
156 }
157
158 macro_rules! impl_index {
159 ($type:ident) => {
160 impl<E: CubeType, I: Index> CubeIndex<I> for $type<E> {
161 type Output = E;
162 }
163 };
164 }
165 macro_rules! impl_index_vec {
166 ($($type:ident),*) => {
167 $(
168 impl<I: Index> CubeIndex<I> for $type {
169 type Output = Self;
170 }
171 )*
172 };
173 }
174
175 impl_index!(Array);
176 impl_index!(Tensor);
177 impl_index!(SharedMemory);
178 impl_index_vec!(
179 i64, i32, i16, i8, f16, flex32, tf32, bf16, f32, f64, u64, u32, u16, u8
180 );
181
182 impl<E: CubeType, I: Index> CubeIndex<I> for Slice<E> {
183 type Output = E;
184 }
185
186 impl<E: CubeType, I: Index> CubeIndex<I> for SliceMut<E> {
187 type Output = E;
188 }
189}
190
191pub mod add_assign_array_op {
192 use self::ir::Arithmetic;
193 use super::*;
194 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
195
196 pub fn expand<A: CubeType + CubeIndex<u32>>(
197 scope: &mut Scope,
198 array: ExpandElementTyped<A>,
199 index: ExpandElementTyped<u32>,
200 value: ExpandElementTyped<A::Output>,
201 ) where
202 A::Output: CubeType + Sized,
203 {
204 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Add);
205 }
206}
207
208pub mod sub_assign_array_op {
209 use self::ir::Arithmetic;
210 use super::*;
211 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
212
213 pub fn expand<A: CubeType + CubeIndex<u32>>(
214 scope: &mut Scope,
215 array: ExpandElementTyped<A>,
216 index: ExpandElementTyped<u32>,
217 value: ExpandElementTyped<A::Output>,
218 ) where
219 A::Output: CubeType + Sized,
220 {
221 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Sub);
222 }
223}
224
225pub mod mul_assign_array_op {
226 use self::ir::Arithmetic;
227 use super::*;
228 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
229
230 pub fn expand<A: CubeType + CubeIndex<u32>>(
231 scope: &mut Scope,
232 array: ExpandElementTyped<A>,
233 index: ExpandElementTyped<u32>,
234 value: ExpandElementTyped<A::Output>,
235 ) where
236 A::Output: CubeType + Sized,
237 {
238 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Mul);
239 }
240}
241
242pub mod div_assign_array_op {
243 use self::ir::Arithmetic;
244 use super::*;
245 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
246
247 pub fn expand<A: CubeType + CubeIndex<u32>>(
248 scope: &mut Scope,
249 array: ExpandElementTyped<A>,
250 index: ExpandElementTyped<u32>,
251 value: ExpandElementTyped<A::Output>,
252 ) where
253 A::Output: CubeType + Sized,
254 {
255 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Div);
256 }
257}
258
259pub mod rem_assign_array_op {
260 use self::ir::Arithmetic;
261 use super::*;
262 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
263
264 pub fn expand<A: CubeType + CubeIndex<u32>>(
265 scope: &mut Scope,
266 array: ExpandElementTyped<A>,
267 index: ExpandElementTyped<u32>,
268 value: ExpandElementTyped<A::Output>,
269 ) where
270 A::Output: CubeType + Sized,
271 {
272 array_assign_binary_op_expand(scope, array, index, value, Arithmetic::Modulo);
273 }
274}
275
276pub mod bitor_assign_array_op {
277 use super::*;
278 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
279
280 pub fn expand<A: CubeType + CubeIndex<u32>>(
281 scope: &mut Scope,
282 array: ExpandElementTyped<A>,
283 index: ExpandElementTyped<u32>,
284 value: ExpandElementTyped<A::Output>,
285 ) where
286 A::Output: CubeType + Sized,
287 {
288 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseOr);
289 }
290}
291
292pub mod bitand_assign_array_op {
293 use super::*;
294 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
295
296 pub fn expand<A: CubeType + CubeIndex<u32>>(
297 scope: &mut Scope,
298 array: ExpandElementTyped<A>,
299 index: ExpandElementTyped<u32>,
300 value: ExpandElementTyped<A::Output>,
301 ) where
302 A::Output: CubeType + Sized,
303 {
304 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseAnd);
305 }
306}
307
308pub mod bitxor_assign_array_op {
309 use super::*;
310 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
311
312 pub fn expand<A: CubeType + CubeIndex<u32>>(
313 scope: &mut Scope,
314 array: ExpandElementTyped<A>,
315 index: ExpandElementTyped<u32>,
316 value: ExpandElementTyped<A::Output>,
317 ) where
318 A::Output: CubeType + Sized,
319 {
320 array_assign_binary_op_expand(scope, array, index, value, Bitwise::BitwiseXor);
321 }
322}
323
324pub mod shl_assign_array_op {
325
326 use super::*;
327 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
328
329 pub fn expand<A: CubeType + CubeIndex<u32>>(
330 scope: &mut Scope,
331 array: ExpandElementTyped<A>,
332 index: ExpandElementTyped<u32>,
333 value: ExpandElementTyped<u32>,
334 ) where
335 A::Output: CubeType + Sized,
336 {
337 array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftLeft);
338 }
339}
340
341pub mod shr_assign_array_op {
342
343 use super::*;
344 use crate::prelude::{CubeType, ExpandElementTyped, array_assign_binary_op_expand};
345
346 pub fn expand<A: CubeType + CubeIndex<u32>>(
347 scope: &mut Scope,
348 array: ExpandElementTyped<A>,
349 index: ExpandElementTyped<u32>,
350 value: ExpandElementTyped<u32>,
351 ) where
352 A::Output: CubeType + Sized,
353 {
354 array_assign_binary_op_expand(scope, array, index, value, Bitwise::ShiftRight);
355 }
356}
357
358pub mod add_assign_op {
359 use std::ops::AddAssign;
360
361 use self::ir::Arithmetic;
362 use crate::{
363 frontend::operation::base::assign_op_expand,
364 prelude::{CubeType, ExpandElementTyped},
365 };
366
367 use super::*;
368
369 pub fn expand<C: CubeType + AddAssign>(
370 scope: &mut Scope,
371 lhs: ExpandElementTyped<C>,
372 rhs: ExpandElementTyped<C>,
373 ) -> ExpandElementTyped<C> {
374 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
375 }
376}
377
378pub mod sub_assign_op {
379 use self::ir::Arithmetic;
380 use super::*;
381 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
382
383 pub fn expand<C: CubeType>(
384 scope: &mut Scope,
385 lhs: ExpandElementTyped<C>,
386 rhs: ExpandElementTyped<C>,
387 ) -> ExpandElement {
388 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Sub)
389 }
390}
391
392pub mod mul_assign_op {
393 use self::ir::Arithmetic;
394 use super::*;
395 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
396
397 pub fn expand<C: CubeType>(
398 scope: &mut Scope,
399 lhs: ExpandElementTyped<C>,
400 rhs: ExpandElementTyped<C>,
401 ) -> ExpandElement {
402 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Mul)
403 }
404}
405
406pub mod div_assign_op {
407 use self::ir::Arithmetic;
408 use super::*;
409 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
410
411 pub fn expand<C: CubeType>(
412 scope: &mut Scope,
413 lhs: ExpandElementTyped<C>,
414 rhs: ExpandElementTyped<C>,
415 ) -> ExpandElement {
416 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Div)
417 }
418}
419
420pub mod rem_assign_op {
421 use self::ir::Arithmetic;
422 use super::*;
423 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
424
425 pub fn expand<C: CubeType>(
426 scope: &mut Scope,
427 lhs: ExpandElementTyped<C>,
428 rhs: ExpandElementTyped<C>,
429 ) -> ExpandElement {
430 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Modulo)
431 }
432}
433
434pub mod bitor_assign_op {
435
436 use super::*;
437 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
438
439 pub fn expand<C: CubeType>(
440 scope: &mut Scope,
441 lhs: ExpandElementTyped<C>,
442 rhs: ExpandElementTyped<C>,
443 ) -> ExpandElement {
444 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseOr)
445 }
446}
447
448pub mod bitand_assign_op {
449
450 use super::*;
451 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
452
453 pub fn expand<C: CubeType>(
454 scope: &mut Scope,
455 lhs: ExpandElementTyped<C>,
456 rhs: ExpandElementTyped<C>,
457 ) -> ExpandElement {
458 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseAnd)
459 }
460}
461
462pub mod bitxor_assign_op {
463
464 use super::*;
465 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
466
467 pub fn expand<C: CubeType>(
468 scope: &mut Scope,
469 lhs: ExpandElementTyped<C>,
470 rhs: ExpandElementTyped<C>,
471 ) -> ExpandElement {
472 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::BitwiseXor)
473 }
474}
475
476pub mod shl_assign_op {
477
478 use super::*;
479 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
480
481 pub fn expand<C: CubeType>(
482 scope: &mut Scope,
483 lhs: ExpandElementTyped<C>,
484 rhs: ExpandElementTyped<u32>,
485 ) -> ExpandElement {
486 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftLeft)
487 }
488}
489
490pub mod shr_assign_op {
491 use super::*;
492 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
493
494 pub fn expand<C: CubeType>(
495 scope: &mut Scope,
496 lhs: ExpandElementTyped<C>,
497 rhs: ExpandElementTyped<u32>,
498 ) -> ExpandElement {
499 assign_op_expand(scope, lhs.into(), rhs.into(), Bitwise::ShiftRight)
500 }
501}
502
503pub mod add_assign {
504 use cubecl_ir::Arithmetic;
505
506 use super::*;
507 use crate::prelude::{CubePrimitive, ExpandElementTyped, assign_op_expand};
508
509 pub fn expand<C: CubePrimitive>(
510 scope: &mut Scope,
511 lhs: ExpandElementTyped<C>,
512 rhs: ExpandElementTyped<C>,
513 ) -> ExpandElementTyped<C> {
514 assign_op_expand(scope, lhs.into(), rhs.into(), Arithmetic::Add).into()
515 }
516}