Skip to main content

tml_utils/
autodiff.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::{cell::RefCell, rc::Rc};
4
5use crate::Float;
6
7/// Node identifier for expression graphs.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct NodeId {
10    index: usize,
11    graph_id: u64,
12}
13
14impl NodeId {
15    fn new(index: usize, graph_id: u64) -> Self {
16        Self { index, graph_id }
17    }
18}
19
20static NEXT_GRAPH_ID: AtomicU64 = AtomicU64::new(1);
21
22/// Expression graph with optimized performance.
23/// Forward evaluation is pure; reuse an [`EvalTape`] to cache intermediates explicitly.
24#[derive(Debug)]
25pub struct ExprGraph {
26    graph_id: u64,
27    nodes: Vec<Node>,
28    node_map: HashMap<String, NodeId>,
29    inputs: Vec<NodeId>,
30    input_names: Vec<String>,
31    outputs: Vec<NodeId>,
32    max_arity: usize,
33    next_id: usize,
34}
35
36/// Node in the computation graph
37#[derive(Debug, Clone)]
38pub enum Node {
39    Input(String),
40    Const(Float),
41    AfterOperation(Op, Box<[NodeId]>),
42    Output(NodeId),
43}
44
45/// Operations that can be performed on nodes
46#[derive(Debug, Clone, Copy)]
47pub enum Op {
48    Scale(Float),
49    Sin,
50    Cos,
51    Pow(i32),
52    Add,
53    Mul,
54}
55
56/// Workspace that stores intermediate primals and gradient vectors during evaluation.
57/// Reuse it across calls to avoid repeated allocations when performance matters.
58#[derive(Debug, Default)]
59pub struct EvalTape {
60    primals: Vec<Float>,
61    tangents: Vec<Float>,
62    input_count: usize,
63    scratch_primals: Vec<Float>,
64    scratch_partials: Vec<Float>,
65}
66
67impl EvalTape {
68    pub fn new() -> Self {
69        Self::default()
70    }
71
72    pub fn with_capacity(nodes: usize, input_count: usize, max_arity: usize) -> Self {
73        Self {
74            primals: Vec::with_capacity(nodes),
75            tangents: Vec::with_capacity(nodes * input_count),
76            input_count,
77            scratch_primals: Vec::with_capacity(max_arity),
78            scratch_partials: Vec::with_capacity(max_arity),
79        }
80    }
81
82    fn reset(&mut self, nodes: usize, input_count: usize, max_arity: usize) {
83        self.input_count = input_count;
84        self.primals.clear();
85        self.tangents.clear();
86        self.primals.resize(nodes, 0.0);
87        self.tangents.resize(nodes * input_count, 0.0);
88        self.scratch_primals.clear();
89        self.scratch_partials.clear();
90        self.scratch_primals.resize(max_arity, 0.0);
91        self.scratch_partials.resize(max_arity, 0.0);
92    }
93
94    fn tangent_index(&self, node_idx: usize, input_idx: usize) -> usize {
95        node_idx * self.input_count + input_idx
96    }
97}
98
99/// Workspace that stores intermediate primals and adjoints during reverse-mode evaluation.
100#[derive(Debug, Default)]
101pub struct ReverseTape {
102    primals: Vec<Float>,
103    adjoints: Vec<Float>,
104    scratch_primals: Vec<Float>,
105    scratch_partials: Vec<Float>,
106}
107
108impl ReverseTape {
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    pub fn with_capacity(nodes: usize, max_arity: usize) -> Self {
114        Self {
115            primals: Vec::with_capacity(nodes),
116            adjoints: Vec::with_capacity(nodes),
117            scratch_primals: Vec::with_capacity(max_arity),
118            scratch_partials: Vec::with_capacity(max_arity),
119        }
120    }
121
122    fn reset(&mut self, nodes: usize, max_arity: usize) {
123        self.primals.clear();
124        self.adjoints.clear();
125        self.primals.resize(nodes, 0.0);
126        self.adjoints.resize(nodes, 0.0);
127        self.scratch_primals.clear();
128        self.scratch_partials.clear();
129        self.scratch_primals.resize(max_arity, 0.0);
130        self.scratch_partials.resize(max_arity, 0.0);
131    }
132}
133
134impl Op {
135    fn validate_arity(self, inputs_len: usize) {
136        let ok = match self {
137            Op::Scale(_) | Op::Sin | Op::Cos | Op::Pow(_) => inputs_len == 1,
138            Op::Add | Op::Mul => inputs_len >= 2,
139        };
140
141        assert!(
142            ok,
143            "invalid arity for {:?}: expected {}, got {}",
144            self,
145            match self {
146                Op::Scale(_) | Op::Sin | Op::Cos | Op::Pow(_) => "1",
147                Op::Add | Op::Mul => ">= 2",
148            },
149            inputs_len
150        );
151    }
152
153    fn apply(self, inputs: &[Float]) -> Float {
154        match self {
155            Op::Scale(factor) => inputs[0] * factor,
156            Op::Sin => inputs[0].sin(),
157            Op::Cos => inputs[0].cos(),
158            Op::Pow(exp) => inputs[0].powi(exp),
159            Op::Add => inputs.iter().sum(),
160            Op::Mul => inputs.iter().product(),
161        }
162    }
163
164    fn compute_derivative(self, inputs: &[Float], input_idx: usize) -> Float {
165        match self {
166            Op::Scale(factor) => factor,
167            Op::Sin => inputs[0].cos(),
168            Op::Cos => -inputs[0].sin(),
169            Op::Pow(exp) => {
170                if exp == 0 {
171                    0.0
172                } else {
173                    exp as Float * inputs[0].powi(exp - 1)
174                }
175            }
176            Op::Add => 1.0,
177            Op::Mul => inputs
178                .iter()
179                .enumerate()
180                .filter(|(i, _)| *i != input_idx)
181                .map(|(_, &x)| x)
182                .product(),
183        }
184    }
185}
186
187impl ExprGraph {
188    pub fn new() -> Self {
189        Self {
190            graph_id: NEXT_GRAPH_ID.fetch_add(1, Ordering::Relaxed),
191            nodes: Vec::new(),
192            node_map: HashMap::new(),
193            inputs: Vec::new(),
194            input_names: Vec::new(),
195            outputs: Vec::new(),
196            max_arity: 0,
197            next_id: 0,
198        }
199    }
200
201    fn make_node_id(&self, index: usize) -> NodeId {
202        NodeId::new(index, self.graph_id)
203    }
204
205    fn is_valid_node(&self, id: NodeId) -> bool {
206        id.graph_id == self.graph_id && id.index < self.next_id
207    }
208
209    fn assert_valid_node(&self, id: NodeId, context: &str) {
210        assert!(
211            self.is_valid_node(id),
212            "{context} does not belong to this graph or is out of bounds"
213        );
214    }
215
216    pub fn input(&mut self, name: String) -> NodeId {
217        assert!(
218            !self.node_map.contains_key(&name),
219            "input name already exists: {name}"
220        );
221
222        let id = self.make_node_id(self.next_id);
223        self.next_id += 1;
224        self.nodes.push(Node::Input(name.clone()));
225        self.node_map.insert(name.clone(), id);
226        self.inputs.push(id);
227        self.input_names.push(name);
228        id
229    }
230
231    pub fn constant(&mut self, value: Float) -> NodeId {
232        let id = self.make_node_id(self.next_id);
233        self.next_id += 1;
234        self.nodes.push(Node::Const(value));
235        id
236    }
237
238    pub fn operation<I>(&mut self, op: Op, inputs: I) -> NodeId
239    where
240        I: AsRef<[NodeId]>,
241    {
242        let inputs_ref = inputs.as_ref();
243        op.validate_arity(inputs_ref.len());
244        assert!(
245            inputs_ref.iter().all(|id| self.is_valid_node(*id)),
246            "operation inputs must reference earlier nodes in the same graph"
247        );
248        self.max_arity = self.max_arity.max(inputs_ref.len());
249        let id = self.make_node_id(self.next_id);
250        self.next_id += 1;
251        self.nodes
252            .push(Node::AfterOperation(op, Box::from(inputs_ref)));
253        id
254    }
255
256    pub fn output(&mut self, node: NodeId) -> NodeId {
257        self.assert_valid_node(node, "output node");
258        let id = self.make_node_id(self.next_id);
259        self.next_id += 1;
260        self.nodes.push(Node::Output(node));
261        self.outputs.push(id);
262        id
263    }
264
265    /// Allocate a forward-mode tape sized for this graph.
266    /// Reuse it to avoid allocations between runs.
267    pub fn fwd_tape(&self) -> EvalTape {
268        EvalTape::with_capacity(self.nodes.len(), self.inputs.len(), self.max_arity)
269    }
270
271    /// Allocate a reverse-mode tape sized for this graph.
272    pub fn tape(&self) -> ReverseTape {
273        self.reverse_tape()
274    }
275
276    pub fn reverse_tape(&self) -> ReverseTape {
277        ReverseTape::with_capacity(self.nodes.len(), self.max_arity)
278    }
279
280    pub fn input_names(&self) -> &[String] {
281        &self.input_names
282    }
283
284    /// Pure forward evaluation that allocates its own tape. Suitable for single-shot calls.
285    /// Returns a value and per-input gradient vector for each output.
286    pub fn eval_fwd(&self, inputs: &[Float]) -> Vec<(Float, Vec<Float>)> {
287        let mut tape = self.fwd_tape();
288        self.eval_fwd_with_tape(inputs, &mut tape)
289    }
290
291    /// Forward evaluation that reuses the provided tape to cache intermediates.
292    /// Returns a value and per-input gradient vector for each output.
293    pub fn eval_fwd_with_tape(
294        &self,
295        inputs: &[Float],
296        tape: &mut EvalTape,
297    ) -> Vec<(Float, Vec<Float>)> {
298        assert_eq!(
299            inputs.len(),
300            self.inputs.len(),
301            "expected {} inputs, got {}",
302            self.inputs.len(),
303            inputs.len()
304        );
305
306        tape.reset(self.nodes.len(), self.inputs.len(), self.max_arity);
307
308        // First pass: handle inputs (ordered by definition)
309        for (input_idx, node_id) in self.inputs.iter().enumerate() {
310            let node_idx = node_id.index;
311            tape.primals[node_idx] = inputs[input_idx];
312            let tangent_idx = tape.tangent_index(node_idx, input_idx);
313            tape.tangents[tangent_idx] = 1.0;
314        }
315
316        // Second pass: handle operations (topological order)
317        for (i, node) in self.nodes.iter().enumerate() {
318            match node {
319                Node::AfterOperation(op, inputs) => {
320                    let arity = inputs.len();
321                    let input_primals = &mut tape.scratch_primals[..arity];
322                    for (slot, &id) in input_primals.iter_mut().zip(inputs.iter()) {
323                        *slot = tape.primals[id.index];
324                    }
325
326                    tape.primals[i] = op.apply(input_primals);
327
328                    // Compute derivatives using chain rule for each input dimension
329                    let partials = &mut tape.scratch_partials[..arity];
330                    for (j, partial) in partials.iter_mut().enumerate() {
331                        *partial = op.compute_derivative(input_primals, j);
332                    }
333
334                    let input_count = tape.input_count;
335                    let tangents = &mut tape.tangents;
336                    for input_dim in 0..input_count {
337                        let mut total = 0.0;
338                        for (j, &input_id) in inputs.iter().enumerate() {
339                            let idx = input_id.index * input_count + input_dim;
340                            total += tangents[idx] * partials[j];
341                        }
342                        let out_idx = i * input_count + input_dim;
343                        tangents[out_idx] = total;
344                    }
345                }
346                Node::Const(value) => {
347                    tape.primals[i] = *value;
348                }
349                _ => {}
350            }
351        }
352
353        // Third pass: handle outputs
354        for (i, node) in self.nodes.iter().enumerate() {
355            if let Node::Output(input_id) = node {
356                tape.primals[i] = tape.primals[input_id.index];
357                let src_start = tape.tangent_index(input_id.index, 0);
358                let dst_start = tape.tangent_index(i, 0);
359                let len = tape.input_count;
360                tape.tangents
361                    .copy_within(src_start..(src_start + len), dst_start);
362            }
363        }
364
365        self.outputs
366            .iter()
367            .map(|id| {
368                let idx = id.index;
369                let start = tape.tangent_index(idx, 0);
370                let end = start + tape.input_count;
371                (tape.primals[idx], tape.tangents[start..end].to_vec())
372            })
373            .collect()
374    }
375
376    pub fn eval_fwd_one(&self, inputs: &[Float]) -> (Float, Vec<Float>) {
377        let mut tape = self.fwd_tape();
378        self.eval_fwd_one_with_tape(inputs, &mut tape)
379    }
380
381    pub fn eval_fwd_one_with_tape(
382        &self,
383        inputs: &[Float],
384        tape: &mut EvalTape,
385    ) -> (Float, Vec<Float>) {
386        let mut outputs = self.eval_fwd_with_tape(inputs, tape);
387        assert!(
388            outputs.len() == 1,
389            "expected a single output, got {}",
390            outputs.len()
391        );
392        outputs.remove(0)
393    }
394
395    pub fn eval_fwd_named(&self, inputs: &[Float]) -> Vec<(Float, Vec<(String, Float)>)> {
396        let mut tape = self.fwd_tape();
397        self.eval_fwd_named_with_tape(inputs, &mut tape)
398    }
399
400    pub fn eval_fwd_named_with_tape(
401        &self,
402        inputs: &[Float],
403        tape: &mut EvalTape,
404    ) -> Vec<(Float, Vec<(String, Float)>)> {
405        let outputs = self.eval_fwd_with_tape(inputs, tape);
406        outputs
407            .into_iter()
408            .map(|(value, grads)| {
409                let named = self
410                    .input_names
411                    .iter()
412                    .cloned()
413                    .zip(grads)
414                    .collect::<Vec<_>>();
415                (value, named)
416            })
417            .collect()
418    }
419
420    /// Reverse-mode evaluation that allocates its own tape. Suitable for single-shot calls.
421    /// Returns a value and per-input gradient vector for each output.
422    pub fn eval(&self, inputs: &[Float]) -> Vec<(Float, Vec<Float>)> {
423        let mut tape = self.reverse_tape();
424        self.eval_with_tape(inputs, &mut tape)
425    }
426
427    /// Reverse-mode evaluation that reuses the provided tape to cache intermediates.
428    /// Returns a value and per-input gradient vector for each output.
429    pub fn eval_with_tape(
430        &self,
431        inputs: &[Float],
432        tape: &mut ReverseTape,
433    ) -> Vec<(Float, Vec<Float>)> {
434        self.eval_for_with_tape(inputs, &self.outputs, tape)
435    }
436
437    /// Reverse-mode evaluation for a selected set of outputs.
438    pub fn eval_for(&self, inputs: &[Float], outputs: &[NodeId]) -> Vec<(Float, Vec<Float>)> {
439        let mut tape = self.reverse_tape();
440        self.eval_for_with_tape(inputs, outputs, &mut tape)
441    }
442
443    /// Reverse-mode evaluation for a selected set of outputs with a reusable tape.
444    pub fn eval_for_with_tape(
445        &self,
446        inputs: &[Float],
447        outputs: &[NodeId],
448        tape: &mut ReverseTape,
449    ) -> Vec<(Float, Vec<Float>)> {
450        assert_eq!(
451            inputs.len(),
452            self.inputs.len(),
453            "expected {} inputs, got {}",
454            self.inputs.len(),
455            inputs.len()
456        );
457        for &output in outputs {
458            self.assert_valid_node(output, "requested output");
459        }
460
461        tape.reset(self.nodes.len(), self.max_arity);
462
463        // Forward primals
464        for (input_idx, node_id) in self.inputs.iter().enumerate() {
465            tape.primals[node_id.index] = inputs[input_idx];
466        }
467
468        for (i, node) in self.nodes.iter().enumerate() {
469            match node {
470                Node::AfterOperation(op, inputs) => {
471                    let arity = inputs.len();
472                    let input_primals = &mut tape.scratch_primals[..arity];
473                    for (slot, &id) in input_primals.iter_mut().zip(inputs.iter()) {
474                        *slot = tape.primals[id.index];
475                    }
476                    tape.primals[i] = op.apply(input_primals);
477                }
478                Node::Output(input_id) => {
479                    tape.primals[i] = tape.primals[input_id.index];
480                }
481                Node::Const(value) => {
482                    tape.primals[i] = *value;
483                }
484                Node::Input(_) => {}
485            }
486        }
487
488        let mut results = Vec::with_capacity(outputs.len());
489
490        for output_id in outputs {
491            tape.adjoints.fill(0.0);
492            tape.adjoints[output_id.index] = 1.0;
493
494            for (i, node) in self.nodes.iter().enumerate().rev() {
495                match node {
496                    Node::Output(input_id) => {
497                        tape.adjoints[input_id.index] += tape.adjoints[i];
498                    }
499                    Node::AfterOperation(op, inputs) => {
500                        let arity = inputs.len();
501                        let input_primals = &mut tape.scratch_primals[..arity];
502                        for (slot, &id) in input_primals.iter_mut().zip(inputs.iter()) {
503                            *slot = tape.primals[id.index];
504                        }
505
506                        let partials = &mut tape.scratch_partials[..arity];
507                        for (j, partial) in partials.iter_mut().enumerate() {
508                            *partial = op.compute_derivative(input_primals, j);
509                        }
510
511                        let adj = tape.adjoints[i];
512                        if adj != 0.0 {
513                            for (j, &input_id) in inputs.iter().enumerate() {
514                                tape.adjoints[input_id.index] += adj * partials[j];
515                            }
516                        }
517                    }
518                    Node::Const(_) | Node::Input(_) => {}
519                }
520            }
521
522            let grads = self
523                .inputs
524                .iter()
525                .map(|id| tape.adjoints[id.index])
526                .collect::<Vec<_>>();
527            results.push((tape.primals[output_id.index], grads));
528        }
529
530        results
531    }
532
533    pub fn eval_one(&self, inputs: &[Float]) -> (Float, Vec<Float>) {
534        let mut tape = self.reverse_tape();
535        self.eval_one_with_tape(inputs, &mut tape)
536    }
537
538    pub fn eval_one_with_tape(
539        &self,
540        inputs: &[Float],
541        tape: &mut ReverseTape,
542    ) -> (Float, Vec<Float>) {
543        let mut outputs = self.eval_with_tape(inputs, tape);
544        assert!(
545            outputs.len() == 1,
546            "expected a single output, got {}",
547            outputs.len()
548        );
549        outputs.remove(0)
550    }
551
552    pub fn eval_named(&self, inputs: &[Float]) -> Vec<(Float, Vec<(String, Float)>)> {
553        let mut tape = self.reverse_tape();
554        self.eval_named_with_tape(inputs, &mut tape)
555    }
556
557    pub fn eval_named_with_tape(
558        &self,
559        inputs: &[Float],
560        tape: &mut ReverseTape,
561    ) -> Vec<(Float, Vec<(String, Float)>)> {
562        let outputs = self.eval_with_tape(inputs, tape);
563        outputs
564            .into_iter()
565            .map(|(value, grads)| {
566                let named = self
567                    .input_names
568                    .iter()
569                    .cloned()
570                    .zip(grads)
571                    .collect::<Vec<_>>();
572                (value, named)
573            })
574            .collect()
575    }
576
577    pub fn eval_named_for(
578        &self,
579        inputs: &[Float],
580        outputs: &[NodeId],
581    ) -> Vec<(Float, Vec<(String, Float)>)> {
582        let mut tape = self.reverse_tape();
583        self.eval_named_for_with_tape(inputs, outputs, &mut tape)
584    }
585
586    pub fn eval_named_for_with_tape(
587        &self,
588        inputs: &[Float],
589        outputs: &[NodeId],
590        tape: &mut ReverseTape,
591    ) -> Vec<(Float, Vec<(String, Float)>)> {
592        let outputs = self.eval_for_with_tape(inputs, outputs, tape);
593        outputs
594            .into_iter()
595            .map(|(value, grads)| {
596                let named = self
597                    .input_names
598                    .iter()
599                    .cloned()
600                    .zip(grads)
601                    .collect::<Vec<_>>();
602                (value, named)
603            })
604            .collect()
605    }
606}
607
608impl Default for ExprGraph {
609    fn default() -> Self {
610        Self::new()
611    }
612}
613
614#[derive(Debug, Clone)]
615pub struct Gradients {
616    pub value: Float,
617    pub grads: Vec<(String, Float)>,
618}
619
620impl Gradients {
621    pub fn get(&self, name: &str) -> Option<Float> {
622        self.grads
623            .iter()
624            .find_map(|(key, value)| (key == name).then_some(*value))
625    }
626}
627
628#[derive(Debug, Clone, PartialEq, Eq)]
629pub enum TapeError {
630    InputLengthMismatch { expected: usize, got: usize },
631    UnknownInput(String),
632}
633
634impl std::fmt::Display for TapeError {
635    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
636        match self {
637            Self::InputLengthMismatch { expected, got } => {
638                write!(f, "expected {expected} inputs, got {got}")
639            }
640            Self::UnknownInput(name) => write!(f, "unknown input name: {name}"),
641        }
642    }
643}
644
645impl std::error::Error for TapeError {}
646
647/// Rust-like autodiff tape with operator overloading.
648#[derive(Debug, Clone)]
649pub struct Tape {
650    inner: Rc<RefCell<TapeInner>>,
651}
652
653#[derive(Debug)]
654struct TapeInner {
655    graph: ExprGraph,
656    values: Vec<Float>,
657}
658
659/// A node handle tied to a [`Tape`].
660#[derive(Debug, Clone)]
661pub struct Var {
662    id: NodeId,
663    inner: Rc<RefCell<TapeInner>>,
664}
665
666impl Tape {
667    pub fn new() -> Self {
668        Self {
669            inner: Rc::new(RefCell::new(TapeInner {
670                graph: ExprGraph::new(),
671                values: Vec::new(),
672            })),
673        }
674    }
675
676    pub fn input(&mut self, name: impl Into<String>, value: Float) -> Var {
677        let mut inner = self.inner.borrow_mut();
678        let id = inner.graph.input(name.into());
679        inner.values.push(value);
680        Var {
681            id,
682            inner: self.inner.clone(),
683        }
684    }
685
686    pub fn input_unnamed(&mut self, value: Float) -> Var {
687        let idx = self.inner.borrow().values.len();
688        self.input(format!("_{}", idx), value)
689    }
690
691    pub fn constant(&mut self, value: Float) -> Var {
692        let mut inner = self.inner.borrow_mut();
693        let id = inner.graph.constant(value);
694        Var {
695            id,
696            inner: self.inner.clone(),
697        }
698    }
699
700    pub fn set_inputs(&mut self, values: &[Float]) {
701        self.try_set_inputs(values)
702            .expect("input length mismatch for Tape::set_inputs");
703    }
704
705    pub fn try_set_inputs(&mut self, values: &[Float]) -> Result<(), TapeError> {
706        let mut inner = self.inner.borrow_mut();
707        let expected = inner.values.len();
708        if values.len() != expected {
709            return Err(TapeError::InputLengthMismatch {
710                expected,
711                got: values.len(),
712            });
713        }
714        inner.values.copy_from_slice(values);
715        Ok(())
716    }
717
718    pub fn set(&mut self, name: &str, value: Float) {
719        self.try_set(name, value)
720            .expect("unknown input name for Tape::set");
721    }
722
723    pub fn try_set(&mut self, name: &str, value: Float) -> Result<(), TapeError> {
724        let mut inner = self.inner.borrow_mut();
725        let Some(idx) = inner.graph.input_names.iter().position(|n| n == name) else {
726            return Err(TapeError::UnknownInput(name.to_string()));
727        };
728        inner.values[idx] = value;
729        Ok(())
730    }
731
732    pub fn input_names(&self) -> Vec<String> {
733        self.inner.borrow().graph.input_names.clone()
734    }
735
736    pub fn gradients(&self, output: &Var) -> Gradients {
737        output.assert_same_tape(self);
738        let inner = self.inner.borrow();
739        let results = inner.graph.eval_named_for(&inner.values, &[output.id]);
740        let (value, grads) = results.into_iter().next().expect("missing output");
741        Gradients { value, grads }
742    }
743
744    pub fn gradients_for(&self, outputs: &[Var]) -> Vec<Gradients> {
745        if outputs.is_empty() {
746            return Vec::new();
747        }
748        outputs[0].assert_same_tape(self);
749        for var in outputs.iter().skip(1) {
750            var.assert_same_tape(self);
751        }
752
753        let inner = self.inner.borrow();
754        let ids = outputs.iter().map(|var| var.id).collect::<Vec<_>>();
755        inner
756            .graph
757            .eval_named_for(&inner.values, &ids)
758            .into_iter()
759            .map(|(value, grads)| Gradients { value, grads })
760            .collect()
761    }
762}
763
764impl Default for Tape {
765    fn default() -> Self {
766        Self::new()
767    }
768}
769
770impl Var {
771    fn assert_same_tape(&self, tape: &Tape) {
772        assert!(
773            Rc::ptr_eq(&self.inner, &tape.inner),
774            "cannot mix Vars from different tapes"
775        );
776    }
777
778    fn assert_same_var_tape(&self, other: &Var) {
779        assert!(
780            Rc::ptr_eq(&self.inner, &other.inner),
781            "cannot mix Vars from different tapes"
782        );
783    }
784
785    fn unary_op(&self, op: Op) -> Var {
786        let mut inner = self.inner.borrow_mut();
787        let id = inner.graph.operation(op, vec![self.id]);
788        Var {
789            id,
790            inner: self.inner.clone(),
791        }
792    }
793
794    fn binary_op(&self, rhs: &Var, op: Op) -> Var {
795        self.assert_same_var_tape(rhs);
796        let mut inner = self.inner.borrow_mut();
797        let id = inner.graph.operation(op, vec![self.id, rhs.id]);
798        Var {
799            id,
800            inner: self.inner.clone(),
801        }
802    }
803
804    fn konst(&self, value: Float) -> Var {
805        let mut inner = self.inner.borrow_mut();
806        let id = inner.graph.constant(value);
807        Var {
808            id,
809            inner: self.inner.clone(),
810        }
811    }
812
813    pub fn sin(&self) -> Var {
814        self.unary_op(Op::Sin)
815    }
816
817    pub fn cos(&self) -> Var {
818        self.unary_op(Op::Cos)
819    }
820
821    pub fn powi(&self, exp: i32) -> Var {
822        self.unary_op(Op::Pow(exp))
823    }
824
825    pub fn scale(&self, factor: Float) -> Var {
826        self.unary_op(Op::Scale(factor))
827    }
828}
829
830impl std::ops::Add for Var {
831    type Output = Var;
832    fn add(self, rhs: Var) -> Self::Output {
833        self.binary_op(&rhs, Op::Add)
834    }
835}
836
837impl std::ops::Add<Float> for Var {
838    type Output = Var;
839    fn add(self, rhs: Float) -> Self::Output {
840        let rhs = self.konst(rhs);
841        self.binary_op(&rhs, Op::Add)
842    }
843}
844
845impl std::ops::Sub for Var {
846    type Output = Var;
847    fn sub(self, rhs: Var) -> Self::Output {
848        self + (-rhs)
849    }
850}
851
852impl std::ops::Sub<Float> for Var {
853    type Output = Var;
854    fn sub(self, rhs: Float) -> Self::Output {
855        self + (-rhs)
856    }
857}
858
859impl std::ops::Mul for Var {
860    type Output = Var;
861    fn mul(self, rhs: Var) -> Self::Output {
862        self.binary_op(&rhs, Op::Mul)
863    }
864}
865
866impl std::ops::Mul<Float> for Var {
867    type Output = Var;
868    fn mul(self, rhs: Float) -> Self::Output {
869        self.scale(rhs)
870    }
871}
872
873impl std::ops::Div for Var {
874    type Output = Var;
875    fn div(self, rhs: Var) -> Self::Output {
876        self * rhs.powi(-1)
877    }
878}
879
880impl std::ops::Div<Float> for Var {
881    type Output = Var;
882    fn div(self, rhs: Float) -> Self::Output {
883        self.scale(1.0 / rhs)
884    }
885}
886
887impl std::ops::Neg for Var {
888    type Output = Var;
889    fn neg(self) -> Self::Output {
890        self.scale(-1.0)
891    }
892}
893
894/// Macro for building differentiable expressions.
895///
896/// # Examples
897///
898/// Single input expression:
899/// ```rust,ignore
900/// let expr = expr! {
901///     input -> Sin -> Cos -> output
902/// };
903/// ```
904///
905/// Multi-input expression:
906/// ```rust,ignore
907/// let expr = expr! {
908///     inputs: [x, y]
909///     x -> Pow(2) -> @x_sq
910///     y -> Sin -> @y_sin
911///     (@x_sq, @y_sin) -> Add -> @result
912///     output @result
913/// };
914/// ```
915///
916/// Mixed expression (operations without intermediate names):
917/// ```rust,ignore
918/// let expr = expr! {
919///     inputs: [x, y]
920///     x -> Pow(2) -> @temp1
921///     y -> Cos -> @temp2
922///     (@temp1, @temp2) -> Mul -> @res
923///     output @res
924/// };
925/// ```
926///
927/// # Performance Notes
928///
929/// The default `eval` path allocates a fresh [`ReverseTape`] each call for purity.
930/// When you need to reuse buffers, create a tape with `expr.tape()` (or
931/// `expr.reverse_tape()`) and call `eval_with_tape` to keep allocations off the hot
932/// path. Operation arity is validated at runtime.
933#[macro_export]
934macro_rules! expr {
935    // Single-input expression.
936    (input -> $($rest:tt)*) => {
937        {
938            use $crate::autodiff::{ExprGraph, Op};
939            let mut graph = ExprGraph::new();
940            let __input = graph.input("input".to_string());
941            $crate::expr! {
942                @build_single
943                graph,
944                __input,
945                $($rest)*
946            }
947        }
948    };
949
950    // Multi-input expression.
951    (inputs: [$($input:ident),*] $($rest:tt)*) => {
952        {
953            use $crate::autodiff::{ExprGraph, Op};
954            let mut graph = ExprGraph::new();
955            $(let $input = graph.input(stringify!($input).to_string());)*
956            $crate::expr! {
957                @build_multi
958                graph,
959                $($rest)*
960            }
961        }
962    };
963
964    // Single-input builder.
965    (@build_single $graph:ident, $node:ident, Add -> $($rest:tt)*) => {
966        compile_error!("Add is n-ary; use `inputs: [...]` and `(@a, @b, ...) -> Add`");
967    };
968
969    (@build_single $graph:ident, $node:ident, Mul -> $($rest:tt)*) => {
970        compile_error!("Mul is n-ary; use `inputs: [...]` and `(@a, @b, ...) -> Mul`");
971    };
972
973    (@build_single $graph:ident, $node:ident, $op:ident -> $($rest:tt)*) => {
974        let __next = $graph.operation(Op::$op, vec![$node]);
975        $crate::expr! {
976            @build_single
977            $graph,
978            __next,
979            $($rest)*
980        }
981    };
982
983    (@build_single $graph:ident, $node:ident, $op:ident ( $($op_args:tt)* ) -> $($rest:tt)*) => {
984        let __next = $graph.operation(Op::$op($($op_args)*), vec![$node]);
985        $crate::expr! {
986            @build_single
987            $graph,
988            __next,
989            $($rest)*
990        }
991    };
992
993    (@build_single $graph:ident, $node:ident, output) => {
994        $graph.output($node);
995        $graph
996    };
997
998    (@build_multi $graph:ident, $node:ident -> Add -> @ $result:ident $($rest:tt)*) => {
999        compile_error!("Add is n-ary; use (@a, @b, ...) -> Add");
1000    };
1001
1002    (@build_multi $graph:ident, $node:ident -> Mul -> @ $result:ident $($rest:tt)*) => {
1003        compile_error!("Mul is n-ary; use (@a, @b, ...) -> Mul");
1004    };
1005
1006    (@build_multi $graph:ident, $node:ident -> Add ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1007        compile_error!("Add takes no arguments and is n-ary; use (@a, @b, ...) -> Add");
1008    };
1009
1010    (@build_multi $graph:ident, $node:ident -> Mul ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1011        compile_error!("Mul takes no arguments and is n-ary; use (@a, @b, ...) -> Mul");
1012    };
1013
1014    (@build_multi $graph:ident, $node:ident -> $op:ident -> @ $result:ident $($rest:tt)*) => {
1015        let $result = $graph.operation(Op::$op, vec![$node]);
1016        $crate::expr! { @build_multi $graph, $($rest)* }
1017    };
1018
1019    (@build_multi $graph:ident, $node:ident -> $op:ident ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1020        let $result = $graph.operation(Op::$op($($op_args)*), vec![$node]);
1021        $crate::expr! { @build_multi $graph, $($rest)* }
1022    };
1023
1024    // Reject unary ops in n-ary position
1025    (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Sin -> @ $result:ident $($rest:tt)*) => {
1026        compile_error!("Sin is unary; use x -> Sin");
1027    };
1028
1029    (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Cos -> @ $result:ident $($rest:tt)*) => {
1030        compile_error!("Cos is unary; use x -> Cos");
1031    };
1032
1033    (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Scale ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1034        compile_error!("Scale is unary; use x -> Scale(factor)");
1035    };
1036
1037    (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> Pow ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1038        compile_error!("Pow is unary; use x -> Pow(exp)");
1039    };
1040
1041    // Generic N-ary op without extra args: (@a, @b, @c) -> Add -> @result
1042    (@build_multi $graph:ident, ( @ $node:ident ) -> Add -> @ $result:ident $($rest:tt)*) => {
1043        compile_error!("Add requires at least 2 inputs");
1044    };
1045
1046    (@build_multi $graph:ident, ( @ $node:ident ) -> Mul -> @ $result:ident $($rest:tt)*) => {
1047        compile_error!("Mul requires at least 2 inputs");
1048    };
1049
1050    (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> $op:ident -> @ $result:ident $($rest:tt)*) => {
1051        let $result = $graph.operation(Op::$op, vec![$($node),+]);
1052        $crate::expr! { @build_multi $graph, $($rest)* }
1053    };
1054
1055    // Generic N-ary op with extra args: (@a, @b, @c) -> scale(2.0) -> @res
1056    (@build_multi $graph:ident, ( $( @ $node:ident ),+ ) -> $op:ident ( $($op_args:tt)* ) -> @ $result:ident $($rest:tt)*) => {
1057        let $result = $graph.operation(Op::$op($($op_args)*), vec![$($node),+]);
1058        $crate::expr! { @build_multi $graph, $($rest)* }
1059    };
1060
1061    (@build_multi $graph:ident, output @ $node:ident) => {
1062        $graph.output($node);
1063        $graph
1064    };
1065
1066    (@build_multi $graph:ident, output) => {
1067        $graph
1068    };
1069}
1070
1071#[cfg(test)]
1072mod tests {
1073    use super::*;
1074
1075    fn approx_eq(a: Float, b: Float, eps: Float) {
1076        let diff = (a - b).abs();
1077        assert!(diff <= eps, "expected {a} ~= {b} (diff={diff}, eps={eps})");
1078    }
1079
1080    #[test]
1081    fn reverse_matches_forward_and_finite_difference() {
1082        let mut g = ExprGraph::new();
1083        let x = g.input("x".to_string());
1084        let z = g.input("z".to_string());
1085        let x_sq = g.operation(Op::Pow(2), [x]);
1086        let z_cos = g.operation(Op::Cos, [z]);
1087        let sum = g.operation(Op::Add, [x_sq, z_cos]);
1088        let out = g.operation(Op::Sin, [sum]);
1089        g.output(out);
1090
1091        let base = [1.3, -0.7];
1092        let (fwd_val, fwd_grad) = g.eval_fwd_one(&base);
1093        let (rev_val, rev_grad) = g.eval_one(&base);
1094
1095        approx_eq(fwd_val, rev_val, 1e-12);
1096        approx_eq(fwd_grad[0], rev_grad[0], 1e-10);
1097        approx_eq(fwd_grad[1], rev_grad[1], 1e-10);
1098
1099        let eps = 1e-7;
1100        for i in 0..base.len() {
1101            let mut plus = base;
1102            let mut minus = base;
1103            plus[i] += eps;
1104            minus[i] -= eps;
1105            let f_plus = g.eval_fwd_one(&plus).0;
1106            let f_minus = g.eval_fwd_one(&minus).0;
1107            let numeric = (f_plus - f_minus) / (2.0 * eps);
1108            approx_eq(rev_grad[i], numeric, 1e-6);
1109        }
1110    }
1111
1112    #[test]
1113    fn output_rejects_foreign_node_id() {
1114        let mut g1 = ExprGraph::new();
1115        let foreign = g1.input("x".to_string());
1116
1117        let mut g2 = ExprGraph::new();
1118        let _ = g2.input("y".to_string());
1119        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
1120            g2.output(foreign);
1121        }));
1122        assert!(result.is_err());
1123    }
1124
1125    #[test]
1126    fn tape_try_set_variants() {
1127        let mut tape = Tape::new();
1128        let x = tape.input("x", 1.0);
1129        let y = tape.input("y", 2.0);
1130        let out = x + y;
1131
1132        tape.try_set_inputs(&[3.0, 4.0])
1133            .expect("valid input update");
1134        let grads = tape.gradients(&out);
1135        approx_eq(grads.value, 7.0, 1e-12);
1136
1137        let err = tape
1138            .try_set_inputs(&[1.0])
1139            .expect_err("length mismatch should fail");
1140        assert!(matches!(
1141            err,
1142            TapeError::InputLengthMismatch {
1143                expected: 2,
1144                got: 1
1145            }
1146        ));
1147
1148        tape.try_set("x", 5.0).expect("known input should be set");
1149        let err = tape
1150            .try_set("missing", 0.0)
1151            .expect_err("unknown input should fail");
1152        assert!(matches!(err, TapeError::UnknownInput(_)));
1153    }
1154
1155    #[test]
1156    fn pow_zero_has_zero_gradient_at_zero() {
1157        let mut g = ExprGraph::new();
1158        let x = g.input("x".to_string());
1159        let out = g.operation(Op::Pow(0), [x]);
1160        g.output(out);
1161
1162        let (value, grads) = g.eval_one(&[0.0]);
1163        approx_eq(value, 1.0, 1e-12);
1164        approx_eq(grads[0], 0.0, 1e-12);
1165        assert!(grads[0].is_finite());
1166    }
1167}