cubecl_core/ir/
macros.rs

1use crate::{flex32, ir::ConstantScalarValue, tf32};
2use half::{bf16, f16};
3
4use super::{FloatKind, IntKind, UIntKind, Variable};
5
6#[macro_export(local_inner_macros)]
7/// Cube Pseudo Assembly.
8macro_rules! cpa {
9    // out = lhs + rhs
10    ($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => {
11        cpa!($scope, $out = add($lhs, $rhs))
12    };
13    // out += input
14    ($scope:expr, $out:ident += $input:ident) => {
15        cpa!($scope, $out = add($out, $input))
16    };
17    // out = add(lhs, rhs)
18    ($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => {
19        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Add(
20            cpa!(binary $lhs, $rhs)
21        ), $out));
22    };
23    // out = lhs - rhs
24    ($scope:expr, $out:ident = $lhs:ident - $rhs:expr) => {
25        cpa!($scope, $out = sub($lhs, $rhs))
26    };
27    // out = sub(lhs, rhs)
28    ($scope:expr, $out:ident = sub($lhs:expr, $rhs:expr)) => {
29        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Sub(
30            cpa!(binary $lhs, $rhs)
31        ), $out));
32    };
33    // out = lhs * rhs
34    ($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => {
35        cpa!($scope, $out = mul($lhs, $rhs))
36    };
37    // out *= input
38    ($scope:expr, $out:ident *= $input:ident) => {
39        cpa!($scope, $out = mul($out, $input))
40    };
41    // out = mul(lhs, rhs)
42    ($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => {
43        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Mul(
44            cpa!(binary $lhs, $rhs)
45        ), $out));
46    };
47    // out = lhs / rhs
48    ($scope:expr, $out:ident = $lhs:ident / $rhs:expr) => {
49        cpa!($scope, $out = div($lhs, $rhs))
50    };
51    // out = div(lhs, rhs)
52    ($scope:expr, $out:ident = div($lhs:expr, $rhs:expr)) => {
53        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Div(
54            cpa!(binary $lhs, $rhs)
55        ), $out));
56    };
57    // out = lhs % rhs
58    ($scope:expr, $out:ident = $lhs:ident % $rhs:expr) => {
59        cpa!($scope, $out = modulo($lhs, $rhs))
60    };
61    // out = modulo(lhs, rhs)
62    ($scope:expr, $out:ident = modulo($lhs:expr, $rhs:expr)) => {
63        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Modulo(
64            cpa!(binary $lhs, $rhs)
65        ), $out));
66    };
67    // out = powf(lhs, rhs)
68    ($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => {
69        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Powf(
70            cpa!(binary $lhs, $rhs)
71        ), $out));
72    };
73    // out = lhs && rhs
74    ($scope:expr, $out:ident = $lhs:ident && $rhs:expr) => {
75        cpa!($scope, $out = and($lhs, $rhs))
76    };
77    // out = and(lhs, rhs)
78    ($scope:expr, $out:ident = and($lhs:expr, $rhs:expr)) => {
79        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::And(
80            cpa!(binary $lhs, $rhs)
81        ), $out));
82    };
83    // out = lhs || rhs
84    ($scope:expr, $out:ident = $lhs:ident || $rhs:expr) => {
85        cpa!($scope, $out = or($lhs, $rhs))
86    };
87    // out = or(lhs, rhs)
88    ($scope:expr, $out:ident = or($lhs:expr, $rhs:expr)) => {
89        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Or(
90            cpa!(binary $lhs, $rhs)
91        ), $out));
92    };
93    // out = !input
94    ($scope:expr, $out:ident = !$input:expr) => {
95        cpa!($scope, $out = not($input))
96    };
97    // out = not(input)
98    ($scope:expr, $out:ident = not($input:expr)) => {
99        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Not(
100            cpa!(unary $input)
101        ), $out));
102    };
103    // out = lhs & rhs
104    ($scope:expr, $out: ident = $lhs:ident & $rhs:ident) => {
105        cpa!($scope, $out = bitwise_and($lhs, $rhs))
106    };
107    // out = bitwise_and(lhs, rhs)
108    ($scope:expr, $out:ident = bitwise_and($lhs:expr, $rhs:expr)) => {
109        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::BitwiseAnd(
110            cpa!(binary $lhs, $rhs)
111        ), $out));
112    };
113    // out = lhs ^ rhs
114    ($scope:expr, $out: ident = $lhs:ident ^ $rhs:ident) => {
115        cpa!($scope, $out = bitwise_xor($lhs, $rhs))
116    };
117    // out = bitwise_xor(lhs, rhs)
118    ($scope:expr, $out:ident = bitwise_xor($lhs:expr, $rhs:expr)) => {
119        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::BitwiseXor(
120            cpa!(binary $lhs, $rhs)
121        ), $out));
122    };
123    // out = select(cond, then, or_else)
124    ($scope:expr, $out:ident = select($cond:expr, $then:expr, $or_else:expr)) => {
125        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Select($crate::ir::Select{
126            cond: $cond,
127            then: $then,
128            or_else: $or_else,
129        }), $out));
130    };
131    // out = lhs << rhs
132    ($scope:expr, $out: ident = $lhs:ident << $rhs:ident) => {
133        cpa!($scope, $out = shift_left($lhs, $rhs))
134    };
135    // out = shift_left(lhs, rhs)
136    ($scope:expr, $out:ident = shift_left($lhs:expr, $rhs:expr)) => {
137        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::ShiftLeft(
138            cpa!(binary $lhs, $rhs)
139        ), $out));
140    };
141    // out = lhs >> rhs
142    ($scope:expr, $out: ident = $lhs:ident >> $rhs:ident) => {
143        cpa!($scope, $out = shift_right($lhs, $rhs))
144    };
145    // out = shift_right(lhs, rhs)
146    ($scope:expr, $out:ident = shift_right($lhs:expr, $rhs:expr)) => {
147        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::ShiftRight(
148            cpa!(binary $lhs, $rhs)
149        ), $out));
150    };
151    // out = lhs == rhs
152    ($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
153        cpa!($scope, $out = equal($lhs, $rhs))
154    };
155    // out = equal(lhs, rhs)
156    ($scope:expr, $out:ident = equal($lhs:expr, $rhs:expr)) => {
157        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Equal(
158            cpa!(binary $lhs, $rhs)
159        ), $out));
160    };
161    // out = lhs != rhs
162    ($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => {
163        cpa!($scope, $out = not_equal($lhs, $rhs))
164    };
165    // out = not_equal(lhs, rhs)
166    ($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => {
167        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::NotEqual(
168            cpa!(binary $lhs, $rhs)
169        ), $out));
170    };
171    // out = lhs > rhs
172    ($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
173        cpa!($scope, $out = greater($lhs, $rhs))
174    };
175    // out = greater(lhs, rhs)
176    ($scope:expr, $out:ident = greater($lhs:expr, $rhs:expr)) => {
177        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Greater(
178            cpa!(binary $lhs, $rhs)
179        ), $out));
180    };
181    // out = lhs >= rhs
182    ($scope:expr, $out:ident = $lhs:ident >= $rhs:expr) => {
183        cpa!($scope, $out = greater_equal($lhs, $rhs))
184    };
185    // out = greater_equal(lhs, rhs)
186    ($scope:expr, $out:ident = greater_equal($lhs:expr, $rhs:expr)) => {
187        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::GreaterEqual(
188            cpa!(binary $lhs, $rhs)
189        ), $out));
190    };
191    // out = lhs < rhs
192    ($scope:expr, $out:ident = $lhs:ident < $rhs:expr) => {
193        cpa!($scope, $out = lower($lhs, $rhs))
194    };
195    // out = lower(lhs, rhs)
196    ($scope:expr, $out:ident = lower($lhs:expr, $rhs:expr)) => {
197        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Lower(
198            cpa!(binary $lhs, $rhs)
199        ), $out));
200    };
201    // out = lhs <= rhs
202    ($scope:expr, $out:ident = $lhs:ident <= $rhs:expr) => {
203        cpa!($scope, $out = lower_equal($lhs, $rhs))
204    };
205    // out = lower_equal(lhs, rhs)
206    ($scope:expr, $out:ident = lower_equal($lhs:expr, $rhs:expr)) => {
207        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::LowerEqual(
208            cpa!(binary $lhs, $rhs)
209        ), $out));
210    };
211    // out = max(lhs, rhs)
212    ($scope:expr, $out:ident = max($lhs:expr, $rhs:expr)) => {
213        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Max(
214            cpa!(binary $lhs, $rhs)
215        ), $out));
216    };
217    // out = min(lhs, rhs)
218    ($scope:expr, $out:ident = min($lhs:expr, $rhs:expr)) => {
219        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Min(
220            cpa!(binary $lhs, $rhs)
221        ), $out));
222    };
223    // out = lhs[rhs]
224    ($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => {
225        cpa!($scope, $out = index($lhs, $rhs))
226    };
227    // out = index(lhs, rhs)
228    ($scope:expr, $out:ident = index($lhs:expr, $rhs:expr)) => {
229        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Index(
230            cpa!(binary $lhs, $rhs)
231        ), $out));
232    };
233    // out = unchecked(lhs[rhs])
234    ($scope:expr, $out:ident = unchecked($lhs:ident[$rhs:expr])) => {
235        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::UncheckedIndex(
236            cpa!(binary $lhs, $rhs)
237        ), $out));
238    };
239    // out[lhs] = rhs
240    ($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => {
241        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::IndexAssign(
242            cpa!(binary $lhs, $rhs)
243        ), $out));
244    };
245    // unchecked(out[lhs]) = rhs
246    ($scope:expr, unchecked($out:ident[$lhs:ident]) = $rhs:expr) => {
247        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::UncheckedIndexAssign(
248            cpa!(binary $lhs, $rhs)
249        ), $out));
250    };
251    // out = |input|
252    ($scope:expr, $out:ident = |$input:ident|) => {
253        cpa!($scope, $out = abs($input))
254    };
255    // out = abs(input)
256    ($scope:expr, $out:ident = abs($input:expr)) => {
257        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Abs(
258            cpa!(unary $input)
259        ), $out));
260    };
261    // out = exp(input)
262    ($scope:expr, $out:ident = exp($input:expr)) => {
263        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Exp(
264            cpa!(unary $input)
265        ), $out));
266    };
267    // out = log(input)
268    ($scope:expr, $out:ident = log($input:expr)) => {
269        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Log(
270            cpa!(unary $input)
271        ), $out));
272    };
273    // out = log1p(input)
274    ($scope:expr, $out:ident = log1p($input:expr)) => {
275        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Log1p(
276            cpa!(unary $input)
277        ), $out));
278    };
279    // out = cos(input)
280    ($scope:expr, $out:ident = cos($input:expr)) => {
281        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Cos(
282            cpa!(unary $input)
283        ), $out));
284    };
285    ($scope:expr, $out:ident = normalize($input:expr)) => {
286        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Normalize(
287            cpa!(unary $input)
288        ), $out));
289    };
290    // out = sin(input)
291    ($scope:expr, $out:ident = sin($input:expr)) => {
292        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Sin(
293            cpa!(unary $input)
294        ), $out));
295    };
296    // out = tanh(input)
297    ($scope:expr, $out:ident = tanh($input:expr)) => {
298        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Tanh(
299            cpa!(unary $input)
300        ), $out));
301    };
302    // out = sqrt(input)
303    ($scope:expr, $out:ident = sqrt($input:expr)) => {
304        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Sqrt(
305            cpa!(unary $input)
306        ), $out));
307    };
308    // out = floor(input)
309    ($scope:expr, $out:ident = floor($input:expr)) => {
310        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Floor(
311            cpa!(unary $input)
312        ), $out));
313    };
314    // out = ceil(input)
315    ($scope:expr, $out:ident = ceil($input:expr)) => {
316        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Ceil(
317            cpa!(unary $input)
318        ), $out));
319    };
320    // out = erf(input)
321    ($scope:expr, $out:ident = erf($input:expr)) => {
322        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Erf(
323            cpa!(unary $input)
324        ), $out));
325    };
326    // out = input
327    ($scope:expr, $out:ident = $input:ident) => {
328        $scope.register($crate::ir::Instruction::new($crate::ir::Operation::Copy(
329            $input
330        ), $out));
331    };
332    // out = vec4(a, b, c, d)
333    ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => {
334        let i = $scope.zero(Elem::UInt);
335        cpa!($scope, $out[i] = $a);
336        cpa!($scope, i = i + 1u32);
337        cpa!($scope, $out[i] = $b);
338        cpa!($scope, i = i + 1u32);
339        cpa!($scope, $out[i] = $c);
340        cpa!($scope, i = i + 1u32);
341        cpa!($scope, $out[i] = $d);
342    };
343    // out = input
344    ($scope:expr, $out:ident = $input:ident) => {
345        cpa!($scope, $out = cast($input))
346    };
347    // out = cast(input)
348    ($scope:expr, $out:ident = cast($input:expr)) => {
349        $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Cast(
350            cpa!(unary $input)
351        ), $out));
352    };
353    // out = shape(tensor, dim)
354    ($scope:expr, $out:ident = shape($input:expr, $dim:expr)) => {
355        $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::Shape {
356            dim: $dim.into(),
357            var: $input.into(),
358        }, $out));
359    };
360    // out = stride(tensor, dim)
361    ($scope:expr, $out:ident = stride($input:expr, $dim:expr)) => {
362        $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::Stride {
363            dim: $dim.into(),
364            var: $input.into(),
365        }, $out));
366    };
367    // out = len(array)
368    ($scope:expr, $out:ident = len($input:expr)) => {
369        $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::Length {
370            var: $input.into(),
371        }, $out));
372    };
373    // out = buffer_len(array)
374    ($scope:expr, $out:ident = buffer_len($input:expr)) => {
375        $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::BufferLength {
376            var: $input.into(),
377        }, $out));
378    };
379    // range(start, end).for_each(|i, scope| { ... })
380    ($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
381        $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg);
382    };
383    // range(start, end, unroll).for_each(|i, scope| { ... })
384    ($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => {
385        if $unroll {
386            $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg);
387        } else {
388            $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg);
389        }
390    };
391        // range_stepped(start, end, step).for_each(|i, scope| { ... })
392        ($scope:expr, range($start:expr, $end:expr, $step:expr).for_each($arg:expr)) => {
393            $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
394        };
395        // range_stepped(start, end, step, unroll).for_each(|i, scope| { ... })
396        ($scope:expr, range($start:expr, $end:expr, $step:expr, $unroll:expr).for_each($arg:expr)) => {
397            if $unroll {
398                $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
399            } else {
400                $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
401            }
402        };
403    // loop(|scope| { ... })
404    ($scope:expr, loop($arg:expr)) => {
405        $crate::ir::Loop::register($scope, $arg);
406    };
407    // if (cond).then(|scope| { ... })
408    ($scope:expr, if ($cond:expr).then($arg:expr)) => {
409        $crate::ir::If::register($scope, $cond.into(), $arg);
410    };
411    // if (cond).then(|scope| { ... }).else(|scope| { ... })
412    ($scope:expr, if ($cond:expr).then($arg_if:expr).else($arg_else:expr)) => {
413        $crate::ir::IfElse::register($scope, $cond.into(), $arg_if, $arg_else);
414    };
415    (binary $lhs:expr, $rhs:expr) => {
416        $crate::ir::BinaryOperator {
417            lhs: $lhs.into(),
418            rhs: $rhs.into(),
419        }
420    };
421    (unary $input:expr) => {
422        $crate::ir::UnaryOperator {
423            input: $input.into(),
424        }
425    };
426}
427
428impl From<bool> for Variable {
429    fn from(value: bool) -> Self {
430        Variable::constant(ConstantScalarValue::Bool(value))
431    }
432}
433
434impl From<i8> for Variable {
435    fn from(value: i8) -> Self {
436        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
437    }
438}
439
440impl From<i16> for Variable {
441    fn from(value: i16) -> Self {
442        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
443    }
444}
445
446impl From<i32> for Variable {
447    fn from(value: i32) -> Self {
448        Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
449    }
450}
451
452impl From<i64> for Variable {
453    fn from(value: i64) -> Self {
454        Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
455    }
456}
457
458impl From<f16> for Variable {
459    fn from(value: f16) -> Self {
460        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
461    }
462}
463
464impl From<bf16> for Variable {
465    fn from(value: bf16) -> Self {
466        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
467    }
468}
469
470impl From<flex32> for Variable {
471    fn from(value: flex32) -> Self {
472        Variable::constant(ConstantScalarValue::Float(
473            value.to_f64(),
474            FloatKind::Flex32,
475        ))
476    }
477}
478
479impl From<tf32> for Variable {
480    fn from(value: tf32) -> Self {
481        Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
482    }
483}
484
485impl From<f32> for Variable {
486    fn from(value: f32) -> Self {
487        Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
488    }
489}
490
491impl From<f64> for Variable {
492    fn from(value: f64) -> Self {
493        Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
494    }
495}
496
497impl From<u8> for Variable {
498    fn from(value: u8) -> Self {
499        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
500    }
501}
502
503impl From<u16> for Variable {
504    fn from(value: u16) -> Self {
505        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
506    }
507}
508
509impl From<u32> for Variable {
510    fn from(value: u32) -> Self {
511        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
512    }
513}
514
515impl From<u64> for Variable {
516    fn from(value: u64) -> Self {
517        Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
518    }
519}
520
521impl From<usize> for Variable {
522    fn from(value: usize) -> Self {
523        Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
524    }
525}
526
527pub(crate) use cpa;