1use crate::dual::Dual;
8use crate::float::Float;
9use crate::reverse::Reverse;
10use crate::tape::{Tape, TapeGuard, TapeThreadLocal};
11
12#[cfg(feature = "bytecode")]
13use crate::breverse::BReverse;
14#[cfg(feature = "bytecode")]
15use crate::bytecode_tape::{BtapeGuard, BtapeThreadLocal, BytecodeTape, CONSTANT};
16
17pub fn grad<F: Float + TapeThreadLocal>(
27 f: impl FnOnce(&[Reverse<F>]) -> Reverse<F>,
28 x: &[F],
29) -> Vec<F> {
30 let n = x.len();
31 let mut tape = Tape::take_pooled(n * 10);
32
33 let inputs: Vec<Reverse<F>> = x
35 .iter()
36 .map(|&val| {
37 let (idx, v) = tape.new_variable(val);
38 Reverse::from_tape(v, idx)
39 })
40 .collect();
41
42 let guard = TapeGuard::new(&mut tape);
43 let output = f(&inputs);
44 drop(guard);
45
46 if output.index == crate::tape::CONSTANT {
48 Tape::return_to_pool(tape);
49 return vec![F::zero(); n];
50 }
51
52 let adjoints = tape.reverse(output.index);
54
55 let result = (0..n).map(|i| adjoints[i]).collect();
57 Tape::return_to_pool(tape);
58 result
59}
60
61pub fn jvp<F: Float>(f: impl Fn(&[Dual<F>]) -> Vec<Dual<F>>, x: &[F], v: &[F]) -> (Vec<F>, Vec<F>) {
65 assert_eq!(x.len(), v.len(), "x and v must have the same length");
66 let inputs: Vec<Dual<F>> = x
67 .iter()
68 .zip(v.iter())
69 .map(|(&xi, &vi)| Dual::new(xi, vi))
70 .collect();
71 let outputs = f(&inputs);
72 let values = outputs.iter().map(|d| d.re).collect();
73 let tangents = outputs.iter().map(|d| d.eps).collect();
74 (values, tangents)
75}
76
77pub fn vjp<F: Float + TapeThreadLocal>(
81 f: impl FnOnce(&[Reverse<F>]) -> Vec<Reverse<F>>,
82 x: &[F],
83 w: &[F],
84) -> (Vec<F>, Vec<F>) {
85 let n = x.len();
86 let mut tape = Tape::take_pooled(n * 10);
87
88 let inputs: Vec<Reverse<F>> = x
89 .iter()
90 .map(|&val| {
91 let (idx, v) = tape.new_variable(val);
92 Reverse::from_tape(v, idx)
93 })
94 .collect();
95
96 let guard = TapeGuard::new(&mut tape);
97 let outputs = f(&inputs);
98 drop(guard);
99
100 assert_eq!(
101 outputs.len(),
102 w.len(),
103 "output length must match weight vector length"
104 );
105
106 let values: Vec<F> = outputs.iter().map(|r| r.value).collect();
107
108 let seeds: Vec<(u32, F)> = outputs
110 .iter()
111 .zip(w.iter())
112 .filter(|(r, _)| r.index != crate::tape::CONSTANT)
113 .map(|(r, &wi)| (r.index, wi))
114 .collect();
115 let adjoints = tape.reverse_seeded(&seeds);
116
117 let grad: Vec<F> = (0..n).map(|i| adjoints[i]).collect();
118 let result = (values, grad);
119 Tape::return_to_pool(tape);
120 result
121}
122
123pub fn jacobian<F: Float>(
127 f: impl Fn(&[Dual<F>]) -> Vec<Dual<F>>,
128 x: &[F],
129) -> (Vec<F>, Vec<Vec<F>>) {
130 let n = x.len();
131
132 let const_inputs: Vec<Dual<F>> = x.iter().map(|&xi| Dual::constant(xi)).collect();
134 let const_outputs = f(&const_inputs);
135 let m = const_outputs.len();
136 let values: Vec<F> = const_outputs.iter().map(|d| d.re).collect();
137
138 let mut jac = vec![vec![F::zero(); n]; m];
140 for j in 0..n {
141 let inputs: Vec<Dual<F>> = x
142 .iter()
143 .enumerate()
144 .map(|(k, &xi)| {
145 if k == j {
146 Dual::variable(xi)
147 } else {
148 Dual::constant(xi)
149 }
150 })
151 .collect();
152 let outputs = f(&inputs);
153 for (row, out) in jac.iter_mut().zip(outputs.iter()) {
154 row[j] = out.eps;
155 }
156 }
157
158 (values, jac)
159}
160
161#[cfg(feature = "bytecode")]
186pub fn record<F: Float + BtapeThreadLocal>(
187 f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
188 x: &[F],
189) -> (BytecodeTape<F>, F) {
190 let n = x.len();
191 let mut tape = BytecodeTape::with_capacity(n * 10);
192
193 let inputs: Vec<BReverse<F>> = x
195 .iter()
196 .map(|&val| {
197 let idx = tape.new_input(val);
198 BReverse::from_tape(val, idx)
199 })
200 .collect();
201
202 let output = {
203 let _guard = BtapeGuard::new(&mut tape);
204 f(&inputs)
205 };
206
207 let output_index = if output.index == CONSTANT {
210 tape.push_const(output.value)
211 } else {
212 output.index
213 };
214 tape.set_output(output_index);
215 let value = output.value;
216 (tape, value)
217}
218
219#[cfg(feature = "bytecode")]
227pub fn record_multi<F: Float + BtapeThreadLocal>(
228 f: impl FnOnce(&[BReverse<F>]) -> Vec<BReverse<F>>,
229 x: &[F],
230) -> (BytecodeTape<F>, Vec<F>) {
231 let n = x.len();
232 let mut tape = BytecodeTape::with_capacity(n * 10);
233
234 let inputs: Vec<BReverse<F>> = x
236 .iter()
237 .map(|&val| {
238 let idx = tape.new_input(val);
239 BReverse::from_tape(val, idx)
240 })
241 .collect();
242
243 let outputs = {
244 let _guard = BtapeGuard::new(&mut tape);
245 f(&inputs)
246 };
247
248 assert!(
254 !outputs.is_empty(),
255 "record_multi: closure returned zero outputs; record_multi is for \
256 vector-valued f : R^n -> R^m with m >= 1"
257 );
258
259 let values: Vec<F> = outputs.iter().map(|o| o.value).collect();
260 let indices: Vec<u32> = outputs
262 .iter()
263 .map(|o| {
264 if o.index == CONSTANT {
265 tape.push_const(o.value)
266 } else {
267 o.index
268 }
269 })
270 .collect();
271
272 tape.set_outputs(&indices);
273 if let Some(&first) = indices.first() {
275 tape.set_output(first);
276 }
277
278 (tape, values)
279}
280
281#[cfg(feature = "bytecode")]
288pub fn hvp<F: Float + BtapeThreadLocal>(
289 f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
290 x: &[F],
291 v: &[F],
292) -> (Vec<F>, Vec<F>) {
293 let (tape, _) = record(f, x);
294 tape.hvp(x, v)
295}
296
297#[cfg(feature = "bytecode")]
304pub fn hessian<F: Float + BtapeThreadLocal>(
305 f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
306 x: &[F],
307) -> (F, Vec<F>, Vec<Vec<F>>) {
308 let (tape, _) = record(f, x);
309 tape.hessian(x)
310}
311
312#[cfg(feature = "bytecode")]
317pub fn hessian_vec<F: Float + BtapeThreadLocal, const N: usize>(
318 f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
319 x: &[F],
320) -> (F, Vec<F>, Vec<Vec<F>>) {
321 let (tape, _) = record(f, x);
322 tape.hessian_vec::<N>(x)
323}
324
325#[cfg(feature = "bytecode")]
330pub fn sparse_hessian<F: Float + BtapeThreadLocal>(
331 f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
332 x: &[F],
333) -> (F, Vec<F>, crate::sparse::SparsityPattern, Vec<F>) {
334 let (tape, _) = record(f, x);
335 tape.sparse_hessian(x)
336}
337
338#[cfg(feature = "bytecode")]
343pub fn sparse_hessian_vec<F: Float + BtapeThreadLocal, const N: usize>(
344 f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
345 x: &[F],
346) -> (F, Vec<F>, crate::sparse::SparsityPattern, Vec<F>) {
347 let (tape, _) = record(f, x);
348 tape.sparse_hessian_vec::<N>(x)
349}
350
351#[cfg(feature = "bytecode")]
357pub fn sparse_jacobian<F: Float + BtapeThreadLocal>(
358 f: impl FnOnce(&[BReverse<F>]) -> Vec<BReverse<F>>,
359 x: &[F],
360) -> (Vec<F>, crate::sparse::JacobianSparsityPattern, Vec<F>) {
361 let (mut tape, _) = record_multi(f, x);
362 tape.sparse_jacobian(x)
363}
364
365#[cfg(feature = "bytecode")]
376pub fn composed_hvp<F, Func>(f: Func, x: &[F], v: &[F]) -> (F, Vec<F>, Vec<F>)
377where
378 F: Float + BtapeThreadLocal,
379 Func: FnOnce(&[Dual<BReverse<F>>]) -> Dual<BReverse<F>>,
380{
381 let n = x.len();
382 assert_eq!(x.len(), v.len(), "x and v must have the same length");
383
384 let mut tape = BytecodeTape::with_capacity(n * 30);
385
386 let inputs: Vec<Dual<BReverse<F>>> = x
389 .iter()
390 .zip(v.iter())
391 .map(|(&xi, &vi)| {
392 let idx = tape.new_input(xi);
393 let re = BReverse::from_tape(xi, idx);
394 let eps = BReverse::constant(vi);
395 Dual::new(re, eps)
396 })
397 .collect();
398
399 let output = {
400 let _guard = BtapeGuard::new(&mut tape);
401 f(&inputs)
402 };
403
404 let value = output.re.value;
405 let primal_index = output.re.index;
406 let tangent_index = output.eps.index;
407
408 let gradient = if primal_index != crate::bytecode_tape::CONSTANT {
410 let adjoints = tape.reverse(primal_index);
411 adjoints[..n].to_vec()
412 } else {
413 vec![F::zero(); n]
414 };
415
416 let hvp = if tangent_index != crate::bytecode_tape::CONSTANT {
418 let adjoints = tape.reverse(tangent_index);
419 adjoints[..n].to_vec()
420 } else {
421 vec![F::zero(); n]
422 };
423
424 (value, gradient, hvp)
425}