radiate_core/stats/expression/
compile.rs1use super::{
2 Expr,
3 ops::{BinaryExpr, BinaryOp, TrinaryExpr, UnaryExpr, UnaryOp, fuse_affine},
4};
5use radiate_utils::AnyValue;
6
7impl Expr {
8 pub fn compile(self) -> Expr {
21 match self {
22 Expr::Literal(_) | Expr::Selector(_) | Expr::Schedule(_) => self,
24
25 Expr::Unary(u) => {
26 let UnaryExpr { child, op } = u;
27 let child = (*child).compile();
28 match op {
29 UnaryOp::Affine { scale, bias } => fuse_affine(child, scale, bias),
32 other_op => Expr::Unary(UnaryExpr::new(child, other_op)),
33 }
34 }
35
36 Expr::Trinary(t) => Expr::Trinary(TrinaryExpr::new(
37 (*t.first).compile(),
38 (*t.second).compile(),
39 (*t.third).compile(),
40 t.operation,
41 )),
42
43 Expr::Binary(b) => {
44 let lhs = (*b.lhs).compile();
45 let rhs = (*b.rhs).compile();
46 reduce_binary(lhs, rhs, b.op)
47 }
48
49 Expr::Aggregate(mut a) => {
51 let child = std::mem::replace(a.child.as_mut(), Expr::Literal(AnyValue::Null));
52 *a.child = child.compile();
53 Expr::Aggregate(a)
54 }
55 }
56 }
57}
58
59fn reduce_binary(lhs: Expr, rhs: Expr, op: BinaryOp) -> Expr {
60 if let (Expr::Literal(l), Expr::Literal(r)) = (&lhs, &rhs)
62 && let Some(folded) = fold_literals(l, r, op)
63 {
64 return Expr::Literal(folded);
65 }
66
67 match op {
69 BinaryOp::Add => {
70 if let Expr::Literal(v) = &rhs
71 && let Some(k) = v.extract::<f32>()
72 {
73 return fuse_affine(lhs, 1.0, k);
74 }
75 if let Expr::Literal(v) = &lhs
76 && let Some(k) = v.extract::<f32>()
77 {
78 return fuse_affine(rhs, 1.0, k);
79 }
80 }
81 BinaryOp::Sub => {
82 if let Expr::Literal(v) = &rhs
83 && let Some(k) = v.extract::<f32>()
84 {
85 return fuse_affine(lhs, 1.0, -k);
87 }
88 if let Expr::Literal(v) = &lhs
89 && let Some(k) = v.extract::<f32>()
90 {
91 return fuse_affine(rhs, -1.0, k);
93 }
94 }
95 BinaryOp::Mul => {
96 if let Expr::Literal(v) = &rhs
97 && let Some(s) = v.extract::<f32>()
98 {
99 return fuse_affine(lhs, s, 0.0);
100 }
101 if let Expr::Literal(v) = &lhs
102 && let Some(s) = v.extract::<f32>()
103 {
104 return fuse_affine(rhs, s, 0.0);
105 }
106 }
107 BinaryOp::Div => {
108 if let Expr::Literal(v) = &rhs
110 && let Some(d) = v.extract::<f32>()
111 && d != 0.0
112 && d.is_finite()
113 {
114 return fuse_affine(lhs, 1.0 / d, 0.0);
115 }
116 }
117 _ => {}
118 }
119
120 Expr::Binary(BinaryExpr::new(lhs, rhs, op))
121}
122
123fn fold_literals(
124 l: &AnyValue<'static>,
125 r: &AnyValue<'static>,
126 op: BinaryOp,
127) -> Option<AnyValue<'static>> {
128 let a = l.extract::<f32>()?;
129 let b = r.extract::<f32>()?;
130 let result = match op {
131 BinaryOp::Add => a + b,
132 BinaryOp::Sub => a - b,
133 BinaryOp::Mul => a * b,
134 BinaryOp::Div if b != 0.0 => a / b,
135 _ => return None,
136 };
137 if result.is_finite() {
138 Some(AnyValue::Float32(result))
139 } else {
140 None
141 }
142}