1use crate::{flex32, ir::ConstantScalarValue, tf32};
2use half::{bf16, f16};
3
4use super::{FloatKind, IntKind, UIntKind, Variable};
5
6#[macro_export(local_inner_macros)]
7macro_rules! cpa {
9 ($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => {
11 cpa!($scope, $out = add($lhs, $rhs))
12 };
13 ($scope:expr, $out:ident += $input:ident) => {
15 cpa!($scope, $out = add($out, $input))
16 };
17 ($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => {
19 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Add(
20 cpa!(binary $lhs, $rhs)
21 ), $out));
22 };
23 ($scope:expr, $out:ident = $lhs:ident - $rhs:expr) => {
25 cpa!($scope, $out = sub($lhs, $rhs))
26 };
27 ($scope:expr, $out:ident = sub($lhs:expr, $rhs:expr)) => {
29 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Sub(
30 cpa!(binary $lhs, $rhs)
31 ), $out));
32 };
33 ($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => {
35 cpa!($scope, $out = mul($lhs, $rhs))
36 };
37 ($scope:expr, $out:ident *= $input:ident) => {
39 cpa!($scope, $out = mul($out, $input))
40 };
41 ($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => {
43 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Mul(
44 cpa!(binary $lhs, $rhs)
45 ), $out));
46 };
47 ($scope:expr, $out:ident = $lhs:ident / $rhs:expr) => {
49 cpa!($scope, $out = div($lhs, $rhs))
50 };
51 ($scope:expr, $out:ident = div($lhs:expr, $rhs:expr)) => {
53 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Div(
54 cpa!(binary $lhs, $rhs)
55 ), $out));
56 };
57 ($scope:expr, $out:ident = $lhs:ident % $rhs:expr) => {
59 cpa!($scope, $out = modulo($lhs, $rhs))
60 };
61 ($scope:expr, $out:ident = modulo($lhs:expr, $rhs:expr)) => {
63 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Modulo(
64 cpa!(binary $lhs, $rhs)
65 ), $out));
66 };
67 ($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => {
69 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Powf(
70 cpa!(binary $lhs, $rhs)
71 ), $out));
72 };
73 ($scope:expr, $out:ident = $lhs:ident && $rhs:expr) => {
75 cpa!($scope, $out = and($lhs, $rhs))
76 };
77 ($scope:expr, $out:ident = and($lhs:expr, $rhs:expr)) => {
79 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::And(
80 cpa!(binary $lhs, $rhs)
81 ), $out));
82 };
83 ($scope:expr, $out:ident = $lhs:ident || $rhs:expr) => {
85 cpa!($scope, $out = or($lhs, $rhs))
86 };
87 ($scope:expr, $out:ident = or($lhs:expr, $rhs:expr)) => {
89 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Or(
90 cpa!(binary $lhs, $rhs)
91 ), $out));
92 };
93 ($scope:expr, $out:ident = !$input:expr) => {
95 cpa!($scope, $out = not($input))
96 };
97 ($scope:expr, $out:ident = not($input:expr)) => {
99 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Not(
100 cpa!(unary $input)
101 ), $out));
102 };
103 ($scope:expr, $out: ident = $lhs:ident & $rhs:ident) => {
105 cpa!($scope, $out = bitwise_and($lhs, $rhs))
106 };
107 ($scope:expr, $out:ident = bitwise_and($lhs:expr, $rhs:expr)) => {
109 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::BitwiseAnd(
110 cpa!(binary $lhs, $rhs)
111 ), $out));
112 };
113 ($scope:expr, $out: ident = $lhs:ident ^ $rhs:ident) => {
115 cpa!($scope, $out = bitwise_xor($lhs, $rhs))
116 };
117 ($scope:expr, $out:ident = bitwise_xor($lhs:expr, $rhs:expr)) => {
119 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::BitwiseXor(
120 cpa!(binary $lhs, $rhs)
121 ), $out));
122 };
123 ($scope:expr, $out:ident = select($cond:expr, $then:expr, $or_else:expr)) => {
125 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Select($crate::ir::Select{
126 cond: $cond,
127 then: $then,
128 or_else: $or_else,
129 }), $out));
130 };
131 ($scope:expr, $out: ident = $lhs:ident << $rhs:ident) => {
133 cpa!($scope, $out = shift_left($lhs, $rhs))
134 };
135 ($scope:expr, $out:ident = shift_left($lhs:expr, $rhs:expr)) => {
137 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::ShiftLeft(
138 cpa!(binary $lhs, $rhs)
139 ), $out));
140 };
141 ($scope:expr, $out: ident = $lhs:ident >> $rhs:ident) => {
143 cpa!($scope, $out = shift_right($lhs, $rhs))
144 };
145 ($scope:expr, $out:ident = shift_right($lhs:expr, $rhs:expr)) => {
147 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::ShiftRight(
148 cpa!(binary $lhs, $rhs)
149 ), $out));
150 };
151 ($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
153 cpa!($scope, $out = equal($lhs, $rhs))
154 };
155 ($scope:expr, $out:ident = equal($lhs:expr, $rhs:expr)) => {
157 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Equal(
158 cpa!(binary $lhs, $rhs)
159 ), $out));
160 };
161 ($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => {
163 cpa!($scope, $out = not_equal($lhs, $rhs))
164 };
165 ($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => {
167 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::NotEqual(
168 cpa!(binary $lhs, $rhs)
169 ), $out));
170 };
171 ($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
173 cpa!($scope, $out = greater($lhs, $rhs))
174 };
175 ($scope:expr, $out:ident = greater($lhs:expr, $rhs:expr)) => {
177 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Greater(
178 cpa!(binary $lhs, $rhs)
179 ), $out));
180 };
181 ($scope:expr, $out:ident = $lhs:ident >= $rhs:expr) => {
183 cpa!($scope, $out = greater_equal($lhs, $rhs))
184 };
185 ($scope:expr, $out:ident = greater_equal($lhs:expr, $rhs:expr)) => {
187 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::GreaterEqual(
188 cpa!(binary $lhs, $rhs)
189 ), $out));
190 };
191 ($scope:expr, $out:ident = $lhs:ident < $rhs:expr) => {
193 cpa!($scope, $out = lower($lhs, $rhs))
194 };
195 ($scope:expr, $out:ident = lower($lhs:expr, $rhs:expr)) => {
197 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Lower(
198 cpa!(binary $lhs, $rhs)
199 ), $out));
200 };
201 ($scope:expr, $out:ident = $lhs:ident <= $rhs:expr) => {
203 cpa!($scope, $out = lower_equal($lhs, $rhs))
204 };
205 ($scope:expr, $out:ident = lower_equal($lhs:expr, $rhs:expr)) => {
207 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::LowerEqual(
208 cpa!(binary $lhs, $rhs)
209 ), $out));
210 };
211 ($scope:expr, $out:ident = max($lhs:expr, $rhs:expr)) => {
213 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Max(
214 cpa!(binary $lhs, $rhs)
215 ), $out));
216 };
217 ($scope:expr, $out:ident = min($lhs:expr, $rhs:expr)) => {
219 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Min(
220 cpa!(binary $lhs, $rhs)
221 ), $out));
222 };
223 ($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => {
225 cpa!($scope, $out = index($lhs, $rhs))
226 };
227 ($scope:expr, $out:ident = index($lhs:expr, $rhs:expr)) => {
229 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Index(
230 cpa!(binary $lhs, $rhs)
231 ), $out));
232 };
233 ($scope:expr, $out:ident = unchecked($lhs:ident[$rhs:expr])) => {
235 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::UncheckedIndex(
236 cpa!(binary $lhs, $rhs)
237 ), $out));
238 };
239 ($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => {
241 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::IndexAssign(
242 cpa!(binary $lhs, $rhs)
243 ), $out));
244 };
245 ($scope:expr, unchecked($out:ident[$lhs:ident]) = $rhs:expr) => {
247 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::UncheckedIndexAssign(
248 cpa!(binary $lhs, $rhs)
249 ), $out));
250 };
251 ($scope:expr, $out:ident = |$input:ident|) => {
253 cpa!($scope, $out = abs($input))
254 };
255 ($scope:expr, $out:ident = abs($input:expr)) => {
257 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Abs(
258 cpa!(unary $input)
259 ), $out));
260 };
261 ($scope:expr, $out:ident = exp($input:expr)) => {
263 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Exp(
264 cpa!(unary $input)
265 ), $out));
266 };
267 ($scope:expr, $out:ident = log($input:expr)) => {
269 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Log(
270 cpa!(unary $input)
271 ), $out));
272 };
273 ($scope:expr, $out:ident = log1p($input:expr)) => {
275 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Log1p(
276 cpa!(unary $input)
277 ), $out));
278 };
279 ($scope:expr, $out:ident = cos($input:expr)) => {
281 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Cos(
282 cpa!(unary $input)
283 ), $out));
284 };
285 ($scope:expr, $out:ident = normalize($input:expr)) => {
286 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Normalize(
287 cpa!(unary $input)
288 ), $out));
289 };
290 ($scope:expr, $out:ident = sin($input:expr)) => {
292 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Sin(
293 cpa!(unary $input)
294 ), $out));
295 };
296 ($scope:expr, $out:ident = tanh($input:expr)) => {
298 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Tanh(
299 cpa!(unary $input)
300 ), $out));
301 };
302 ($scope:expr, $out:ident = sqrt($input:expr)) => {
304 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Sqrt(
305 cpa!(unary $input)
306 ), $out));
307 };
308 ($scope:expr, $out:ident = floor($input:expr)) => {
310 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Floor(
311 cpa!(unary $input)
312 ), $out));
313 };
314 ($scope:expr, $out:ident = ceil($input:expr)) => {
316 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Ceil(
317 cpa!(unary $input)
318 ), $out));
319 };
320 ($scope:expr, $out:ident = erf($input:expr)) => {
322 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Erf(
323 cpa!(unary $input)
324 ), $out));
325 };
326 ($scope:expr, $out:ident = $input:ident) => {
328 $scope.register($crate::ir::Instruction::new($crate::ir::Operation::Copy(
329 $input
330 ), $out));
331 };
332 ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => {
334 let i = $scope.zero(Elem::UInt);
335 cpa!($scope, $out[i] = $a);
336 cpa!($scope, i = i + 1u32);
337 cpa!($scope, $out[i] = $b);
338 cpa!($scope, i = i + 1u32);
339 cpa!($scope, $out[i] = $c);
340 cpa!($scope, i = i + 1u32);
341 cpa!($scope, $out[i] = $d);
342 };
343 ($scope:expr, $out:ident = $input:ident) => {
345 cpa!($scope, $out = cast($input))
346 };
347 ($scope:expr, $out:ident = cast($input:expr)) => {
349 $scope.register($crate::ir::Instruction::new($crate::ir::Operator::Cast(
350 cpa!(unary $input)
351 ), $out));
352 };
353 ($scope:expr, $out:ident = shape($input:expr, $dim:expr)) => {
355 $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::Shape {
356 dim: $dim.into(),
357 var: $input.into(),
358 }, $out));
359 };
360 ($scope:expr, $out:ident = stride($input:expr, $dim:expr)) => {
362 $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::Stride {
363 dim: $dim.into(),
364 var: $input.into(),
365 }, $out));
366 };
367 ($scope:expr, $out:ident = len($input:expr)) => {
369 $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::Length {
370 var: $input.into(),
371 }, $out));
372 };
373 ($scope:expr, $out:ident = buffer_len($input:expr)) => {
375 $scope.register($crate::ir::Instruction::new($crate::ir::Metadata::BufferLength {
376 var: $input.into(),
377 }, $out));
378 };
379 ($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
381 $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg);
382 };
383 ($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => {
385 if $unroll {
386 $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg);
387 } else {
388 $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, false, $arg);
389 }
390 };
391 ($scope:expr, range($start:expr, $end:expr, $step:expr).for_each($arg:expr)) => {
393 $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
394 };
395 ($scope:expr, range($start:expr, $end:expr, $step:expr, $unroll:expr).for_each($arg:expr)) => {
397 if $unroll {
398 $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
399 } else {
400 $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
401 }
402 };
403 ($scope:expr, loop($arg:expr)) => {
405 $crate::ir::Loop::register($scope, $arg);
406 };
407 ($scope:expr, if ($cond:expr).then($arg:expr)) => {
409 $crate::ir::If::register($scope, $cond.into(), $arg);
410 };
411 ($scope:expr, if ($cond:expr).then($arg_if:expr).else($arg_else:expr)) => {
413 $crate::ir::IfElse::register($scope, $cond.into(), $arg_if, $arg_else);
414 };
415 (binary $lhs:expr, $rhs:expr) => {
416 $crate::ir::BinaryOperator {
417 lhs: $lhs.into(),
418 rhs: $rhs.into(),
419 }
420 };
421 (unary $input:expr) => {
422 $crate::ir::UnaryOperator {
423 input: $input.into(),
424 }
425 };
426}
427
428impl From<bool> for Variable {
429 fn from(value: bool) -> Self {
430 Variable::constant(ConstantScalarValue::Bool(value))
431 }
432}
433
434impl From<i8> for Variable {
435 fn from(value: i8) -> Self {
436 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I8))
437 }
438}
439
440impl From<i16> for Variable {
441 fn from(value: i16) -> Self {
442 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I16))
443 }
444}
445
446impl From<i32> for Variable {
447 fn from(value: i32) -> Self {
448 Variable::constant(ConstantScalarValue::Int(value as i64, IntKind::I32))
449 }
450}
451
452impl From<i64> for Variable {
453 fn from(value: i64) -> Self {
454 Variable::constant(ConstantScalarValue::Int(value, IntKind::I64))
455 }
456}
457
458impl From<f16> for Variable {
459 fn from(value: f16) -> Self {
460 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::F16))
461 }
462}
463
464impl From<bf16> for Variable {
465 fn from(value: bf16) -> Self {
466 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::BF16))
467 }
468}
469
470impl From<flex32> for Variable {
471 fn from(value: flex32) -> Self {
472 Variable::constant(ConstantScalarValue::Float(
473 value.to_f64(),
474 FloatKind::Flex32,
475 ))
476 }
477}
478
479impl From<tf32> for Variable {
480 fn from(value: tf32) -> Self {
481 Variable::constant(ConstantScalarValue::Float(value.to_f64(), FloatKind::TF32))
482 }
483}
484
485impl From<f32> for Variable {
486 fn from(value: f32) -> Self {
487 Variable::constant(ConstantScalarValue::Float(value as f64, FloatKind::F32))
488 }
489}
490
491impl From<f64> for Variable {
492 fn from(value: f64) -> Self {
493 Variable::constant(ConstantScalarValue::Float(value, FloatKind::F64))
494 }
495}
496
497impl From<u8> for Variable {
498 fn from(value: u8) -> Self {
499 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U8))
500 }
501}
502
503impl From<u16> for Variable {
504 fn from(value: u16) -> Self {
505 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U16))
506 }
507}
508
509impl From<u32> for Variable {
510 fn from(value: u32) -> Self {
511 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
512 }
513}
514
515impl From<u64> for Variable {
516 fn from(value: u64) -> Self {
517 Variable::constant(ConstantScalarValue::UInt(value, UIntKind::U64))
518 }
519}
520
521impl From<usize> for Variable {
522 fn from(value: usize) -> Self {
523 Variable::constant(ConstantScalarValue::UInt(value as u64, UIntKind::U32))
524 }
525}
526
527pub(crate) use cpa;