1pub mod codegen;
29pub mod compile;
30pub mod diff;
31pub mod display;
32pub mod eval;
33pub mod graph;
34pub mod node;
35mod scalar;
36pub mod simplify;
37pub mod sparsity;
38pub mod wgsl;
39
40pub use graph::ExprGraph;
41pub use node::{ExprId, Node};
42
43use std::cell::RefCell;
44
45thread_local! {
46 static GRAPH: RefCell<ExprGraph> = RefCell::new(ExprGraph::new());
47}
48
49pub fn with_graph<F, R>(f: F) -> R
51where
52 F: FnOnce(&mut ExprGraph) -> R,
53{
54 GRAPH.with(|g| f(&mut g.borrow_mut()))
55}
56
57pub fn trace<F, R>(f: F) -> (ExprGraph, R)
62where
63 F: FnOnce() -> R,
64{
65 GRAPH.with(|g| {
67 let old = std::mem::take(&mut *g.borrow_mut());
68 let result = f();
69 let graph = std::mem::replace(&mut *g.borrow_mut(), old);
70 (graph, result)
71 })
72}
73
74impl ExprId {
75 #[inline]
77 pub fn var(n: u16) -> Self {
78 with_graph(|g| g.var(n))
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use tang::{Scalar, Vec3};
86
87 #[test]
88 fn trace_vec3_dot() {
89 let (g, dot) = trace(|| {
90 let a = Vec3::new(ExprId::var(0), ExprId::var(1), ExprId::var(2));
91 let b = Vec3::new(ExprId::var(3), ExprId::var(4), ExprId::var(5));
92 a.dot(b)
93 });
94
95 let result: f64 = g.eval(dot, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
97 assert!((result - 32.0).abs() < 1e-10);
98 }
99
100 #[test]
101 fn trace_vec3_norm() {
102 let (g, norm) = trace(|| {
103 let v = Vec3::new(ExprId::var(0), ExprId::var(1), ExprId::var(2));
104 v.norm()
105 });
106
107 let result: f64 = g.eval(norm, &[3.0, 4.0, 0.0]);
109 assert!((result - 5.0).abs() < 1e-10);
110 }
111
112 #[test]
113 fn trace_isolation() {
114 let (g1, _) = trace(|| {
116 let _x = ExprId::var(0);
117 });
118 let (g2, _) = trace(|| {
119 let _x = ExprId::var(0);
120 });
121 assert_eq!(g1.len(), g2.len());
123 }
124
125 #[test]
126 fn from_f64_creates_lit() {
127 let (g, v) = trace(|| ExprId::from_f64(42.0));
128 let result: f64 = g.eval(v, &[]);
129 assert!((result - 42.0).abs() < 1e-10);
130 }
131
132 #[test]
133 fn scalar_constants() {
134 let (g, (zero, one, two)) = trace(|| {
135 let z: ExprId = Scalar::from_f64(0.0);
136 let o: ExprId = Scalar::from_f64(1.0);
137 let t: ExprId = Scalar::from_f64(2.0);
138 (z, o, t)
139 });
140 assert_eq!(g.eval::<f64>(zero, &[]), 0.0);
141 assert_eq!(g.eval::<f64>(one, &[]), 1.0);
142 assert_eq!(g.eval::<f64>(two, &[]), 2.0);
143 }
144}