Skip to main content

tang_expr/
lib.rs

1//! tang-expr — RISC expression graph for symbolic computation.
2//!
3//! Builds computation graphs from generic `Scalar` code via a thread-local
4//! graph. Enables symbolic differentiation, sparsity detection, simplification,
5//! and multi-backend compilation (CPU closures, WGSL shaders).
6//!
7//! # Quick start
8//!
9//! ```
10//! use tang::Vec3;
11//! use tang_expr::{trace, ExprId};
12//!
13//! let (mut g, dot) = trace(|| {
14//!     let a = Vec3::new(ExprId::var(0), ExprId::var(1), ExprId::var(2));
15//!     let b = Vec3::new(ExprId::var(3), ExprId::var(4), ExprId::var(5));
16//!     a.dot(b)
17//! });
18//!
19//! // Evaluate with concrete values
20//! let result: f64 = g.eval(dot, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
21//! assert!((result - 32.0).abs() < 1e-10);
22//!
23//! // Symbolic differentiation
24//! let ddot_dx0 = g.diff(dot, 0);
25//! let ddot_dx0 = g.simplify(ddot_dx0);
26//! ```
27
28pub mod codegen;
29pub mod compile;
30pub mod diff;
31pub mod display;
32pub mod eval;
33pub mod graph;
34pub mod node;
35mod scalar;
36pub mod simplify;
37pub mod sparsity;
38pub mod wgsl;
39
40pub use graph::ExprGraph;
41pub use node::{ExprId, Node};
42
43use std::cell::RefCell;
44
45thread_local! {
46    static GRAPH: RefCell<ExprGraph> = RefCell::new(ExprGraph::new());
47}
48
49/// Access the thread-local graph.
50pub fn with_graph<F, R>(f: F) -> R
51where
52    F: FnOnce(&mut ExprGraph) -> R,
53{
54    GRAPH.with(|g| f(&mut g.borrow_mut()))
55}
56
57/// Run a closure with a fresh graph, returning the graph and result.
58///
59/// Installs a new empty graph, runs `f` (which builds the expression via
60/// `ExprId` arithmetic / `Scalar` calls), then extracts the graph.
61pub fn trace<F, R>(f: F) -> (ExprGraph, R)
62where
63    F: FnOnce() -> R,
64{
65    // Swap in a fresh graph
66    GRAPH.with(|g| {
67        let old = std::mem::take(&mut *g.borrow_mut());
68        let result = f();
69        let graph = std::mem::replace(&mut *g.borrow_mut(), old);
70        (graph, result)
71    })
72}
73
74impl ExprId {
75    /// Create a variable node in the thread-local graph.
76    #[inline]
77    pub fn var(n: u16) -> Self {
78        with_graph(|g| g.var(n))
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use tang::{Scalar, Vec3};
86
87    #[test]
88    fn trace_vec3_dot() {
89        let (g, dot) = trace(|| {
90            let a = Vec3::new(ExprId::var(0), ExprId::var(1), ExprId::var(2));
91            let b = Vec3::new(ExprId::var(3), ExprId::var(4), ExprId::var(5));
92            a.dot(b)
93        });
94
95        // Evaluate: [1,2,3] . [4,5,6] = 4+10+18 = 32
96        let result: f64 = g.eval(dot, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
97        assert!((result - 32.0).abs() < 1e-10);
98    }
99
100    #[test]
101    fn trace_vec3_norm() {
102        let (g, norm) = trace(|| {
103            let v = Vec3::new(ExprId::var(0), ExprId::var(1), ExprId::var(2));
104            v.norm()
105        });
106
107        // norm([3,4,0]) = 5
108        let result: f64 = g.eval(norm, &[3.0, 4.0, 0.0]);
109        assert!((result - 5.0).abs() < 1e-10);
110    }
111
112    #[test]
113    fn trace_isolation() {
114        // Traces should be isolated
115        let (g1, _) = trace(|| {
116            let _x = ExprId::var(0);
117        });
118        let (g2, _) = trace(|| {
119            let _x = ExprId::var(0);
120        });
121        // Both graphs should have same size (3 pre-populated + 1 var)
122        assert_eq!(g1.len(), g2.len());
123    }
124
125    #[test]
126    fn from_f64_creates_lit() {
127        let (g, v) = trace(|| ExprId::from_f64(42.0));
128        let result: f64 = g.eval(v, &[]);
129        assert!((result - 42.0).abs() < 1e-10);
130    }
131
132    #[test]
133    fn scalar_constants() {
134        let (g, (zero, one, two)) = trace(|| {
135            let z: ExprId = Scalar::from_f64(0.0);
136            let o: ExprId = Scalar::from_f64(1.0);
137            let t: ExprId = Scalar::from_f64(2.0);
138            (z, o, t)
139        });
140        assert_eq!(g.eval::<f64>(zero, &[]), 0.0);
141        assert_eq!(g.eval::<f64>(one, &[]), 1.0);
142        assert_eq!(g.eval::<f64>(two, &[]), 2.0);
143    }
144}