use std::fmt;
use echidna::{BytecodeTape, Dual, Float};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum PiggybackError {
PrimalDivergence { iteration: usize, last_norm: f64 },
TangentDivergence { iteration: usize, last_norm: f64 },
AdjointDivergence { iteration: usize, last_norm: f64 },
IterationsExhaustedTangent { iteration: usize, z_norm: f64 },
IterationsExhaustedAdjoint { iteration: usize, lam_norm: f64 },
IterationsExhaustedForwardAdjoint {
iteration: usize,
z_norm: f64,
lam_norm: f64,
},
DimensionMismatch {
field: &'static str,
expected: usize,
actual: usize,
},
}
impl fmt::Display for PiggybackError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PiggybackError::PrimalDivergence {
iteration,
last_norm,
} => {
write!(
f,
"piggyback: primal diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
)
}
PiggybackError::TangentDivergence {
iteration,
last_norm,
} => {
write!(
f,
"piggyback: tangent diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
)
}
PiggybackError::AdjointDivergence {
iteration,
last_norm,
} => {
write!(
f,
"piggyback: adjoint diverged at iteration {iteration} (last_norm = {last_norm:.3e})"
)
}
PiggybackError::IterationsExhaustedTangent { iteration, z_norm } => {
write!(
f,
"piggyback: tangent solve reached max_iter = {iteration} (z_norm = {z_norm:.3e})"
)
}
PiggybackError::IterationsExhaustedAdjoint {
iteration,
lam_norm,
} => {
write!(
f,
"piggyback: adjoint solve reached max_iter = {iteration} (lam_norm = {lam_norm:.3e})"
)
}
PiggybackError::IterationsExhaustedForwardAdjoint {
iteration,
z_norm,
lam_norm,
} => {
write!(
f,
"piggyback: forward-adjoint solve reached max_iter = {iteration} (z_norm = {z_norm:.3e}, lam_norm = {lam_norm:.3e})"
)
}
PiggybackError::DimensionMismatch {
field,
expected,
actual,
} => {
write!(
f,
"piggyback: dimension mismatch for `{field}` (expected {expected}, got {actual})"
)
}
}
}
}
impl std::error::Error for PiggybackError {}
echidna::assert_send_sync!(PiggybackError);
fn validate_step_tape<F: Float>(tape: &BytecodeTape<F>, z: &[F], x: &[F], num_states: usize) {
assert_eq!(z.len(), num_states);
assert_eq!(tape.num_inputs(), num_states + x.len());
assert_eq!(
tape.num_outputs(),
num_states,
"step tape must have num_outputs == num_states (G: R^(m+n) -> R^m)"
);
}
pub fn piggyback_tangent_step<F: Float>(
step_tape: &BytecodeTape<F>,
z: &[F],
x: &[F],
z_dot: &[F],
x_dot: &[F],
num_states: usize,
) -> (Vec<F>, Vec<F>) {
let mut buf = Vec::new();
piggyback_tangent_step_with_buf(step_tape, z, x, z_dot, x_dot, num_states, &mut buf)
}
pub fn piggyback_tangent_step_with_buf<F: Float>(
step_tape: &BytecodeTape<F>,
z: &[F],
x: &[F],
z_dot: &[F],
x_dot: &[F],
num_states: usize,
buf: &mut Vec<Dual<F>>,
) -> (Vec<F>, Vec<F>) {
validate_step_tape(step_tape, z, x, num_states);
let m = num_states;
let n = x.len();
assert_eq!(z_dot.len(), m, "z_dot length must equal num_states");
assert_eq!(x_dot.len(), n, "x_dot length must equal x length");
let mut dual_inputs = Vec::with_capacity(m + n);
for i in 0..m {
dual_inputs.push(Dual::new(z[i], z_dot[i]));
}
for j in 0..n {
dual_inputs.push(Dual::new(x[j], x_dot[j]));
}
step_tape.forward_tangent(&dual_inputs, buf);
let out_indices = step_tape.all_output_indices();
let mut z_new = Vec::with_capacity(m);
let mut z_dot_new = Vec::with_capacity(m);
for &idx in out_indices {
let d = buf[idx as usize];
z_new.push(d.re);
z_dot_new.push(d.eps);
}
(z_new, z_dot_new)
}
pub fn piggyback_tangent_solve<F: Float>(
step_tape: &BytecodeTape<F>,
z0: &[F],
x: &[F],
x_dot: &[F],
num_states: usize,
max_iter: usize,
tol: F,
) -> Result<(Vec<F>, Vec<F>, usize), PiggybackError> {
if x_dot.len() != x.len() {
return Err(PiggybackError::DimensionMismatch {
field: "x_dot",
expected: x.len(),
actual: x_dot.len(),
});
}
let m = num_states;
let mut z = z0.to_vec();
let mut z_dot = vec![F::zero(); m];
let mut buf = Vec::new();
let mut last_norm: f64 = f64::NAN;
for k in 0..max_iter {
let (z_new, z_dot_new) =
piggyback_tangent_step_with_buf(step_tape, &z, x, &z_dot, x_dot, num_states, &mut buf);
let mut delta_sq = F::zero();
let mut z_sq = F::zero();
for i in 0..m {
let d = z_new[i] - z[i];
delta_sq = delta_sq + d * d;
z_sq = z_sq + z[i] * z[i];
}
let norm = delta_sq.sqrt() / (F::one() + z_sq.sqrt());
if !norm.is_finite() {
return Err(PiggybackError::PrimalDivergence {
iteration: k,
last_norm: norm.to_f64().unwrap_or(f64::NAN),
});
}
if !z_dot_new.iter().all(|v| v.is_finite()) {
debug_assert!(
norm.is_finite(),
"TangentDivergence path must see a finite primal norm"
);
return Err(PiggybackError::TangentDivergence {
iteration: k,
last_norm: norm.to_f64().unwrap_or(f64::NAN),
});
}
last_norm = norm.to_f64().unwrap_or(f64::NAN);
if norm < tol {
return Ok((z_new, z_dot_new, k + 1));
}
z = z_new;
z_dot = z_dot_new;
}
Err(PiggybackError::IterationsExhaustedTangent {
iteration: max_iter,
z_norm: last_norm,
})
}
pub fn piggyback_adjoint_solve<F: Float>(
step_tape: &mut BytecodeTape<F>,
z_star: &[F],
x: &[F],
z_bar: &[F],
num_states: usize,
max_iter: usize,
tol: F,
) -> Result<(Vec<F>, usize), PiggybackError> {
let m = num_states;
if z_bar.len() != m {
return Err(PiggybackError::DimensionMismatch {
field: "z_bar",
expected: m,
actual: z_bar.len(),
});
}
validate_step_tape(step_tape, z_star, x, num_states);
let mut input = Vec::with_capacity(m + x.len());
input.extend_from_slice(z_star);
input.extend_from_slice(x);
step_tape.forward(&input);
let mut lambda = z_bar.to_vec();
let mut last_norm: f64 = f64::NAN;
for k in 0..max_iter {
let adj = step_tape.reverse_seeded(&lambda);
let mut lambda_new = Vec::with_capacity(m);
let mut delta_sq = F::zero();
let mut lam_sq = F::zero();
for i in 0..m {
let l_new = adj[i] + z_bar[i];
let d = l_new - lambda[i];
delta_sq = delta_sq + d * d;
lam_sq = lam_sq + lambda[i] * lambda[i];
lambda_new.push(l_new);
}
let norm = delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
if !norm.is_finite() {
return Err(PiggybackError::AdjointDivergence {
iteration: k,
last_norm: norm.to_f64().unwrap_or(f64::NAN),
});
}
if !lambda_new.iter().all(|v| v.is_finite()) {
debug_assert!(
norm.is_finite(),
"AdjointDivergence componentwise path must see a finite norm"
);
return Err(PiggybackError::AdjointDivergence {
iteration: k,
last_norm: norm.to_f64().unwrap_or(f64::NAN),
});
}
last_norm = norm.to_f64().unwrap_or(f64::NAN);
if norm < tol {
let adj_final = step_tape.reverse_seeded(&lambda_new);
return Ok((adj_final[m..].to_vec(), k + 1));
}
lambda = lambda_new;
}
Err(PiggybackError::IterationsExhaustedAdjoint {
iteration: max_iter,
lam_norm: last_norm,
})
}
pub fn piggyback_forward_adjoint_solve<F: Float>(
step_tape: &mut BytecodeTape<F>,
z0: &[F],
x: &[F],
z_bar: &[F],
num_states: usize,
max_iter: usize,
tol: F,
) -> Result<(Vec<F>, Vec<F>, usize), PiggybackError> {
let m = num_states;
if z_bar.len() != m {
return Err(PiggybackError::DimensionMismatch {
field: "z_bar",
expected: m,
actual: z_bar.len(),
});
}
validate_step_tape(step_tape, z0, x, num_states);
let mut input = Vec::with_capacity(m + x.len());
input.extend_from_slice(z0);
input.extend_from_slice(x);
let mut lambda = z_bar.to_vec();
let mut last_z_norm: f64 = f64::NAN;
let mut last_lam_norm: f64 = f64::NAN;
for k in 0..max_iter {
step_tape.forward(&input);
let z_new = step_tape.output_values();
let adj = step_tape.reverse_seeded(&lambda);
let mut z_delta_sq = F::zero();
let mut z_sq = F::zero();
for i in 0..m {
let d = z_new[i] - input[i];
z_delta_sq = z_delta_sq + d * d;
z_sq = z_sq + input[i] * input[i];
}
let z_norm = z_delta_sq.sqrt() / (F::one() + z_sq.sqrt());
if !z_norm.is_finite() {
return Err(PiggybackError::PrimalDivergence {
iteration: k,
last_norm: z_norm.to_f64().unwrap_or(f64::NAN),
});
}
let mut lam_delta_sq = F::zero();
let mut lam_sq = F::zero();
let mut lambda_new = Vec::with_capacity(m);
for i in 0..m {
let l_new = adj[i] + z_bar[i];
let d = l_new - lambda[i];
lam_delta_sq = lam_delta_sq + d * d;
lam_sq = lam_sq + lambda[i] * lambda[i];
lambda_new.push(l_new);
}
let lam_norm = lam_delta_sq.sqrt() / (F::one() + lam_sq.sqrt());
if !lam_norm.is_finite() {
return Err(PiggybackError::AdjointDivergence {
iteration: k,
last_norm: lam_norm.to_f64().unwrap_or(f64::NAN),
});
}
if !lambda_new.iter().all(|v| v.is_finite()) {
debug_assert!(
lam_norm.is_finite(),
"AdjointDivergence componentwise path must see a finite lam_norm"
);
return Err(PiggybackError::AdjointDivergence {
iteration: k,
last_norm: lam_norm.to_f64().unwrap_or(f64::NAN),
});
}
if !z_new.iter().all(|v| v.is_finite()) {
debug_assert!(
z_norm.is_finite(),
"PrimalDivergence componentwise path must see a finite z_norm"
);
return Err(PiggybackError::PrimalDivergence {
iteration: k,
last_norm: z_norm.to_f64().unwrap_or(f64::NAN),
});
}
last_z_norm = z_norm.to_f64().unwrap_or(f64::NAN);
last_lam_norm = lam_norm.to_f64().unwrap_or(f64::NAN);
if z_norm < tol && lam_norm < tol {
input[..m].copy_from_slice(&z_new[..m]);
step_tape.forward(&input);
let adj_final = step_tape.reverse_seeded(&lambda_new);
return Ok((z_new, adj_final[m..].to_vec(), k + 1));
}
input[..m].copy_from_slice(&z_new[..m]);
lambda = lambda_new;
}
Err(PiggybackError::IterationsExhaustedForwardAdjoint {
iteration: max_iter,
z_norm: last_z_norm,
lam_norm: last_lam_norm,
})
}