use echidna::{BytecodeTape, Dual, Float};
fn validate_step_tape<F: Float>(tape: &BytecodeTape<F>, z: &[F], x: &[F], num_states: usize) {
assert_eq!(z.len(), num_states);
assert_eq!(tape.num_inputs(), num_states + x.len());
assert_eq!(
tape.num_outputs(),
num_states,
"step tape must have num_outputs == num_states (G: R^(m+n) -> R^m)"
);
}
pub fn piggyback_tangent_step<F: Float>(
step_tape: &BytecodeTape<F>,
z: &[F],
x: &[F],
z_dot: &[F],
x_dot: &[F],
num_states: usize,
) -> (Vec<F>, Vec<F>) {
let mut buf = Vec::new();
piggyback_tangent_step_with_buf(step_tape, z, x, z_dot, x_dot, num_states, &mut buf)
}
pub fn piggyback_tangent_step_with_buf<F: Float>(
step_tape: &BytecodeTape<F>,
z: &[F],
x: &[F],
z_dot: &[F],
x_dot: &[F],
num_states: usize,
buf: &mut Vec<Dual<F>>,
) -> (Vec<F>, Vec<F>) {
validate_step_tape(step_tape, z, x, num_states);
let m = num_states;
let n = x.len();
assert_eq!(z_dot.len(), m, "z_dot length must equal num_states");
assert_eq!(x_dot.len(), n, "x_dot length must equal x length");
let mut dual_inputs = Vec::with_capacity(m + n);
for i in 0..m {
dual_inputs.push(Dual::new(z[i], z_dot[i]));
}
for j in 0..n {
dual_inputs.push(Dual::new(x[j], x_dot[j]));
}
step_tape.forward_tangent(&dual_inputs, buf);
let out_indices = step_tape.all_output_indices();
let mut z_new = Vec::with_capacity(m);
let mut z_dot_new = Vec::with_capacity(m);
for &idx in out_indices {
let d = buf[idx as usize];
z_new.push(d.re);
z_dot_new.push(d.eps);
}
(z_new, z_dot_new)
}
pub fn piggyback_tangent_solve<F: Float>(
step_tape: &BytecodeTape<F>,
z0: &[F],
x: &[F],
x_dot: &[F],
num_states: usize,
max_iter: usize,
tol: F,
) -> Option<(Vec<F>, Vec<F>, usize)> {
let m = num_states;
let mut z = z0.to_vec();
let mut z_dot = vec![F::zero(); m];
let mut buf = Vec::new();
for k in 0..max_iter {
let (z_new, z_dot_new) =
piggyback_tangent_step_with_buf(step_tape, &z, x, &z_dot, x_dot, num_states, &mut buf);
let mut delta_sq = F::zero();
let mut z_sq = F::zero();
for i in 0..m {
let d = z_new[i] - z[i];
delta_sq = delta_sq + d * d;
z_sq = z_sq + z[i] * z[i];
}
let norm = delta_sq.sqrt() / (F::one() + z_sq.sqrt());
if !norm.is_finite() {
return None;
}
if norm < tol {
return Some((z_new, z_dot_new, k + 1));
}
z = z_new;
z_dot = z_dot_new;
}
None
}
pub fn piggyback_adjoint_solve<F: Float>(
step_tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
z_bar: &[F],
num_states: usize,
max_iter: usize,
tol: F,
) -> Option<(Vec<F>, usize)> {
validate_step_tape(step_tape, z_star, x, num_states);
let m = num_states;
assert_eq!(z_bar.len(), m, "z_bar length must equal num_states");
let mut input = Vec::with_capacity(m + x.len());
input.extend_from_slice(z_star);
input.extend_from_slice(x);
step_tape.forward(&input);
let mut lambda = z_bar.to_vec();
for k in 0..max_iter {
let adj = step_tape.reverse_seeded(&lambda);
let mut lambda_new = Vec::with_capacity(m);
let mut delta_sq = F::zero();
let mut lam_sq = F::zero();
for i in 0..m {
let l_new = adj[i] + z_bar[i];
let d = l_new - lambda[i];
delta_sq = delta_sq + d * d;
lam_sq = lam_sq + lambda[i] * lambda[i];
lambda_new.push(l_new);
}
let norm = delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
if !norm.is_finite() {
return None;
}
if norm < tol {
let adj_final = step_tape.reverse_seeded(&lambda_new);
return Some((adj_final[m..].to_vec(), k + 1));
}
lambda = lambda_new;
}
None
}
pub fn piggyback_forward_adjoint_solve<F: Float>(
step_tape: &mut BytecodeTape<F>,
z0: &[F],
x: &[F],
z_bar: &[F],
num_states: usize,
max_iter: usize,
tol: F,
) -> Option<(Vec<F>, Vec<F>, usize)> {
validate_step_tape(step_tape, z0, x, num_states);
let m = num_states;
assert_eq!(z_bar.len(), m, "z_bar length must equal num_states");
let mut input = Vec::with_capacity(m + x.len());
input.extend_from_slice(z0);
input.extend_from_slice(x);
let mut lambda = z_bar.to_vec();
for k in 0..max_iter {
step_tape.forward(&input);
let z_new = step_tape.output_values();
let adj = step_tape.reverse_seeded(&lambda);
let mut z_delta_sq = F::zero();
let mut z_sq = F::zero();
for i in 0..m {
let d = z_new[i] - input[i];
z_delta_sq = z_delta_sq + d * d;
z_sq = z_sq + input[i] * input[i];
}
let z_norm = z_delta_sq.sqrt() / (F::one() + z_sq.sqrt());
if !z_norm.is_finite() {
return None;
}
let mut lam_delta_sq = F::zero();
let mut lam_sq = F::zero();
let mut lambda_new = Vec::with_capacity(m);
for i in 0..m {
let l_new = adj[i] + z_bar[i];
let d = l_new - lambda[i];
lam_delta_sq = lam_delta_sq + d * d;
lam_sq = lam_sq + lambda[i] * lambda[i];
lambda_new.push(l_new);
}
let lam_norm = lam_delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
if !lam_norm.is_finite() {
return None;
}
if z_norm < tol && lam_norm < tol {
input[..m].copy_from_slice(&z_new[..m]);
step_tape.forward(&input);
let adj_final = step_tape.reverse_seeded(&lambda_new);
return Some((z_new, adj_final[m..].to_vec(), k + 1));
}
input[..m].copy_from_slice(&z_new[..m]);
lambda = lambda_new;
}
None
}