1use std::collections::HashMap;
4
5use crate::node::{ExprId, Node};
6
7#[derive(Clone)]
12pub struct ExprGraph {
13 nodes: Vec<Node>,
14 intern: HashMap<Node, ExprId>,
15}
16
17impl ExprGraph {
18 pub fn new() -> Self {
20 let mut g = Self {
21 nodes: Vec::new(),
22 intern: HashMap::new(),
23 };
24 let z = g.insert(Node::lit(0.0));
26 debug_assert_eq!(z, ExprId::ZERO);
27 let o = g.insert(Node::lit(1.0));
29 debug_assert_eq!(o, ExprId::ONE);
30 let t = g.insert(Node::lit(2.0));
32 debug_assert_eq!(t, ExprId::TWO);
33 g
34 }
35
36 #[inline]
38 pub fn len(&self) -> usize {
39 self.nodes.len()
40 }
41
42 #[inline]
44 pub fn is_empty(&self) -> bool {
45 self.nodes.is_empty()
46 }
47
48 #[inline]
50 pub fn node(&self, id: ExprId) -> Node {
51 self.nodes[id.0 as usize]
52 }
53
54 #[inline]
56 pub fn nodes_slice(&self) -> &[Node] {
57 &self.nodes
58 }
59
60 fn insert(&mut self, node: Node) -> ExprId {
62 if let Some(&id) = self.intern.get(&node) {
63 return id;
64 }
65 let id = ExprId(self.nodes.len() as u32);
66 self.nodes.push(node);
67 self.intern.insert(node, id);
68 id
69 }
70
71 #[inline]
73 pub fn var(&mut self, n: u16) -> ExprId {
74 self.insert(Node::Var(n))
75 }
76
77 #[inline]
79 pub fn lit(&mut self, v: f64) -> ExprId {
80 self.insert(Node::lit(v))
81 }
82
83 #[inline]
85 pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
86 self.insert(Node::Add(a, b))
87 }
88
89 #[inline]
91 pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
92 self.insert(Node::Mul(a, b))
93 }
94
95 #[inline]
97 pub fn neg(&mut self, a: ExprId) -> ExprId {
98 self.insert(Node::Neg(a))
99 }
100
101 #[inline]
103 pub fn recip(&mut self, a: ExprId) -> ExprId {
104 self.insert(Node::Recip(a))
105 }
106
107 #[inline]
109 pub fn sqrt(&mut self, a: ExprId) -> ExprId {
110 self.insert(Node::Sqrt(a))
111 }
112
113 #[inline]
115 pub fn sin(&mut self, a: ExprId) -> ExprId {
116 self.insert(Node::Sin(a))
117 }
118
119 #[inline]
121 pub fn atan2(&mut self, y: ExprId, x: ExprId) -> ExprId {
122 self.insert(Node::Atan2(y, x))
123 }
124
125 #[inline]
127 pub fn exp2(&mut self, a: ExprId) -> ExprId {
128 self.insert(Node::Exp2(a))
129 }
130
131 #[inline]
133 pub fn log2(&mut self, a: ExprId) -> ExprId {
134 self.insert(Node::Log2(a))
135 }
136
137 #[inline]
139 pub fn select(&mut self, cond: ExprId, a: ExprId, b: ExprId) -> ExprId {
140 self.insert(Node::Select(cond, a, b))
141 }
142}
143
144impl Default for ExprGraph {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn pre_populated() {
156 let g = ExprGraph::new();
157 assert_eq!(g.node(ExprId::ZERO).as_f64(), Some(0.0));
158 assert_eq!(g.node(ExprId::ONE).as_f64(), Some(1.0));
159 assert_eq!(g.node(ExprId::TWO).as_f64(), Some(2.0));
160 assert_eq!(g.len(), 3);
161 }
162
163 #[test]
164 fn interning() {
165 let mut g = ExprGraph::new();
166 let x = g.var(0);
167 let x2 = g.var(0);
168 assert_eq!(x, x2);
169
170 let a = g.add(x, ExprId::ONE);
171 let a2 = g.add(x, ExprId::ONE);
172 assert_eq!(a, a2);
173 }
174
175 #[test]
176 fn lit_nan_distinct() {
177 let mut g = ExprGraph::new();
178 let a = g.lit(f64::NAN);
180 let b = g.lit(f64::NAN);
181 assert_eq!(a, b); }
183}