1use half::{bf16, f16};
2
3use crate::{
4 frontend::{Array, CubeContext, ExpandElement, SharedMemory, Tensor},
5 prelude::{CubeIndex, CubeIndexMut, CubeType},
6};
7use crate::{ir, prelude::Index};
8
9pub mod cast {
10 use ir::Instruction;
11
12 use crate::prelude::ExpandElementTyped;
13
14 use self::ir::{Operator, UnaryOperator};
15
16 use super::*;
17
18 pub fn expand<C: CubeType>(
19 context: &mut CubeContext,
20 input: ExpandElementTyped<C>,
21 output: ExpandElementTyped<C>,
22 ) {
23 context.register(Instruction::new(
24 Operator::Cast(UnaryOperator {
25 input: *input.expand,
26 }),
27 *output.expand,
28 ));
29 }
30}
31
32pub mod assign {
33 use ir::{Instruction, Operation};
34
35 use crate::prelude::ExpandElementTyped;
36
37 use super::*;
38
39 pub fn expand<C: CubeType>(
40 context: &mut CubeContext,
41 input: ExpandElementTyped<C>,
42 output: ExpandElementTyped<C>,
43 ) {
44 context.register(Instruction::new(
45 Operation::Copy(*input.expand),
46 *output.expand,
47 ));
48 }
49}
50
51pub mod index_assign {
52 use ir::{Instruction, UIntKind, VariableKind};
53
54 use crate::{
55 flex32,
56 frontend::CubeType,
57 prelude::{ExpandElementTyped, SliceMut},
58 tf32,
59 };
60
61 use self::ir::{BinaryOperator, Operator, Variable};
62
63 use super::*;
64
65 pub fn expand<A: CubeType + CubeIndex<u32>>(
66 context: &mut CubeContext,
67 array: ExpandElementTyped<A>,
68 index: ExpandElementTyped<u32>,
69 value: ExpandElementTyped<A::Output>,
70 ) where
71 A::Output: CubeType + Sized,
72 {
73 let index: Variable = index.expand.into();
74 let index = match index.kind {
75 VariableKind::ConstantScalar(value) => {
76 Variable::constant(ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32))
77 }
78 _ => index,
79 };
80 context.register(Instruction::new(
81 Operator::IndexAssign(BinaryOperator {
82 lhs: index,
83 rhs: value.expand.into(),
84 }),
85 array.expand.into(),
86 ));
87 }
88
89 macro_rules! impl_index {
90 ($type:ident) => {
91 impl<E: CubeType, I: Index> CubeIndexMut<I> for $type<E> {}
92 };
93 }
94 macro_rules! impl_index_vec {
95 ($($type:ident),*) => {
96 $(
97 impl<I: Index> CubeIndexMut<I> for $type {}
98 )*
99 };
100 }
101
102 impl_index!(Array);
103 impl_index!(Tensor);
104 impl_index!(SharedMemory);
105 impl_index_vec!(i64, i32, i16, i8, f16, bf16, flex32, tf32, f32, f64, u64, u32, u16, u8);
106
107 impl<E: CubeType, I: Index> CubeIndexMut<I> for SliceMut<E> {}
108}
109
110pub mod index {
111 use ir::{UIntKind, VariableKind};
112
113 use crate::{
114 flex32,
115 frontend::{
116 operation::base::{binary_expand, binary_expand_no_vec},
117 CubeType,
118 },
119 prelude::{ExpandElementTyped, Slice, SliceMut},
120 tf32,
121 };
122
123 use self::ir::{Operator, Variable};
124
125 use super::*;
126
127 pub fn expand<A: CubeType + CubeIndex<ExpandElementTyped<u32>>>(
128 context: &mut CubeContext,
129 array: ExpandElementTyped<A>,
130 index: ExpandElementTyped<u32>,
131 ) -> ExpandElementTyped<A::Output>
132 where
133 A::Output: CubeType + Sized,
134 {
135 let index: ExpandElement = index.into();
136 let index_var: Variable = *index;
137 let index = match index_var.kind {
138 VariableKind::ConstantScalar(value) => ExpandElement::Plain(Variable::constant(
139 ir::ConstantScalarValue::UInt(value.as_u64(), UIntKind::U32),
140 )),
141 _ => index,
142 };
143 let array: ExpandElement = array.into();
144 let var: Variable = *array;
145 let var = match var.kind {
146 VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. } => {
147 binary_expand_no_vec(context, array, index, Operator::Index)
148 }
149 _ => binary_expand(context, array, index, Operator::Index),
150 };
151
152 ExpandElementTyped::new(var)
153 }
154
155 macro_rules! impl_index {
156 ($type:ident) => {
157 impl<E: CubeType, I: Index> CubeIndex<I> for $type<E> {
158 type Output = E;
159 }
160 };
161 }
162 macro_rules! impl_index_vec {
163 ($($type:ident),*) => {
164 $(
165 impl<I: Index> CubeIndex<I> for $type {
166 type Output = Self;
167 }
168 )*
169 };
170 }
171
172 impl_index!(Array);
173 impl_index!(Tensor);
174 impl_index!(SharedMemory);
175 impl_index_vec!(i64, i32, i16, i8, f16, flex32, tf32, bf16, f32, f64, u64, u32, u16, u8);
176
177 impl<E: CubeType, I: Index> CubeIndex<I> for Slice<E> {
178 type Output = E;
179 }
180
181 impl<E: CubeType, I: Index> CubeIndex<I> for SliceMut<E> {
182 type Output = E;
183 }
184}
185
186pub mod add_assign_array_op {
187 use self::ir::Operator;
188 use super::*;
189 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
190
191 pub fn expand<A: CubeType + CubeIndex<u32>>(
192 context: &mut CubeContext,
193 array: ExpandElementTyped<A>,
194 index: ExpandElementTyped<u32>,
195 value: ExpandElementTyped<A::Output>,
196 ) where
197 A::Output: CubeType + Sized,
198 {
199 array_assign_binary_op_expand(context, array, index, value, Operator::Add);
200 }
201}
202
203pub mod sub_assign_array_op {
204 use self::ir::Operator;
205 use super::*;
206 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
207
208 pub fn expand<A: CubeType + CubeIndex<u32>>(
209 context: &mut CubeContext,
210 array: ExpandElementTyped<A>,
211 index: ExpandElementTyped<u32>,
212 value: ExpandElementTyped<A::Output>,
213 ) where
214 A::Output: CubeType + Sized,
215 {
216 array_assign_binary_op_expand(context, array, index, value, Operator::Sub);
217 }
218}
219
220pub mod mul_assign_array_op {
221 use self::ir::Operator;
222 use super::*;
223 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
224
225 pub fn expand<A: CubeType + CubeIndex<u32>>(
226 context: &mut CubeContext,
227 array: ExpandElementTyped<A>,
228 index: ExpandElementTyped<u32>,
229 value: ExpandElementTyped<A::Output>,
230 ) where
231 A::Output: CubeType + Sized,
232 {
233 array_assign_binary_op_expand(context, array, index, value, Operator::Mul);
234 }
235}
236
237pub mod div_assign_array_op {
238 use self::ir::Operator;
239 use super::*;
240 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
241
242 pub fn expand<A: CubeType + CubeIndex<u32>>(
243 context: &mut CubeContext,
244 array: ExpandElementTyped<A>,
245 index: ExpandElementTyped<u32>,
246 value: ExpandElementTyped<A::Output>,
247 ) where
248 A::Output: CubeType + Sized,
249 {
250 array_assign_binary_op_expand(context, array, index, value, Operator::Div);
251 }
252}
253
254pub mod rem_assign_array_op {
255 use self::ir::Operator;
256 use super::*;
257 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
258
259 pub fn expand<A: CubeType + CubeIndex<u32>>(
260 context: &mut CubeContext,
261 array: ExpandElementTyped<A>,
262 index: ExpandElementTyped<u32>,
263 value: ExpandElementTyped<A::Output>,
264 ) where
265 A::Output: CubeType + Sized,
266 {
267 array_assign_binary_op_expand(context, array, index, value, Operator::Modulo);
268 }
269}
270
271pub mod bitor_assign_array_op {
272 use self::ir::Operator;
273 use super::*;
274 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
275
276 pub fn expand<A: CubeType + CubeIndex<u32>>(
277 context: &mut CubeContext,
278 array: ExpandElementTyped<A>,
279 index: ExpandElementTyped<u32>,
280 value: ExpandElementTyped<A::Output>,
281 ) where
282 A::Output: CubeType + Sized,
283 {
284 array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseOr);
285 }
286}
287
288pub mod bitand_assign_array_op {
289 use self::ir::Operator;
290 use super::*;
291 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
292
293 pub fn expand<A: CubeType + CubeIndex<u32>>(
294 context: &mut CubeContext,
295 array: ExpandElementTyped<A>,
296 index: ExpandElementTyped<u32>,
297 value: ExpandElementTyped<A::Output>,
298 ) where
299 A::Output: CubeType + Sized,
300 {
301 array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseAnd);
302 }
303}
304
305pub mod bitxor_assign_array_op {
306 use self::ir::Operator;
307 use super::*;
308 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
309
310 pub fn expand<A: CubeType + CubeIndex<u32>>(
311 context: &mut CubeContext,
312 array: ExpandElementTyped<A>,
313 index: ExpandElementTyped<u32>,
314 value: ExpandElementTyped<A::Output>,
315 ) where
316 A::Output: CubeType + Sized,
317 {
318 array_assign_binary_op_expand(context, array, index, value, Operator::BitwiseXor);
319 }
320}
321
322pub mod shl_assign_array_op {
323 use self::ir::Operator;
324 use super::*;
325 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
326
327 pub fn expand<A: CubeType + CubeIndex<u32>>(
328 context: &mut CubeContext,
329 array: ExpandElementTyped<A>,
330 index: ExpandElementTyped<u32>,
331 value: ExpandElementTyped<u32>,
332 ) where
333 A::Output: CubeType + Sized,
334 {
335 array_assign_binary_op_expand(context, array, index, value, Operator::ShiftLeft);
336 }
337}
338
339pub mod shr_assign_array_op {
340 use self::ir::Operator;
341 use super::*;
342 use crate::prelude::{array_assign_binary_op_expand, CubeType, ExpandElementTyped};
343
344 pub fn expand<A: CubeType + CubeIndex<u32>>(
345 context: &mut CubeContext,
346 array: ExpandElementTyped<A>,
347 index: ExpandElementTyped<u32>,
348 value: ExpandElementTyped<u32>,
349 ) where
350 A::Output: CubeType + Sized,
351 {
352 array_assign_binary_op_expand(context, array, index, value, Operator::ShiftRight);
353 }
354}
355
356pub mod add_assign_op {
357 use std::ops::AddAssign;
358
359 use self::ir::Operator;
360 use crate::{
361 frontend::operation::base::assign_op_expand,
362 prelude::{CubeType, ExpandElementTyped},
363 };
364
365 use super::*;
366
367 pub fn expand<C: CubeType + AddAssign>(
368 context: &mut CubeContext,
369 lhs: ExpandElementTyped<C>,
370 rhs: ExpandElementTyped<C>,
371 ) -> ExpandElementTyped<C> {
372 assign_op_expand(context, lhs.into(), rhs.into(), Operator::Add).into()
373 }
374}
375
376pub mod sub_assign_op {
377 use self::ir::Operator;
378 use super::*;
379 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
380
381 pub fn expand<C: CubeType>(
382 context: &mut CubeContext,
383 lhs: ExpandElementTyped<C>,
384 rhs: ExpandElementTyped<C>,
385 ) -> ExpandElement {
386 assign_op_expand(context, lhs.into(), rhs.into(), Operator::Sub)
387 }
388}
389
390pub mod mul_assign_op {
391 use self::ir::Operator;
392 use super::*;
393 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
394
395 pub fn expand<C: CubeType>(
396 context: &mut CubeContext,
397 lhs: ExpandElementTyped<C>,
398 rhs: ExpandElementTyped<C>,
399 ) -> ExpandElement {
400 assign_op_expand(context, lhs.into(), rhs.into(), Operator::Mul)
401 }
402}
403
404pub mod div_assign_op {
405 use self::ir::Operator;
406 use super::*;
407 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
408
409 pub fn expand<C: CubeType>(
410 context: &mut CubeContext,
411 lhs: ExpandElementTyped<C>,
412 rhs: ExpandElementTyped<C>,
413 ) -> ExpandElement {
414 assign_op_expand(context, lhs.into(), rhs.into(), Operator::Div)
415 }
416}
417
418pub mod rem_assign_op {
419 use self::ir::Operator;
420 use super::*;
421 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
422
423 pub fn expand<C: CubeType>(
424 context: &mut CubeContext,
425 lhs: ExpandElementTyped<C>,
426 rhs: ExpandElementTyped<C>,
427 ) -> ExpandElement {
428 assign_op_expand(context, lhs.into(), rhs.into(), Operator::Modulo)
429 }
430}
431
432pub mod bitor_assign_op {
433 use self::ir::Operator;
434 use super::*;
435 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
436
437 pub fn expand<C: CubeType>(
438 context: &mut CubeContext,
439 lhs: ExpandElementTyped<C>,
440 rhs: ExpandElementTyped<C>,
441 ) -> ExpandElement {
442 assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseOr)
443 }
444}
445
446pub mod bitand_assign_op {
447 use self::ir::Operator;
448 use super::*;
449 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
450
451 pub fn expand<C: CubeType>(
452 context: &mut CubeContext,
453 lhs: ExpandElementTyped<C>,
454 rhs: ExpandElementTyped<C>,
455 ) -> ExpandElement {
456 assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseAnd)
457 }
458}
459
460pub mod bitxor_assign_op {
461 use self::ir::Operator;
462 use super::*;
463 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
464
465 pub fn expand<C: CubeType>(
466 context: &mut CubeContext,
467 lhs: ExpandElementTyped<C>,
468 rhs: ExpandElementTyped<C>,
469 ) -> ExpandElement {
470 assign_op_expand(context, lhs.into(), rhs.into(), Operator::BitwiseXor)
471 }
472}
473
474pub mod shl_assign_op {
475 use self::ir::Operator;
476 use super::*;
477 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
478
479 pub fn expand<C: CubeType>(
480 context: &mut CubeContext,
481 lhs: ExpandElementTyped<C>,
482 rhs: ExpandElementTyped<u32>,
483 ) -> ExpandElement {
484 assign_op_expand(context, lhs.into(), rhs.into(), Operator::ShiftLeft)
485 }
486}
487
488pub mod shr_assign_op {
489 use self::ir::Operator;
490 use super::*;
491 use crate::{frontend::operation::base::assign_op_expand, prelude::ExpandElementTyped};
492
493 pub fn expand<C: CubeType>(
494 context: &mut CubeContext,
495 lhs: ExpandElementTyped<C>,
496 rhs: ExpandElementTyped<u32>,
497 ) -> ExpandElement {
498 assign_op_expand(context, lhs.into(), rhs.into(), Operator::ShiftRight)
499 }
500}