1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
use super::{
Expr,
ops::{BinaryExpr, BinaryOp, TrinaryExpr, UnaryExpr, UnaryOp, fuse_affine},
};
use radiate_utils::AnyValue;
impl Expr {
/// Walks the tree bottom-up and rewrites algebraically equivalent shapes
/// into the smallest possible form. Specifically:
///
/// - Pure-literal subtrees fold (`Lit(2) + Lit(3)` → `Lit(5)`)
/// - `Add` / `Sub` / `Mul` / `Div` with one literal operand fuses into a
/// `Unary(Affine)` (`x * 5 + 3` → `Affine { scale: 5, bias: 3 }`)
/// - Nested affines collapse: `s2 * (s1*x + b1) + b2` → `Affine(s2*s1, s2*b1 + b2)`
///
/// Called automatically when wrapping in `Rate::Expr` or `NamedExpr`. Safe
/// to call multiple times — idempotent. Mathematically lossless within
/// f32 precision (and within the existing arithmetic semantics for Null /
/// non-finite operands).
pub fn compile(self) -> Expr {
match self {
// Leaves — nothing to rewrite.
Expr::Literal(_) | Expr::Selector(_) | Expr::Schedule(_) => self,
Expr::Unary(u) => {
let UnaryExpr { child, op } = u;
let child = (*child).compile();
match op {
// Affine on top of a compiled child: re-run fusion so any
// newly-revealed Affine nested below collapses upward.
UnaryOp::Affine { scale, bias } => fuse_affine(child, scale, bias),
other_op => Expr::Unary(UnaryExpr::new(child, other_op)),
}
}
Expr::Trinary(t) => Expr::Trinary(TrinaryExpr::new(
(*t.first).compile(),
(*t.second).compile(),
(*t.third).compile(),
t.operation,
)),
Expr::Binary(b) => {
let lhs = (*b.lhs).compile();
let rhs = (*b.rhs).compile();
reduce_binary(lhs, rhs, b.op)
}
// Stateful nodes — keep the rollup/buffer intact, just compile the child.
Expr::Aggregate(mut a) => {
let child = std::mem::replace(a.child.as_mut(), Expr::Literal(AnyValue::Null));
*a.child = child.compile();
Expr::Aggregate(a)
}
}
}
}
fn reduce_binary(lhs: Expr, rhs: Expr, op: BinaryOp) -> Expr {
// Pure literal-on-literal: fold to a Literal.
if let (Expr::Literal(l), Expr::Literal(r)) = (&lhs, &rhs)
&& let Some(folded) = fold_literals(l, r, op)
{
return Expr::Literal(folded);
}
// Affine fusion patterns. Only Add/Sub/Mul/Div are linear; the rest fall through.
match op {
BinaryOp::Add => {
if let Expr::Literal(v) = &rhs
&& let Some(k) = v.extract::<f32>()
{
return fuse_affine(lhs, 1.0, k);
}
if let Expr::Literal(v) = &lhs
&& let Some(k) = v.extract::<f32>()
{
return fuse_affine(rhs, 1.0, k);
}
}
BinaryOp::Sub => {
if let Expr::Literal(v) = &rhs
&& let Some(k) = v.extract::<f32>()
{
// x - k → Affine(x, 1, -k)
return fuse_affine(lhs, 1.0, -k);
}
if let Expr::Literal(v) = &lhs
&& let Some(k) = v.extract::<f32>()
{
// k - x → Affine(x, -1, k)
return fuse_affine(rhs, -1.0, k);
}
}
BinaryOp::Mul => {
if let Expr::Literal(v) = &rhs
&& let Some(s) = v.extract::<f32>()
{
return fuse_affine(lhs, s, 0.0);
}
if let Expr::Literal(v) = &lhs
&& let Some(s) = v.extract::<f32>()
{
return fuse_affine(rhs, s, 0.0);
}
}
BinaryOp::Div => {
// Only fold `x / Lit` (constant divisor). `Lit / x` is non-linear in x.
if let Expr::Literal(v) = &rhs
&& let Some(d) = v.extract::<f32>()
&& d != 0.0
&& d.is_finite()
{
return fuse_affine(lhs, 1.0 / d, 0.0);
}
}
_ => {}
}
Expr::Binary(BinaryExpr::new(lhs, rhs, op))
}
fn fold_literals(
l: &AnyValue<'static>,
r: &AnyValue<'static>,
op: BinaryOp,
) -> Option<AnyValue<'static>> {
let a = l.extract::<f32>()?;
let b = r.extract::<f32>()?;
let result = match op {
BinaryOp::Add => a + b,
BinaryOp::Sub => a - b,
BinaryOp::Mul => a * b,
BinaryOp::Div if b != 0.0 => a / b,
_ => return None,
};
if result.is_finite() {
Some(AnyValue::Float32(result))
} else {
None
}
}