pub mod codegen;
pub mod compile;
pub mod diff;
pub mod display;
pub mod eval;
pub mod graph;
pub mod node;
mod scalar;
pub mod simplify;
pub mod sparsity;
pub mod wgsl;
pub use graph::ExprGraph;
pub use node::{ExprId, Node};
use std::cell::RefCell;
thread_local! {
static GRAPH: RefCell<ExprGraph> = RefCell::new(ExprGraph::new());
}
pub fn with_graph<F, R>(f: F) -> R
where
F: FnOnce(&mut ExprGraph) -> R,
{
GRAPH.with(|g| f(&mut g.borrow_mut()))
}
pub fn trace<F, R>(f: F) -> (ExprGraph, R)
where
F: FnOnce() -> R,
{
GRAPH.with(|g| {
let old = std::mem::take(&mut *g.borrow_mut());
let result = f();
let graph = std::mem::replace(&mut *g.borrow_mut(), old);
(graph, result)
})
}
impl ExprId {
#[inline]
pub fn var(n: u16) -> Self {
with_graph(|g| g.var(n))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Scalar, Vec3};
#[test]
fn trace_vec3_dot() {
let (g, dot) = trace(|| {
let a = Vec3::new(ExprId::var(0), ExprId::var(1), ExprId::var(2));
let b = Vec3::new(ExprId::var(3), ExprId::var(4), ExprId::var(5));
a.dot(b)
});
let result: f64 = g.eval(dot, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert!((result - 32.0).abs() < 1e-10);
}
#[test]
fn trace_vec3_norm() {
let (g, norm) = trace(|| {
let v = Vec3::new(ExprId::var(0), ExprId::var(1), ExprId::var(2));
v.norm()
});
let result: f64 = g.eval(norm, &[3.0, 4.0, 0.0]);
assert!((result - 5.0).abs() < 1e-10);
}
#[test]
fn trace_isolation() {
let (g1, _) = trace(|| {
let _x = ExprId::var(0);
});
let (g2, _) = trace(|| {
let _x = ExprId::var(0);
});
assert_eq!(g1.len(), g2.len());
}
#[test]
fn from_f64_creates_lit() {
let (g, v) = trace(|| ExprId::from_f64(42.0));
let result: f64 = g.eval(v, &[]);
assert!((result - 42.0).abs() < 1e-10);
}
#[test]
fn scalar_constants() {
let (g, (zero, one, two)) = trace(|| {
let z: ExprId = Scalar::from_f64(0.0);
let o: ExprId = Scalar::from_f64(1.0);
let t: ExprId = Scalar::from_f64(2.0);
(z, o, t)
});
assert_eq!(g.eval::<f64>(zero, &[]), 0.0);
assert_eq!(g.eval::<f64>(one, &[]), 1.0);
assert_eq!(g.eval::<f64>(two, &[]), 2.0);
}
}