1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
use crate::ir::ConstantScalarValue;

use super::{FloatKind, IntKind, Variable};

#[macro_export(local_inner_macros)]
/// Cube Pseudo Assembly.
macro_rules! cpa {
    // out = lhs + rhs
    ($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => {
        cpa!($scope, $out = add($lhs, $rhs))
    };
    // out += input
    ($scope:expr, $out:ident += $input:ident) => {
        cpa!($scope, $out = add($out, $input))
    };
    // out = add(lhs, rhs)
    ($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Add(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs - rhs
    ($scope:expr, $out:ident = $lhs:ident - $rhs:expr) => {
        cpa!($scope, $out = sub($lhs, $rhs))
    };
    // out = sub(lhs, rhs)
    ($scope:expr, $out:ident = sub($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Sub(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs * rhs
    ($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => {
        cpa!($scope, $out = mul($lhs, $rhs))
    };
    // out *= input
    ($scope:expr, $out:ident *= $input:ident) => {
        cpa!($scope, $out = mul($out, $input))
    };
    // out = mul(lhs, rhs)
    ($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Mul(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs / rhs
    ($scope:expr, $out:ident = $lhs:ident / $rhs:expr) => {
        cpa!($scope, $out = div($lhs, $rhs))
    };
    // out = div(lhs, rhs)
    ($scope:expr, $out:ident = div($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Div(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs % rhs
    ($scope:expr, $out:ident = $lhs:ident % $rhs:expr) => {
        cpa!($scope, $out = modulo($lhs, $rhs))
    };
    // out = modulo(lhs, rhs)
    ($scope:expr, $out:ident = modulo($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Modulo(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = powf(lhs, rhs)
    ($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Powf(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs && rhs
    ($scope:expr, $out:ident = $lhs:ident && $rhs:expr) => {
        cpa!($scope, $out = and($lhs, $rhs))
    };
    // out = and(lhs, rhs)
    ($scope:expr, $out:ident = and($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::And(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs || rhs
    ($scope:expr, $out:ident = $lhs:ident || $rhs:expr) => {
        cpa!($scope, $out = or($lhs, $rhs))
    };
    // out = or(lhs, rhs)
    ($scope:expr, $out:ident = or($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Or(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = !input
    ($scope:expr, $out:ident = !$input:expr) => {
        cpa!($scope, $out = not($input))
    };
    // out = not(input)
    ($scope:expr, $out:ident = not($input:expr)) => {
        $scope.register($crate::ir::Operator::Not(
            cpa!(unary $input, $out)
        ));
    };
    // out = lhs & rhs
    ($scope:expr, $out: ident = $lhs:ident & $rhs:ident) => {
        cpa!($scope, $out = bitwise_and($lhs, $rhs))
    };
    // out = bitwise_and(lhs, rhs)
    ($scope:expr, $out:ident = bitwise_and($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::BitwiseAnd(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs ^ rhs
    ($scope:expr, $out: ident = $lhs:ident ^ $rhs:ident) => {
        cpa!($scope, $out = bitwise_xor($lhs, $rhs))
    };
    // out = bitwise_xor(lhs, rhs)
    ($scope:expr, $out:ident = bitwise_xor($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::BitwiseXor(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs << rhs
    ($scope:expr, $out: ident = $lhs:ident << $rhs:ident) => {
        cpa!($scope, $out = shift_left($lhs, $rhs))
    };
    // out = shift_left(lhs, rhs)
    ($scope:expr, $out:ident = shift_left($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::ShiftLeft(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs >> rhs
    ($scope:expr, $out: ident = $lhs:ident >> $rhs:ident) => {
        cpa!($scope, $out = shift_right($lhs, $rhs))
    };
    // out = shift_right(lhs, rhs)
    ($scope:expr, $out:ident = shift_right($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::ShiftRight(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs == rhs
    ($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
        cpa!($scope, $out = equal($lhs, $rhs))
    };
    // out = equal(lhs, rhs)
    ($scope:expr, $out:ident = equal($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Equal(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs != rhs
    ($scope:expr, $out:ident = $lhs:ident != $rhs:expr) => {
        cpa!($scope, $out = not_equal($lhs, $rhs))
    };
    // out = not_equal(lhs, rhs)
    ($scope:expr, $out:ident = not_equal($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::NotEqual(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs > rhs
    ($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
        cpa!($scope, $out = greater($lhs, $rhs))
    };
    // out = greater(lhs, rhs)
    ($scope:expr, $out:ident = greater($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Greater(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs >= rhs
    ($scope:expr, $out:ident = $lhs:ident >= $rhs:expr) => {
        cpa!($scope, $out = greater_equal($lhs, $rhs))
    };
    // out = greater_equal(lhs, rhs)
    ($scope:expr, $out:ident = greater_equal($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::GreaterEqual(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs < rhs
    ($scope:expr, $out:ident = $lhs:ident < $rhs:expr) => {
        cpa!($scope, $out = lower($lhs, $rhs))
    };
    // out = lower(lhs, rhs)
    ($scope:expr, $out:ident = lower($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Lower(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs <= rhs
    ($scope:expr, $out:ident = $lhs:ident <= $rhs:expr) => {
        cpa!($scope, $out = lower_equal($lhs, $rhs))
    };
    // out = lower_equal(lhs, rhs)
    ($scope:expr, $out:ident = lower_equal($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::LowerEqual(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = max(lhs, rhs)
    ($scope:expr, $out:ident = max($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Max(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = min(lhs, rhs)
    ($scope:expr, $out:ident = min($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Min(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = lhs[rhs]
    ($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => {
        cpa!($scope, $out = index($lhs, $rhs))
    };
    // out = index(lhs, rhs)
    ($scope:expr, $out:ident = index($lhs:expr, $rhs:expr)) => {
        $scope.register($crate::ir::Operator::Index(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = unchecked(lhs[rhs])
    ($scope:expr, $out:ident = unchecked($lhs:ident[$rhs:expr])) => {
        $scope.register($crate::ir::Operator::UncheckedIndex(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out[lhs] = rhs
    ($scope:expr, $out:ident[$lhs:ident] = $rhs:expr) => {
        $scope.register($crate::ir::Operator::IndexAssign(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // unchecked(out[lhs]) = rhs
    ($scope:expr, unchecked($out:ident[$lhs:ident]) = $rhs:expr) => {
        $scope.register($crate::ir::Operator::UncheckedIndexAssign(
            cpa!(binary $lhs, $rhs, $out)
        ));
    };
    // out = |input|
    ($scope:expr, $out:ident = |$input:ident|) => {
        cpa!($scope, $out = abs($input))
    };
    // out = abs(input)
    ($scope:expr, $out:ident = abs($input:expr)) => {
        $scope.register($crate::ir::Operator::Abs(
            cpa!(unary $input, $out)
        ));
    };
    // out = exp(input)
    ($scope:expr, $out:ident = exp($input:expr)) => {
        $scope.register($crate::ir::Operator::Exp(
            cpa!(unary $input, $out)
        ));
    };
    // out = log(input)
    ($scope:expr, $out:ident = log($input:expr)) => {
        $scope.register($crate::ir::Operator::Log(
            cpa!(unary $input, $out)
        ));
    };
    // out = log1p(input)
    ($scope:expr, $out:ident = log1p($input:expr)) => {
        $scope.register($crate::ir::Operator::Log1p(
            cpa!(unary $input, $out)
        ));
    };
    // out = cos(input)
    ($scope:expr, $out:ident = cos($input:expr)) => {
        $scope.register($crate::ir::Operator::Cos(
            cpa!(unary $input, $out)
        ));
    };
    // out = sin(input)
    ($scope:expr, $out:ident = sin($input:expr)) => {
        $scope.register($crate::ir::Operator::Sin(
            cpa!(unary $input, $out)
        ));
    };
    // out = tanh(input)
    ($scope:expr, $out:ident = tanh($input:expr)) => {
        $scope.register($crate::ir::Operator::Tanh(
            cpa!(unary $input, $out)
        ));
    };
    // out = sqrt(input)
    ($scope:expr, $out:ident = sqrt($input:expr)) => {
        $scope.register($crate::ir::Operator::Sqrt(
            cpa!(unary $input, $out)
        ));
    };
    // out = floor(input)
    ($scope:expr, $out:ident = floor($input:expr)) => {
        $scope.register($crate::ir::Operator::Floor(
            cpa!(unary $input, $out)
        ));
    };
    // out = ceil(input)
    ($scope:expr, $out:ident = ceil($input:expr)) => {
        $scope.register($crate::ir::Operator::Ceil(
            cpa!(unary $input, $out)
        ));
    };
    // out = erf(input)
    ($scope:expr, $out:ident = erf($input:expr)) => {
        $scope.register($crate::ir::Operator::Erf(
            cpa!(unary $input, $out)
        ));
    };
    // out = input
    ($scope:expr, $out:ident = $input:ident) => {
        $scope.register($crate::ir::Operator::Assign(
            cpa!(unary $input, $out)
        ));
    };
    // out = vec4(a, b, c, d)
    ($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => {
        let i = $scope.zero(Elem::UInt);
        cpa!($scope, $out[i] = $a);
        cpa!($scope, i = i + 1u32);
        cpa!($scope, $out[i] = $b);
        cpa!($scope, i = i + 1u32);
        cpa!($scope, $out[i] = $c);
        cpa!($scope, i = i + 1u32);
        cpa!($scope, $out[i] = $d);
    };
    // out = input
    ($scope:expr, $out:ident = $input:ident) => {
        cpa!($scope, $out = cast($input))
    };
    // out = cast(input)
    ($scope:expr, $out:ident = cast($input:expr)) => {
        $scope.register($crate::ir::Operator::Assign(
            cpa!(unary $input, $out)
        ));
    };
    // out = shape(tensor, dim)
    ($scope:expr, $out:ident = shape($input:expr, $dim:expr)) => {
        $scope.register($crate::ir::Metadata::Shape {
            dim: $dim.into(),
            var: $input.into(),
            out: $out.into(),
        });
    };
    // out = stride(tensor, dim)
    ($scope:expr, $out:ident = stride($input:expr, $dim:expr)) => {
        $scope.register($crate::ir::Metadata::Stride {
            dim: $dim.into(),
            var: $input.into(),
            out: $out.into(),
        });
    };
    // out = len(array)
    ($scope:expr, $out:ident = len($input:expr)) => {
        $scope.register($crate::ir::Metadata::Length {
            var: $input.into(),
            out: $out.into(),
        });
    };
    // range(start, end).for_each(|i, scope| { ... })
    ($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
        $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, $arg);
    };
    // range(start, end, unroll).for_each(|i, scope| { ... })
    ($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => {
        if $unroll {
            $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), None, $arg);
        } else {
            $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), None, $arg);
        }
    };
        // range_stepped(start, end, step).for_each(|i, scope| { ... })
        ($scope:expr, range($start:expr, $end:expr, $step:expr).for_each($arg:expr)) => {
            $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
        };
        // range_stepped(start, end, step, unroll).for_each(|i, scope| { ... })
        ($scope:expr, range($start:expr, $end:expr, $step:expr, $unroll:expr).for_each($arg:expr)) => {
            if $unroll {
                $crate::ir::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
            } else {
                $crate::ir::RangeLoop::register($scope, $start.into(), $end.into(), Some($step), $arg);
            }
        };
    // loop(|scope| { ... })
    ($scope:expr, loop($arg:expr)) => {
        $crate::ir::Loop::register($scope, $arg);
    };
    // if (cond).then(|scope| { ... })
    ($scope:expr, if ($cond:expr).then($arg:expr)) => {
        $crate::ir::If::register($scope, $cond.into(), $arg);
    };
    // if (cond).then(|scope| { ... }).else(|scope| { ... })
    ($scope:expr, if ($cond:expr).then($arg_if:expr).else($arg_else:expr)) => {
        $crate::ir::IfElse::register($scope, $cond.into(), $arg_if, $arg_else);
    };
    (binary $lhs:expr, $rhs:expr, $out:expr) => {
        $crate::ir::BinaryOperator {
            lhs: $lhs.into(),
            rhs: $rhs.into(),
            out: $out.into(),
        }
    };
    (unary $input:expr, $out:expr) => {
        $crate::ir::UnaryOperator {
            input: $input.into(),
            out: $out.into(),
        }
    };
}

impl From<bool> for Variable {
    fn from(value: bool) -> Self {
        Variable::ConstantScalar(ConstantScalarValue::Bool(value))
    }
}

impl From<i32> for Variable {
    fn from(value: i32) -> Self {
        Variable::ConstantScalar(ConstantScalarValue::Int(value as i64, IntKind::I32))
    }
}

impl From<i64> for Variable {
    fn from(value: i64) -> Self {
        Variable::ConstantScalar(ConstantScalarValue::Int(value, IntKind::I64))
    }
}

impl From<f32> for Variable {
    fn from(value: f32) -> Self {
        Variable::ConstantScalar(ConstantScalarValue::Float(value as f64, FloatKind::F32))
    }
}

impl From<f64> for Variable {
    fn from(value: f64) -> Self {
        Variable::ConstantScalar(ConstantScalarValue::Float(value, FloatKind::F64))
    }
}

impl From<u32> for Variable {
    fn from(value: u32) -> Self {
        Variable::ConstantScalar(ConstantScalarValue::UInt(value as u64))
    }
}

impl From<usize> for Variable {
    fn from(value: usize) -> Self {
        Variable::ConstantScalar(ConstantScalarValue::UInt(value as u64))
    }
}

pub(crate) use cpa;