use crate::dual::Dual;
use crate::float::Float;
use crate::reverse::Reverse;
use crate::tape::{Tape, TapeGuard, TapeThreadLocal};
#[cfg(feature = "bytecode")]
use crate::breverse::BReverse;
#[cfg(feature = "bytecode")]
use crate::bytecode_tape::{BtapeGuard, BtapeThreadLocal, BytecodeTape, CONSTANT};
pub fn grad<F: Float + TapeThreadLocal>(
f: impl FnOnce(&[Reverse<F>]) -> Reverse<F>,
x: &[F],
) -> Vec<F> {
let n = x.len();
let mut tape = Tape::take_pooled(n * 10);
let inputs: Vec<Reverse<F>> = x
.iter()
.map(|&val| {
let (idx, v) = tape.new_variable(val);
Reverse::from_tape(v, idx)
})
.collect();
let guard = TapeGuard::new(&mut tape);
let output = f(&inputs);
drop(guard);
if output.index == crate::tape::CONSTANT {
Tape::return_to_pool(tape);
return vec![F::zero(); n];
}
let adjoints = tape.reverse(output.index);
let result = (0..n).map(|i| adjoints[i]).collect();
Tape::return_to_pool(tape);
result
}
pub fn jvp<F: Float>(f: impl Fn(&[Dual<F>]) -> Vec<Dual<F>>, x: &[F], v: &[F]) -> (Vec<F>, Vec<F>) {
assert_eq!(x.len(), v.len(), "x and v must have the same length");
let inputs: Vec<Dual<F>> = x
.iter()
.zip(v.iter())
.map(|(&xi, &vi)| Dual::new(xi, vi))
.collect();
let outputs = f(&inputs);
let values = outputs.iter().map(|d| d.re).collect();
let tangents = outputs.iter().map(|d| d.eps).collect();
(values, tangents)
}
pub fn vjp<F: Float + TapeThreadLocal>(
f: impl FnOnce(&[Reverse<F>]) -> Vec<Reverse<F>>,
x: &[F],
w: &[F],
) -> (Vec<F>, Vec<F>) {
let n = x.len();
let mut tape = Tape::take_pooled(n * 10);
let inputs: Vec<Reverse<F>> = x
.iter()
.map(|&val| {
let (idx, v) = tape.new_variable(val);
Reverse::from_tape(v, idx)
})
.collect();
let guard = TapeGuard::new(&mut tape);
let outputs = f(&inputs);
drop(guard);
assert_eq!(
outputs.len(),
w.len(),
"output length must match weight vector length"
);
let values: Vec<F> = outputs.iter().map(|r| r.value).collect();
let seeds: Vec<(u32, F)> = outputs
.iter()
.zip(w.iter())
.filter(|(r, _)| r.index != crate::tape::CONSTANT)
.map(|(r, &wi)| (r.index, wi))
.collect();
let adjoints = tape.reverse_seeded(&seeds);
let grad: Vec<F> = (0..n).map(|i| adjoints[i]).collect();
let result = (values, grad);
Tape::return_to_pool(tape);
result
}
pub fn jacobian<F: Float>(
f: impl Fn(&[Dual<F>]) -> Vec<Dual<F>>,
x: &[F],
) -> (Vec<F>, Vec<Vec<F>>) {
let n = x.len();
let const_inputs: Vec<Dual<F>> = x.iter().map(|&xi| Dual::constant(xi)).collect();
let const_outputs = f(&const_inputs);
let m = const_outputs.len();
let values: Vec<F> = const_outputs.iter().map(|d| d.re).collect();
let mut jac = vec![vec![F::zero(); n]; m];
for j in 0..n {
let inputs: Vec<Dual<F>> = x
.iter()
.enumerate()
.map(|(k, &xi)| {
if k == j {
Dual::variable(xi)
} else {
Dual::constant(xi)
}
})
.collect();
let outputs = f(&inputs);
for (row, out) in jac.iter_mut().zip(outputs.iter()) {
row[j] = out.eps;
}
}
(values, jac)
}
#[cfg(feature = "bytecode")]
pub fn record<F: Float + BtapeThreadLocal>(
f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x: &[F],
) -> (BytecodeTape<F>, F) {
let n = x.len();
let mut tape = BytecodeTape::with_capacity(n * 10);
let inputs: Vec<BReverse<F>> = x
.iter()
.map(|&val| {
let idx = tape.new_input(val);
BReverse::from_tape(val, idx)
})
.collect();
let output = {
let _guard = BtapeGuard::new(&mut tape);
f(&inputs)
};
let output_index = if output.index == CONSTANT {
tape.push_const(output.value)
} else {
output.index
};
tape.set_output(output_index);
let value = output.value;
(tape, value)
}
#[cfg(feature = "bytecode")]
pub fn record_multi<F: Float + BtapeThreadLocal>(
f: impl FnOnce(&[BReverse<F>]) -> Vec<BReverse<F>>,
x: &[F],
) -> (BytecodeTape<F>, Vec<F>) {
let n = x.len();
let mut tape = BytecodeTape::with_capacity(n * 10);
let inputs: Vec<BReverse<F>> = x
.iter()
.map(|&val| {
let idx = tape.new_input(val);
BReverse::from_tape(val, idx)
})
.collect();
let outputs = {
let _guard = BtapeGuard::new(&mut tape);
f(&inputs)
};
assert!(
!outputs.is_empty(),
"record_multi: closure returned zero outputs; record_multi is for \
vector-valued f : R^n -> R^m with m >= 1"
);
let values: Vec<F> = outputs.iter().map(|o| o.value).collect();
let indices: Vec<u32> = outputs
.iter()
.map(|o| {
if o.index == CONSTANT {
tape.push_const(o.value)
} else {
o.index
}
})
.collect();
tape.set_outputs(&indices);
if let Some(&first) = indices.first() {
tape.set_output(first);
}
(tape, values)
}
#[cfg(feature = "bytecode")]
pub fn hvp<F: Float + BtapeThreadLocal>(
f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x: &[F],
v: &[F],
) -> (Vec<F>, Vec<F>) {
let (tape, _) = record(f, x);
tape.hvp(x, v)
}
#[cfg(feature = "bytecode")]
pub fn hessian<F: Float + BtapeThreadLocal>(
f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x: &[F],
) -> (F, Vec<F>, Vec<Vec<F>>) {
let (tape, _) = record(f, x);
tape.hessian(x)
}
#[cfg(feature = "bytecode")]
pub fn hessian_vec<F: Float + BtapeThreadLocal, const N: usize>(
f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x: &[F],
) -> (F, Vec<F>, Vec<Vec<F>>) {
let (tape, _) = record(f, x);
tape.hessian_vec::<N>(x)
}
#[cfg(feature = "bytecode")]
pub fn sparse_hessian<F: Float + BtapeThreadLocal>(
f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x: &[F],
) -> (F, Vec<F>, crate::sparse::SparsityPattern, Vec<F>) {
let (tape, _) = record(f, x);
tape.sparse_hessian(x)
}
#[cfg(feature = "bytecode")]
pub fn sparse_hessian_vec<F: Float + BtapeThreadLocal, const N: usize>(
f: impl FnOnce(&[BReverse<F>]) -> BReverse<F>,
x: &[F],
) -> (F, Vec<F>, crate::sparse::SparsityPattern, Vec<F>) {
let (tape, _) = record(f, x);
tape.sparse_hessian_vec::<N>(x)
}
#[cfg(feature = "bytecode")]
pub fn sparse_jacobian<F: Float + BtapeThreadLocal>(
f: impl FnOnce(&[BReverse<F>]) -> Vec<BReverse<F>>,
x: &[F],
) -> (Vec<F>, crate::sparse::JacobianSparsityPattern, Vec<F>) {
let (mut tape, _) = record_multi(f, x);
tape.sparse_jacobian(x)
}
#[cfg(feature = "bytecode")]
pub fn composed_hvp<F, Func>(f: Func, x: &[F], v: &[F]) -> (F, Vec<F>, Vec<F>)
where
F: Float + BtapeThreadLocal,
Func: FnOnce(&[Dual<BReverse<F>>]) -> Dual<BReverse<F>>,
{
let n = x.len();
assert_eq!(x.len(), v.len(), "x and v must have the same length");
let mut tape = BytecodeTape::with_capacity(n * 30);
let inputs: Vec<Dual<BReverse<F>>> = x
.iter()
.zip(v.iter())
.map(|(&xi, &vi)| {
let idx = tape.new_input(xi);
let re = BReverse::from_tape(xi, idx);
let eps = BReverse::constant(vi);
Dual::new(re, eps)
})
.collect();
let output = {
let _guard = BtapeGuard::new(&mut tape);
f(&inputs)
};
let value = output.re.value;
let primal_index = output.re.index;
let tangent_index = output.eps.index;
let gradient = if primal_index != crate::bytecode_tape::CONSTANT {
let adjoints = tape.reverse(primal_index);
adjoints[..n].to_vec()
} else {
vec![F::zero(); n]
};
let hvp = if tangent_index != crate::bytecode_tape::CONSTANT {
let adjoints = tape.reverse(tangent_index);
adjoints[..n].to_vec()
} else {
vec![F::zero(); n]
};
(value, gradient, hvp)
}