Skip to main content

tang_expr/
graph.rs

1//! Expression graph with structural interning (automatic CSE).
2
3use std::collections::HashMap;
4
5use crate::node::{ExprId, Node};
6
7/// Arena-based expression graph with structural interning.
8///
9/// Identical subexpressions always return the same `ExprId` — this gives
10/// automatic common subexpression elimination (CSE) for free.
11#[derive(Clone)]
12pub struct ExprGraph {
13    nodes: Vec<Node>,
14    intern: HashMap<Node, ExprId>,
15}
16
17impl ExprGraph {
18    /// Create a new graph pre-populated with ZERO, ONE, TWO.
19    pub fn new() -> Self {
20        let mut g = Self {
21            nodes: Vec::new(),
22            intern: HashMap::new(),
23        };
24        // Index 0 = ZERO
25        let z = g.insert(Node::lit(0.0));
26        debug_assert_eq!(z, ExprId::ZERO);
27        // Index 1 = ONE
28        let o = g.insert(Node::lit(1.0));
29        debug_assert_eq!(o, ExprId::ONE);
30        // Index 2 = TWO
31        let t = g.insert(Node::lit(2.0));
32        debug_assert_eq!(t, ExprId::TWO);
33        g
34    }
35
36    /// Total number of nodes in the graph.
37    #[inline]
38    pub fn len(&self) -> usize {
39        self.nodes.len()
40    }
41
42    /// Whether the graph is empty (it never is after construction).
43    #[inline]
44    pub fn is_empty(&self) -> bool {
45        self.nodes.is_empty()
46    }
47
48    /// Look up the node for an ExprId.
49    #[inline]
50    pub fn node(&self, id: ExprId) -> Node {
51        self.nodes[id.0 as usize]
52    }
53
54    /// Read-only access to the node arena for serialization.
55    #[inline]
56    pub fn nodes_slice(&self) -> &[Node] {
57        &self.nodes
58    }
59
60    /// Internal: insert a node, returning its interned ExprId.
61    fn insert(&mut self, node: Node) -> ExprId {
62        if let Some(&id) = self.intern.get(&node) {
63            return id;
64        }
65        let id = ExprId(self.nodes.len() as u32);
66        self.nodes.push(node);
67        self.intern.insert(node, id);
68        id
69    }
70
71    /// Create a variable node.
72    #[inline]
73    pub fn var(&mut self, n: u16) -> ExprId {
74        self.insert(Node::Var(n))
75    }
76
77    /// Create a literal node.
78    #[inline]
79    pub fn lit(&mut self, v: f64) -> ExprId {
80        self.insert(Node::lit(v))
81    }
82
83    /// Add two expressions.
84    #[inline]
85    pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
86        self.insert(Node::Add(a, b))
87    }
88
89    /// Multiply two expressions.
90    #[inline]
91    pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
92        self.insert(Node::Mul(a, b))
93    }
94
95    /// Negate an expression.
96    #[inline]
97    pub fn neg(&mut self, a: ExprId) -> ExprId {
98        self.insert(Node::Neg(a))
99    }
100
101    /// Reciprocal (1/x).
102    #[inline]
103    pub fn recip(&mut self, a: ExprId) -> ExprId {
104        self.insert(Node::Recip(a))
105    }
106
107    /// Square root.
108    #[inline]
109    pub fn sqrt(&mut self, a: ExprId) -> ExprId {
110        self.insert(Node::Sqrt(a))
111    }
112
113    /// Sine.
114    #[inline]
115    pub fn sin(&mut self, a: ExprId) -> ExprId {
116        self.insert(Node::Sin(a))
117    }
118
119    /// atan2(y, x).
120    #[inline]
121    pub fn atan2(&mut self, y: ExprId, x: ExprId) -> ExprId {
122        self.insert(Node::Atan2(y, x))
123    }
124
125    /// Base-2 exponential.
126    #[inline]
127    pub fn exp2(&mut self, a: ExprId) -> ExprId {
128        self.insert(Node::Exp2(a))
129    }
130
131    /// Base-2 logarithm.
132    #[inline]
133    pub fn log2(&mut self, a: ExprId) -> ExprId {
134        self.insert(Node::Log2(a))
135    }
136
137    /// Branchless select: returns `a` if `cond > 0`, else `b`.
138    #[inline]
139    pub fn select(&mut self, cond: ExprId, a: ExprId, b: ExprId) -> ExprId {
140        self.insert(Node::Select(cond, a, b))
141    }
142}
143
144impl Default for ExprGraph {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn pre_populated() {
156        let g = ExprGraph::new();
157        assert_eq!(g.node(ExprId::ZERO).as_f64(), Some(0.0));
158        assert_eq!(g.node(ExprId::ONE).as_f64(), Some(1.0));
159        assert_eq!(g.node(ExprId::TWO).as_f64(), Some(2.0));
160        assert_eq!(g.len(), 3);
161    }
162
163    #[test]
164    fn interning() {
165        let mut g = ExprGraph::new();
166        let x = g.var(0);
167        let x2 = g.var(0);
168        assert_eq!(x, x2);
169
170        let a = g.add(x, ExprId::ONE);
171        let a2 = g.add(x, ExprId::ONE);
172        assert_eq!(a, a2);
173    }
174
175    #[test]
176    fn lit_nan_distinct() {
177        let mut g = ExprGraph::new();
178        // NaN bits are deterministic for the same f64::NAN
179        let a = g.lit(f64::NAN);
180        let b = g.lit(f64::NAN);
181        assert_eq!(a, b); // same bits → same id
182    }
183}