Skip to main content

echidna/
api.rs

1//! Closure-based API for automatic differentiation.
2//!
3//! Provides top-level functions ([`grad`], [`jvp`], [`vjp`], [`jacobian`]) that handle
4//! tape setup, variable creation, and derivative extraction. With the `bytecode` feature,
5//! adds [`record`], [`hessian`], [`hvp`], [`sparse_jacobian`], [`sparse_hessian`], and more.
6
7use crate::dual::Dual;
8use crate::float::Float;
9use crate::reverse::Reverse;
10use crate::tape::{Tape, TapeGuard, TapeThreadLocal};
11
12#[cfg(feature = "bytecode")]
13use crate::breverse::BReverse;
14#[cfg(feature = "bytecode")]
15use crate::bytecode_tape::{BtapeGuard, BtapeThreadLocal, BytecodeTape, CONSTANT};
16
17/// Compute the gradient of a scalar function `f : R^n → R` using reverse mode.
18///
19/// ```
20/// let g = echidna::grad(|x: &[echidna::Reverse<f64>]| {
21///     x[0] * x[0] + x[1] * x[1]
22/// }, &[3.0, 4.0]);
23/// assert!((g[0] - 6.0).abs() < 1e-10);
24/// assert!((g[1] - 8.0).abs() < 1e-10);
25/// ```
26pub fn grad<F: Float + TapeThreadLocal>(
27    f: impl FnOnce(&[Reverse<F>]) -> Reverse<F>,
28    x: &[F],
29) -> Vec<F> {
30    let n = x.len();
31    let mut tape = Tape::take_pooled(n * 10);
32
33    // Create input variables.
34    let inputs: Vec<Reverse<F>> = x
35        .iter()
36        .map(|&val| {
37            let (idx, v) = tape.new_variable(val);
38            Reverse::from_tape(v, idx)
39        })
40        .collect();
41
42    let guard = TapeGuard::new(&mut tape);
43    let output = f(&inputs);
44    drop(guard);
45
46    // If the output is a constant (independent of all inputs), the gradient is zero.
47    if output.index == crate::tape::CONSTANT {
48        Tape::return_to_pool(tape);
49        return vec![F::zero(); n];
50    }
51
52    // Run reverse sweep.
53    let adjoints = tape.reverse(output.index);
54
55    // Extract gradients for input variables (indices 0..n).
56    let result = (0..n).map(|i| adjoints[i]).collect();
57    Tape::return_to_pool(tape);
58    result
59}
60
61/// Jacobian-vector product (forward mode): `(f(x), J·v)`.
62///
63/// Evaluates `f` at `x` and computes the directional derivative in direction `v`.
64pub fn jvp<F: Float>(f: impl Fn(&[Dual<F>]) -> Vec<Dual<F>>, x: &[F], v: &[F]) -> (Vec<F>, Vec<F>) {
65    assert_eq!(x.len(), v.len(), "x and v must have the same length");
66    let inputs: Vec<Dual<F>> = x
67        .iter()
68        .zip(v.iter())
69        .map(|(&xi, &vi)| Dual::new(xi, vi))
70        .collect();
71    let outputs = f(&inputs);
72    let values = outputs.iter().map(|d| d.re).collect();
73    let tangents = outputs.iter().map(|d| d.eps).collect();
74    (values, tangents)
75}
76
77/// Vector-Jacobian product (reverse mode): `(f(x), wᵀ·J)`.
78///
79/// Evaluates `f` at `x` and computes the adjoint product with weights `w`.
80pub fn vjp<F: Float + TapeThreadLocal>(
81    f: impl FnOnce(&[Reverse<F>]) -> Vec<Reverse<F>>,
82    x: &[F],
83    w: &[F],
84) -> (Vec<F>, Vec<F>) {
85    let n = x.len();
86    let mut tape = Tape::take_pooled(n * 10);
87
88    let inputs: Vec<Reverse<F>> = x
89        .iter()
90        .map(|&val| {
91            let (idx, v) = tape.new_variable(val);
92            Reverse::from_tape(v, idx)
93        })
94        .collect();
95
96    let guard = TapeGuard::new(&mut tape);
97    let outputs = f(&inputs);
98    drop(guard);
99
100    assert_eq!(
101        outputs.len(),
102        w.len(),
103        "output length must match weight vector length"
104    );
105
106    let values: Vec<F> = outputs.iter().map(|r| r.value).collect();
107
108    // Seed adjoints with weights.
109    let seeds: Vec<(u32, F)> = outputs
110        .iter()
111        .zip(w.iter())
112        .filter(|(r, _)| r.index != crate::tape::CONSTANT)
113        .map(|(r, &wi)| (r.index, wi))
114        .collect();
115    let adjoints = tape.reverse_seeded(&seeds);
116
117    let grad: Vec<F> = (0..n).map(|i| adjoints[i]).collect();
118    let result = (values, grad);
119    Tape::return_to_pool(tape);
120    result
121}
122
123/// Compute the full Jacobian of `f : R^n → R^m` using forward mode.
124///
125/// Returns `(f(x), J)` where `J[i][j] = ∂f_i/∂x_j`.
126pub fn jacobian<F: Float>(
127    f: impl Fn(&[Dual<F>]) -> Vec<Dual<F>>,
128    x: &[F],
129) -> (Vec<F>, Vec<Vec<F>>) {
130    let n = x.len();
131
132    // First pass to get output dimension and values.
133    let const_inputs: Vec<Dual<F>> = x.iter().map(|&xi| Dual::constant(xi)).collect();
134    let const_outputs = f(&const_inputs);
135    let m = const_outputs.len();
136    let values: Vec<F> = const_outputs.iter().map(|d| d.re).collect();
137
138    // One forward pass per input variable.
139    let mut jac = vec![vec![F::zero(); n]; m];
140    for j in 0..n {
141        let inputs: Vec<Dual<F>> = x
142            .iter()
143            .enumerate()
144            .map(|(k, &xi)| {
145                if k == j {
146                    Dual::variable(xi)
147                } else {
148                    Dual::constant(xi)
149                }
150            })
151            .collect();
152        let outputs = f(&inputs);
153        for (row, out) in jac.iter_mut().zip(outputs.iter()) {
154            row[j] = out.eps;
155        }
156    }
157
158    (values, jac)
159}
160
161/// Record a function into a [`BytecodeTape`] that can be re-evaluated at
162/// different inputs without re-recording.
163///
164/// Returns the tape and the output value from the recording pass.
165///
166/// # Limitations
167///
168/// The tape records one execution path. If `f` contains branches
169/// (`if x > 0 { ... } else { ... }`), re-evaluating at inputs that take a
170/// different branch produces **incorrect results**.
171///
172/// # Example
173///
174/// ```ignore
175/// let (mut tape, val) = echidna::record(
176///     |x| x[0] * x[0] + x[1] * x[1],
177///     &[3.0, 4.0],
178/// );
179/// assert!((val - 25.0).abs() < 1e-10);
180///
181/// let g = tape.gradient(&[3.0, 4.0]);
182/// assert!((g[0] - 6.0).abs() < 1e-10);
183/// assert!((g[1] - 8.0).abs() < 1e-10);
184/// ```
185#[cfg(feature = "bytecode")]
186pub fn record<F: Float + BtapeThreadLocal>(
187    f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
188    x: &[F],
189) -> (BytecodeTape<F>, F) {
190    let n = x.len();
191    let mut tape = BytecodeTape::with_capacity(n * 10);
192
193    // Register inputs.
194    let inputs: Vec<BReverse<F>> = x
195        .iter()
196        .map(|&val| {
197            let idx = tape.new_input(val);
198            BReverse::from_tape(val, idx)
199        })
200        .collect();
201
202    let output = {
203        let _guard = BtapeGuard::new(&mut tape);
204        f(&inputs)
205    };
206
207    // Promote constant outputs (index == CONSTANT) to a tape entry so
208    // set_output has a valid index. The gradient will correctly be zero.
209    let output_index = if output.index == CONSTANT {
210        tape.push_const(output.value)
211    } else {
212        output.index
213    };
214    tape.set_output(output_index);
215    let value = output.value;
216    (tape, value)
217}
218
219/// Record a multi-output function into a [`BytecodeTape`].
220///
221/// Like [`record`] but for vector-valued functions `f : R^n → R^m`.
222/// The returned tape supports [`jacobian`](BytecodeTape::jacobian),
223/// [`vjp_multi`](BytecodeTape::vjp_multi), and [`reverse_seeded`](BytecodeTape::reverse_seeded).
224///
225/// Returns the tape and the output values from the recording pass.
226#[cfg(feature = "bytecode")]
227pub fn record_multi<F: Float + BtapeThreadLocal>(
228    f: impl FnOnce(&[BReverse<F>]) -> Vec<BReverse<F>>,
229    x: &[F],
230) -> (BytecodeTape<F>, Vec<F>) {
231    let n = x.len();
232    let mut tape = BytecodeTape::with_capacity(n * 10);
233
234    // Register inputs.
235    let inputs: Vec<BReverse<F>> = x
236        .iter()
237        .map(|&val| {
238            let idx = tape.new_input(val);
239            BReverse::from_tape(val, idx)
240        })
241        .collect();
242
243    let outputs = {
244        let _guard = BtapeGuard::new(&mut tape);
245        f(&inputs)
246    };
247
248    // A zero-output tape degenerates silently: `set_outputs(&[])` leaves
249    // `output_index` at its default (0 — typically the first input), and
250    // `num_outputs()` would still report 1, so later calls like `jacobian`
251    // or `output_values` return values unrelated to anything the closure
252    // produced. Reject the degenerate case up front.
253    assert!(
254        !outputs.is_empty(),
255        "record_multi: closure returned zero outputs; record_multi is for \
256         vector-valued f : R^n -> R^m with m >= 1"
257    );
258
259    let values: Vec<F> = outputs.iter().map(|o| o.value).collect();
260    // Promote constant outputs to tape entries (see record() for rationale).
261    let indices: Vec<u32> = outputs
262        .iter()
263        .map(|o| {
264            if o.index == CONSTANT {
265                tape.push_const(o.value)
266            } else {
267                o.index
268            }
269        })
270        .collect();
271
272    tape.set_outputs(&indices);
273    // Also set single output_index for backward compat
274    if let Some(&first) = indices.first() {
275        tape.set_output(first);
276    }
277
278    (tape, values)
279}
280
281/// Hessian-vector product via forward-over-reverse on a bytecode tape.
282///
283/// Records `f` into a [`BytecodeTape`], then computes the gradient and
284/// Hessian-vector product at `x` in direction `v`.
285///
286/// Returns `(gradient, H·v)` where both are `Vec<F>` of length `x.len()`.
287#[cfg(feature = "bytecode")]
288pub fn hvp<F: Float + BtapeThreadLocal>(
289    f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
290    x: &[F],
291    v: &[F],
292) -> (Vec<F>, Vec<F>) {
293    let (tape, _) = record(f, x);
294    tape.hvp(x, v)
295}
296
297/// Full Hessian matrix via forward-over-reverse on a bytecode tape.
298///
299/// Records `f` into a [`BytecodeTape`], then computes the function value,
300/// gradient, and full Hessian at `x`.
301///
302/// Returns `(value, gradient, hessian)` where `hessian[i][j] = ∂²f/∂x_i∂x_j`.
303#[cfg(feature = "bytecode")]
304pub fn hessian<F: Float + BtapeThreadLocal>(
305    f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
306    x: &[F],
307) -> (F, Vec<F>, Vec<Vec<F>>) {
308    let (tape, _) = record(f, x);
309    tape.hessian(x)
310}
311
312/// Full Hessian matrix via batched forward-over-reverse.
313///
314/// Like [`hessian`] but processes N tangent directions simultaneously,
315/// reducing the number of tape traversals from 2n to 2·ceil(n/N).
316#[cfg(feature = "bytecode")]
317pub fn hessian_vec<F: Float + BtapeThreadLocal, const N: usize>(
318    f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
319    x: &[F],
320) -> (F, Vec<F>, Vec<Vec<F>>) {
321    let (tape, _) = record(f, x);
322    tape.hessian_vec::<N>(x)
323}
324
325/// Sparse Hessian via structural sparsity detection and graph coloring.
326///
327/// Returns `(value, gradient, pattern, hessian_values)`.
328/// For sparse problems, this is dramatically faster than [`hessian`].
329#[cfg(feature = "bytecode")]
330pub fn sparse_hessian<F: Float + BtapeThreadLocal>(
331    f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
332    x: &[F],
333) -> (F, Vec<F>, crate::sparse::SparsityPattern, Vec<F>) {
334    let (tape, _) = record(f, x);
335    tape.sparse_hessian(x)
336}
337
338/// Batched sparse Hessian: packs N colors per sweep using DualVec.
339///
340/// Like [`sparse_hessian`] but reduces sweeps from `num_colors` to
341/// `ceil(num_colors / N)`.
342#[cfg(feature = "bytecode")]
343pub fn sparse_hessian_vec<F: Float + BtapeThreadLocal, const N: usize>(
344    f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
345    x: &[F],
346) -> (F, Vec<F>, crate::sparse::SparsityPattern, Vec<F>) {
347    let (tape, _) = record(f, x);
348    tape.sparse_hessian_vec::<N>(x)
349}
350
351/// Sparse Jacobian of a multi-output function via sparsity detection and coloring.
352///
353/// Records `f` and auto-selects forward or reverse mode based on which requires fewer sweeps.
354///
355/// Returns `(output_values, pattern, jacobian_values)`.
356#[cfg(feature = "bytecode")]
357pub fn sparse_jacobian<F: Float + BtapeThreadLocal>(
358    f: impl FnOnce(&[BReverse<F>]) -> Vec<BReverse<F>>,
359    x: &[F],
360) -> (Vec<F>, crate::sparse::JacobianSparsityPattern, Vec<F>) {
361    let (mut tape, _) = record_multi(f, x);
362    tape.sparse_jacobian(x)
363}
364
365/// Forward-over-reverse HVP via type-level composition.
366///
367/// Records `f` with `Dual<BReverse<F>>` inputs (tangent direction `v` baked in
368/// as constants), then runs two reverse sweeps — one from the primal output
369/// (gradient) and one from the tangent output (HVP).
370///
371/// Returns `(f(x), gradient, H·v)`.
372///
373/// For repeated HVP with different `v`, prefer [`record`] + [`BytecodeTape::hvp`].
374/// This function re-records each call.
375#[cfg(feature = "bytecode")]
376pub fn composed_hvp<F, Func>(f: Func, x: &[F], v: &[F]) -> (F, Vec<F>, Vec<F>)
377where
378    F: Float + BtapeThreadLocal,
379    Func: FnOnce(&[Dual<BReverse<F>>]) -> Dual<BReverse<F>>,
380{
381    let n = x.len();
382    assert_eq!(x.len(), v.len(), "x and v must have the same length");
383
384    let mut tape = BytecodeTape::with_capacity(n * 30);
385
386    // Register n input slots for primal x values.
387    // Tangent direction v is baked in as BReverse constants (not tracked on tape).
388    let inputs: Vec<Dual<BReverse<F>>> = x
389        .iter()
390        .zip(v.iter())
391        .map(|(&xi, &vi)| {
392            let idx = tape.new_input(xi);
393            let re = BReverse::from_tape(xi, idx);
394            let eps = BReverse::constant(vi);
395            Dual::new(re, eps)
396        })
397        .collect();
398
399    let output = {
400        let _guard = BtapeGuard::new(&mut tape);
401        f(&inputs)
402    };
403
404    let value = output.re.value;
405    let primal_index = output.re.index;
406    let tangent_index = output.eps.index;
407
408    // Reverse from primal output → gradient.
409    let gradient = if primal_index != crate::bytecode_tape::CONSTANT {
410        let adjoints = tape.reverse(primal_index);
411        adjoints[..n].to_vec()
412    } else {
413        vec![F::zero(); n]
414    };
415
416    // Reverse from tangent output → HVP.
417    let hvp = if tangent_index != crate::bytecode_tape::CONSTANT {
418        let adjoints = tape.reverse(tangent_index);
419        adjoints[..n].to_vec()
420    } else {
421        vec![F::zero(); n]
422    };
423
424    (value, gradient, hvp)
425}