use std::collections::HashMap;
use super::node::{ExprId, Node};
#[derive(Clone)]
pub struct ExprGraph {
nodes: Vec<Node>,
intern: HashMap<Node, ExprId>,
}
impl ExprGraph {
pub fn new() -> Self {
let mut g = Self {
nodes: Vec::new(),
intern: HashMap::new(),
};
let z = g.insert(Node::lit(0.0));
debug_assert_eq!(z, ExprId::ZERO);
let o = g.insert(Node::lit(1.0));
debug_assert_eq!(o, ExprId::ONE);
let t = g.insert(Node::lit(2.0));
debug_assert_eq!(t, ExprId::TWO);
g
}
#[inline]
pub fn len(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
#[inline]
pub fn node(&self, id: ExprId) -> Node {
self.nodes[id.0 as usize]
}
#[inline]
pub fn nodes_slice(&self) -> &[Node] {
&self.nodes
}
fn insert(&mut self, node: Node) -> ExprId {
if let Some(&id) = self.intern.get(&node) {
return id;
}
let id = ExprId(self.nodes.len() as u32);
self.nodes.push(node);
self.intern.insert(node, id);
id
}
#[inline]
pub fn var(&mut self, n: u16) -> ExprId {
self.insert(Node::Var(n))
}
#[inline]
pub fn lit(&mut self, v: f64) -> ExprId {
self.insert(Node::lit(v))
}
#[inline]
pub fn add(&mut self, a: ExprId, b: ExprId) -> ExprId {
self.insert(Node::Add(a, b))
}
#[inline]
pub fn mul(&mut self, a: ExprId, b: ExprId) -> ExprId {
self.insert(Node::Mul(a, b))
}
#[inline]
pub fn neg(&mut self, a: ExprId) -> ExprId {
self.insert(Node::Neg(a))
}
#[inline]
pub fn recip(&mut self, a: ExprId) -> ExprId {
self.insert(Node::Recip(a))
}
#[inline]
pub fn sqrt(&mut self, a: ExprId) -> ExprId {
self.insert(Node::Sqrt(a))
}
#[inline]
pub fn sin(&mut self, a: ExprId) -> ExprId {
self.insert(Node::Sin(a))
}
#[inline]
pub fn atan2(&mut self, y: ExprId, x: ExprId) -> ExprId {
self.insert(Node::Atan2(y, x))
}
#[inline]
pub fn exp2(&mut self, a: ExprId) -> ExprId {
self.insert(Node::Exp2(a))
}
#[inline]
pub fn log2(&mut self, a: ExprId) -> ExprId {
self.insert(Node::Log2(a))
}
#[inline]
pub fn select(&mut self, cond: ExprId, a: ExprId, b: ExprId) -> ExprId {
self.insert(Node::Select(cond, a, b))
}
}
impl Default for ExprGraph {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pre_populated() {
let g = ExprGraph::new();
assert_eq!(g.node(ExprId::ZERO).as_f64(), Some(0.0));
assert_eq!(g.node(ExprId::ONE).as_f64(), Some(1.0));
assert_eq!(g.node(ExprId::TWO).as_f64(), Some(2.0));
assert_eq!(g.len(), 3);
}
#[test]
fn interning() {
let mut g = ExprGraph::new();
let x = g.var(0);
let x2 = g.var(0);
assert_eq!(x, x2);
let a = g.add(x, ExprId::ONE);
let a2 = g.add(x, ExprId::ONE);
assert_eq!(a, a2);
}
#[test]
fn lit_nan_distinct() {
let mut g = ExprGraph::new();
let a = g.lit(f64::NAN);
let b = g.lit(f64::NAN);
assert_eq!(a, b); }
}