Skip to main content

tang_expr/
compile.rs

1//! Compile expression graphs to optimized Rust closures.
2
3use std::collections::HashSet;
4
5use crate::graph::ExprGraph;
6use crate::node::{ExprId, Node};
7
8/// Compiled single-output expression closure.
9pub type CompiledExpr = Box<dyn Fn(&[f64]) -> f64>;
10
11/// Compiled multi-output expression closure.
12pub type CompiledMany = Box<dyn Fn(&[f64], &mut [f64])>;
13
14impl ExprGraph {
15    /// Compile a single expression to a closure `&[f64] -> f64`.
16    ///
17    /// Shared subexpressions (from interning) are computed once.
18    /// Dead nodes (not reachable from output) are skipped.
19    pub fn compile(&self, expr: ExprId) -> CompiledExpr {
20        let live = self.live_set(&[expr]);
21        let nodes = self.collect_eval_order(&live, expr.0 as usize + 1);
22        let out_idx = expr.0 as usize;
23
24        Box::new(move |inputs: &[f64]| {
25            let mut vals = vec![0.0f64; out_idx + 1];
26            for &(i, ref node) in &nodes {
27                vals[i] = eval_node(node, &vals, inputs);
28            }
29            vals[out_idx]
30        })
31    }
32
33    /// Compile multiple output expressions to a single closure.
34    ///
35    /// Writes results into the output slice.
36    pub fn compile_many(&self, exprs: &[ExprId]) -> CompiledMany {
37        if exprs.is_empty() {
38            return Box::new(|_, _| {});
39        }
40
41        let live = self.live_set(exprs);
42        let max_id = exprs.iter().map(|e| e.0).max().unwrap() as usize;
43        let nodes = self.collect_eval_order(&live, max_id + 1);
44        let out_indices: Vec<usize> = exprs.iter().map(|e| e.0 as usize).collect();
45
46        Box::new(move |inputs: &[f64], outputs: &mut [f64]| {
47            let mut vals = vec![0.0f64; max_id + 1];
48            for &(i, ref node) in &nodes {
49                vals[i] = eval_node(node, &vals, inputs);
50            }
51            for (k, &idx) in out_indices.iter().enumerate() {
52                outputs[k] = vals[idx];
53            }
54        })
55    }
56
57    /// Find all node indices reachable from the given outputs.
58    pub fn live_set(&self, outputs: &[ExprId]) -> HashSet<usize> {
59        let mut live = HashSet::new();
60        let mut stack: Vec<usize> = outputs.iter().map(|e| e.0 as usize).collect();
61        while let Some(i) = stack.pop() {
62            if !live.insert(i) {
63                continue;
64            }
65            match self.node(ExprId(i as u32)) {
66                Node::Var(_) | Node::Lit(_) => {}
67                Node::Add(a, b) | Node::Mul(a, b) | Node::Atan2(a, b) => {
68                    stack.push(a.0 as usize);
69                    stack.push(b.0 as usize);
70                }
71                Node::Neg(a)
72                | Node::Recip(a)
73                | Node::Sqrt(a)
74                | Node::Sin(a)
75                | Node::Exp2(a)
76                | Node::Log2(a) => {
77                    stack.push(a.0 as usize);
78                }
79                Node::Select(c, a, b) => {
80                    stack.push(c.0 as usize);
81                    stack.push(a.0 as usize);
82                    stack.push(b.0 as usize);
83                }
84            }
85        }
86        live
87    }
88
89    /// Collect (index, node) pairs in topological order, only for live nodes.
90    fn collect_eval_order(&self, live: &HashSet<usize>, count: usize) -> Vec<(usize, Node)> {
91        (0..count)
92            .filter(|i| live.contains(i))
93            .map(|i| (i, self.node(ExprId(i as u32))))
94            .collect()
95    }
96}
97
98#[inline]
99fn eval_node(node: &Node, vals: &[f64], inputs: &[f64]) -> f64 {
100    match *node {
101        Node::Var(idx) => inputs[idx as usize],
102        Node::Lit(bits) => f64::from_bits(bits),
103        Node::Add(a, b) => vals[a.0 as usize] + vals[b.0 as usize],
104        Node::Mul(a, b) => vals[a.0 as usize] * vals[b.0 as usize],
105        Node::Neg(a) => -vals[a.0 as usize],
106        Node::Recip(a) => 1.0 / vals[a.0 as usize],
107        Node::Sqrt(a) => vals[a.0 as usize].sqrt(),
108        Node::Sin(a) => vals[a.0 as usize].sin(),
109        Node::Atan2(y, x) => vals[y.0 as usize].atan2(vals[x.0 as usize]),
110        Node::Exp2(a) => vals[a.0 as usize].exp2(),
111        Node::Log2(a) => vals[a.0 as usize].log2(),
112        Node::Select(c, a, b) => {
113            if vals[c.0 as usize] > 0.0 {
114                vals[a.0 as usize]
115            } else {
116                vals[b.0 as usize]
117            }
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use crate::graph::ExprGraph;
125
126    #[test]
127    fn compile_add_lits() {
128        let mut g = ExprGraph::new();
129        let a = g.lit(3.0);
130        let b = g.lit(4.0);
131        let sum = g.add(a, b);
132        let f = g.compile(sum);
133        assert!((f(&[]) - 7.0).abs() < 1e-10);
134    }
135
136    #[test]
137    fn compile_with_vars() {
138        let mut g = ExprGraph::new();
139        let x = g.var(0);
140        let y = g.var(1);
141        let sum = g.add(x, y);
142        let prod = g.mul(sum, x);
143        let f = g.compile(prod);
144        // (3 + 4) * 3 = 21
145        assert!((f(&[3.0, 4.0]) - 21.0).abs() < 1e-10);
146    }
147
148    #[test]
149    fn compile_sin() {
150        let mut g = ExprGraph::new();
151        let x = g.var(0);
152        let s = g.sin(x);
153        let f = g.compile(s);
154        assert!((f(&[std::f64::consts::FRAC_PI_2]) - 1.0).abs() < 1e-10);
155    }
156
157    #[test]
158    fn compile_many_outputs() {
159        let mut g = ExprGraph::new();
160        let x = g.var(0);
161        let y = g.var(1);
162        let sum = g.add(x, y);
163        let prod = g.mul(x, y);
164        let f = g.compile_many(&[sum, prod]);
165        let mut out = [0.0; 2];
166        f(&[3.0, 4.0], &mut out);
167        assert!((out[0] - 7.0).abs() < 1e-10);
168        assert!((out[1] - 12.0).abs() < 1e-10);
169    }
170
171    #[test]
172    fn compile_dead_code_elimination() {
173        let mut g = ExprGraph::new();
174        let x = g.var(0);
175        let _dead = g.sin(x); // not used in output
176        let result = g.mul(x, x);
177        let f = g.compile(result);
178        assert!((f(&[5.0]) - 25.0).abs() < 1e-10);
179    }
180
181    #[test]
182    fn compile_matches_eval() {
183        let mut g = ExprGraph::new();
184        let x = g.var(0);
185        let y = g.var(1);
186        let xx = g.mul(x, x);
187        let yy = g.mul(y, y);
188        let sum = g.add(xx, yy);
189        let dist = g.sqrt(sum);
190
191        let inputs = [3.0, 4.0];
192        let eval_result: f64 = g.eval(dist, &inputs);
193        let f = g.compile(dist);
194        let compile_result = f(&inputs);
195        assert!((eval_result - compile_result).abs() < 1e-10);
196        assert!((compile_result - 5.0).abs() < 1e-10);
197    }
198}