Skip to main content

radiate_core/stats/expression/
compile.rs

1use super::{
2    Expr,
3    ops::{BinaryExpr, BinaryOp, TrinaryExpr, UnaryExpr, UnaryOp, fuse_affine},
4};
5use radiate_utils::AnyValue;
6
7impl Expr {
8    /// Walks the tree bottom-up and rewrites algebraically equivalent shapes
9    /// into the smallest possible form. Specifically:
10    ///
11    /// - Pure-literal subtrees fold (`Lit(2) + Lit(3)` → `Lit(5)`)
12    /// - `Add` / `Sub` / `Mul` / `Div` with one literal operand fuses into a
13    ///   `Unary(Affine)` (`x * 5 + 3` → `Affine { scale: 5, bias: 3 }`)
14    /// - Nested affines collapse: `s2 * (s1*x + b1) + b2` → `Affine(s2*s1, s2*b1 + b2)`
15    ///
16    /// Called automatically when wrapping in `Rate::Expr` or `NamedExpr`. Safe
17    /// to call multiple times — idempotent. Mathematically lossless within
18    /// f32 precision (and within the existing arithmetic semantics for Null /
19    /// non-finite operands).
20    pub fn compile(self) -> Expr {
21        match self {
22            // Leaves — nothing to rewrite.
23            Expr::Literal(_) | Expr::Selector(_) | Expr::Schedule(_) => self,
24
25            Expr::Unary(u) => {
26                let UnaryExpr { child, op } = u;
27                let child = (*child).compile();
28                match op {
29                    // Affine on top of a compiled child: re-run fusion so any
30                    // newly-revealed Affine nested below collapses upward.
31                    UnaryOp::Affine { scale, bias } => fuse_affine(child, scale, bias),
32                    other_op => Expr::Unary(UnaryExpr::new(child, other_op)),
33                }
34            }
35
36            Expr::Trinary(t) => Expr::Trinary(TrinaryExpr::new(
37                (*t.first).compile(),
38                (*t.second).compile(),
39                (*t.third).compile(),
40                t.operation,
41            )),
42
43            Expr::Binary(b) => {
44                let lhs = (*b.lhs).compile();
45                let rhs = (*b.rhs).compile();
46                reduce_binary(lhs, rhs, b.op)
47            }
48
49            // Stateful nodes — keep the rollup/buffer intact, just compile the child.
50            Expr::Aggregate(mut a) => {
51                let child = std::mem::replace(a.child.as_mut(), Expr::Literal(AnyValue::Null));
52                *a.child = child.compile();
53                Expr::Aggregate(a)
54            }
55        }
56    }
57}
58
59fn reduce_binary(lhs: Expr, rhs: Expr, op: BinaryOp) -> Expr {
60    // Pure literal-on-literal: fold to a Literal.
61    if let (Expr::Literal(l), Expr::Literal(r)) = (&lhs, &rhs)
62        && let Some(folded) = fold_literals(l, r, op)
63    {
64        return Expr::Literal(folded);
65    }
66
67    // Affine fusion patterns. Only Add/Sub/Mul/Div are linear; the rest fall through.
68    match op {
69        BinaryOp::Add => {
70            if let Expr::Literal(v) = &rhs
71                && let Some(k) = v.extract::<f32>()
72            {
73                return fuse_affine(lhs, 1.0, k);
74            }
75            if let Expr::Literal(v) = &lhs
76                && let Some(k) = v.extract::<f32>()
77            {
78                return fuse_affine(rhs, 1.0, k);
79            }
80        }
81        BinaryOp::Sub => {
82            if let Expr::Literal(v) = &rhs
83                && let Some(k) = v.extract::<f32>()
84            {
85                // x - k → Affine(x, 1, -k)
86                return fuse_affine(lhs, 1.0, -k);
87            }
88            if let Expr::Literal(v) = &lhs
89                && let Some(k) = v.extract::<f32>()
90            {
91                // k - x → Affine(x, -1, k)
92                return fuse_affine(rhs, -1.0, k);
93            }
94        }
95        BinaryOp::Mul => {
96            if let Expr::Literal(v) = &rhs
97                && let Some(s) = v.extract::<f32>()
98            {
99                return fuse_affine(lhs, s, 0.0);
100            }
101            if let Expr::Literal(v) = &lhs
102                && let Some(s) = v.extract::<f32>()
103            {
104                return fuse_affine(rhs, s, 0.0);
105            }
106        }
107        BinaryOp::Div => {
108            // Only fold `x / Lit` (constant divisor). `Lit / x` is non-linear in x.
109            if let Expr::Literal(v) = &rhs
110                && let Some(d) = v.extract::<f32>()
111                && d != 0.0
112                && d.is_finite()
113            {
114                return fuse_affine(lhs, 1.0 / d, 0.0);
115            }
116        }
117        _ => {}
118    }
119
120    Expr::Binary(BinaryExpr::new(lhs, rhs, op))
121}
122
123fn fold_literals(
124    l: &AnyValue<'static>,
125    r: &AnyValue<'static>,
126    op: BinaryOp,
127) -> Option<AnyValue<'static>> {
128    let a = l.extract::<f32>()?;
129    let b = r.extract::<f32>()?;
130    let result = match op {
131        BinaryOp::Add => a + b,
132        BinaryOp::Sub => a - b,
133        BinaryOp::Mul => a * b,
134        BinaryOp::Div if b != 0.0 => a / b,
135        _ => return None,
136    };
137    if result.is_finite() {
138        Some(AnyValue::Float32(result))
139    } else {
140        None
141    }
142}