use crate::families::jet_tower::Tower4;
const A: usize = 0;
const B: usize = 1;
const U: usize = 2;
const V: usize = 3;
#[inline]
pub(crate) fn hessian_coeff_fromobjective_q_terms(
m1: f64,
m2: f64,
q_a: f64,
q_b: f64,
q_ab: f64,
) -> f64 {
m2 * q_a * q_b + m1 * q_ab
}
#[inline]
pub(crate) fn directionalhessian_coeff_fromobjective_q_terms(
m1: f64,
m2: f64,
m3: f64,
dq: f64,
q_a: f64,
q_b: f64,
q_ab: f64,
dq_a: f64,
dq_b: f64,
dq_ab: f64,
) -> f64 {
let mut q = Tower4::<3>::zero();
q.g[A] = q_a;
q.g[B] = q_b;
q.g[U] = dq;
q.h[A][B] = q_ab; q.h[A][U] = dq_a; q.h[B][U] = dq_b; q.t3[A][B][U] = dq_ab;
q.compose_unary([0.0, m1, m2, m3, 0.0]).t3[A][B][U]
}
#[inline]
pub(crate) fn second_directionalhessian_coeff_fromobjective_q_terms(
m1: f64,
m2: f64,
m3: f64,
m4: f64,
dq_u: f64,
dqv: f64,
d2q_uv: f64,
q_a: f64,
q_b: f64,
q_ab: f64,
dq_a_u: f64,
dq_av: f64,
dq_b_u: f64,
dq_bv: f64,
d2q_a_uv: f64,
d2q_b_uv: f64,
dq_ab_u: f64,
dq_abv: f64,
d2q_ab_uv: f64,
) -> f64 {
let mut q = Tower4::<4>::zero();
q.g[A] = q_a;
q.g[B] = q_b;
q.g[U] = dq_u;
q.g[V] = dqv;
q.h[A][B] = q_ab; q.h[A][U] = dq_a_u; q.h[A][V] = dq_av; q.h[B][U] = dq_b_u; q.h[B][V] = dq_bv; q.h[U][V] = d2q_uv; q.t3[A][B][U] = dq_ab_u; q.t3[A][B][V] = dq_abv; q.t3[A][U][V] = d2q_a_uv; q.t3[B][U][V] = d2q_b_uv; q.t4[A][B][U][V] = d2q_ab_uv;
q.compose_unary([0.0, m1, m2, m3, m4]).t4[A][B][U][V]
}
#[cfg(test)]
mod oracle_tests {
use super::*;
use crate::families::jet_tower::Tower2;
fn hessian_via_tower(m1: f64, m2: f64, q_a: f64, q_b: f64, q_ab: f64) -> f64 {
let mut q = Tower2::<2>::zero();
q.g[0] = q_a;
q.g[1] = q_b;
q.h[0][1] = q_ab;
q.h[1][0] = q_ab;
q.compose_unary([0.0, m1, m2]).h[0][1]
}
fn directional_hand(
m1: f64,
m2: f64,
m3: f64,
dq: f64,
q_a: f64,
q_b: f64,
q_ab: f64,
dq_a: f64,
dq_b: f64,
dq_ab: f64,
) -> f64 {
m3 * dq * q_a * q_b + m2 * (dq_a * q_b + q_a * dq_b + dq * q_ab) + m1 * dq_ab
}
fn second_directional_hand(
m1: f64,
m2: f64,
m3: f64,
m4: f64,
dq_u: f64,
dqv: f64,
d2q_uv: f64,
q_a: f64,
q_b: f64,
q_ab: f64,
dq_a_u: f64,
dq_av: f64,
dq_b_u: f64,
dq_bv: f64,
d2q_a_uv: f64,
d2q_b_uv: f64,
dq_ab_u: f64,
dq_abv: f64,
d2q_ab_uv: f64,
) -> f64 {
let d_qaqb_u = dq_a_u * q_b + q_a * dq_b_u;
let d_qaqbv = dq_av * q_b + q_a * dq_bv;
let d2_qaqb_uv = d2q_a_uv * q_b + dq_a_u * dq_bv + dq_av * dq_b_u + q_a * d2q_b_uv;
m4 * dq_u * dqv * q_a * q_b
+ m3 * (d2q_uv * q_a * q_b + dq_u * d_qaqbv + dqv * d_qaqb_u + dq_u * dqv * q_ab)
+ m2 * (d2_qaqb_uv + d2q_uv * q_ab + dq_u * dq_abv + dqv * dq_ab_u)
+ m1 * d2q_ab_uv
}
fn stream(seed: u64) -> impl FnMut() -> f64 {
let mut s = seed;
move || {
s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
((s >> 11) as f64 / (1u64 << 53) as f64) * 2.0 - 1.0
}
}
fn close(label: &str, got: f64, want: f64) {
let tol = 1e-12 * want.abs().max(1.0);
assert!(
(got - want).abs() <= tol,
"{label}: tower {got:+.17e} vs hand {want:+.17e} (|Δ|={:.3e})",
(got - want).abs()
);
}
#[test]
fn hessian_matches_tower() {
let mut next = stream(0xC0FFEE);
for _ in 0..200 {
let (m1, m2) = (next(), next());
let (q_a, q_b, q_ab) = (next(), next(), next());
let prod = hessian_coeff_fromobjective_q_terms(m1, m2, q_a, q_b, q_ab);
let tower = hessian_via_tower(m1, m2, q_a, q_b, q_ab);
close("H_ab", prod, tower);
}
}
#[test]
fn directional_matches_hand_chain_rule() {
let mut next = stream(0xBEEF);
for _ in 0..400 {
let m = [next(), next(), next()];
let (dq, q_a, q_b, q_ab, dq_a, dq_b, dq_ab) =
(next(), next(), next(), next(), next(), next(), next());
let tower = directionalhessian_coeff_fromobjective_q_terms(
m[0], m[1], m[2], dq, q_a, q_b, q_ab, dq_a, dq_b, dq_ab,
);
let hand =
directional_hand(m[0], m[1], m[2], dq, q_a, q_b, q_ab, dq_a, dq_b, dq_ab);
close("D_u H_ab", tower, hand);
}
}
#[test]
fn second_directional_matches_hand_chain_rule() {
let mut next = stream(0xD00D);
for _ in 0..400 {
let m = [next(), next(), next(), next()];
let dq_u = next();
let dqv = next();
let d2q_uv = next();
let q_a = next();
let q_b = next();
let q_ab = next();
let dq_a_u = next();
let dq_av = next();
let dq_b_u = next();
let dq_bv = next();
let d2q_a_uv = next();
let d2q_b_uv = next();
let dq_ab_u = next();
let dq_abv = next();
let d2q_ab_uv = next();
let tower = second_directionalhessian_coeff_fromobjective_q_terms(
m[0], m[1], m[2], m[3], dq_u, dqv, d2q_uv, q_a, q_b, q_ab, dq_a_u, dq_av, dq_b_u,
dq_bv, d2q_a_uv, d2q_b_uv, dq_ab_u, dq_abv, d2q_ab_uv,
);
let hand = second_directional_hand(
m[0], m[1], m[2], m[3], dq_u, dqv, d2q_uv, q_a, q_b, q_ab, dq_a_u, dq_av, dq_b_u,
dq_bv, d2q_a_uv, d2q_b_uv, dq_ab_u, dq_abv, d2q_ab_uv,
);
close("D2_uv H_ab", tower, hand);
}
}
#[test]
fn betaw_cross_channel_expansions_match_single_source() {
let mut next = stream(0x5151);
let mut max_xw = 0.0_f64;
let mut max_ww = 0.0_f64;
for _ in 0..2000 {
let (m1, m2, m3, m4) = (next(), next(), next(), next());
let (dq_u, dqv, d2q_uv) = (next(), next(), next());
let (q_t, dq_t_u, dq_tv, d2q_t_uv) = (next(), next(), next(), next());
let (q0_t, dq0_u, dq0v, d2q0_uv) = (next(), next(), next(), next());
let (dq0_t_u, dq0_tv, d2q0_t_uv) = (next(), next(), next());
let (br, dr, ddr, d3r) = (next(), next(), next(), next());
let qw = br;
let dqw_u = dr * dq0_u;
let dqwv = dr * dq0v;
let d2qw_uv = ddr * dq0_u * dq0v + dr * d2q0_uv;
let q_tw = dr * q0_t;
let dq_tw_u = ddr * dq0_u * q0_t + dr * dq0_t_u;
let dq_twv = ddr * dq0v * q0_t + dr * dq0_tv;
let d2q_tw_uv = d3r * dq0_u * dq0v * q0_t
+ ddr * (d2q0_uv * q0_t + dq0_u * dq0_tv + dq0v * dq0_t_u)
+ dr * d2q0_t_uv;
let coeff_tw = second_directionalhessian_coeff_fromobjective_q_terms(
m1, m2, m3, m4, dq_u, dqv, d2q_uv, q_t, qw, q_tw, dq_t_u, dq_tv, dqw_u, dqwv,
d2q_t_uv, d2qw_uv, dq_tw_u, dq_twv, d2q_tw_uv,
);
let alpha_b =
m4 * dq_u * dqv * q_t + m3 * (d2q_uv * q_t + dq_u * dq_tv + dqv * dq_t_u) + m2 * d2q_t_uv;
let alpha_d = m3 * (dq_u * q_t * dq0v + dqv * q_t * dq0_u + dq_u * dqv * q0_t)
+ m2
* (dq_t_u * dq0v
+ dq_tv * dq0_u
+ q_t * d2q0_uv
+ d2q_uv * q0_t
+ dq_u * dq0_tv
+ dqv * dq0_t_u)
+ m1 * d2q0_t_uv;
let alpha_dd = m2 * (q_t * dq0_u * dq0v + dq_u * dq0v * q0_t + dqv * dq0_u * q0_t)
+ m1 * (d2q0_uv * q0_t + dq0_u * dq0_tv + dq0v * dq0_t_u);
let alpha_d3 = m1 * dq0_u * dq0v * q0_t;
let recon_xw = alpha_b * br + alpha_d * dr + alpha_dd * ddr + alpha_d3 * d3r;
max_xw = max_xw.max((coeff_tw - recon_xw).abs() / coeff_tw.abs().max(1.0));
let (brk, drk, ddrk) = (next(), next(), next());
let coeff_ww = second_directionalhessian_coeff_fromobjective_q_terms(
m1,
m2,
m3,
m4,
dq_u,
dqv,
d2q_uv,
br,
brk,
0.0,
dr * dq0_u,
dr * dq0v,
drk * dq0_u,
drk * dq0v,
ddr * dq0_u * dq0v + dr * d2q0_uv,
ddrk * dq0_u * dq0v + drk * d2q0_uv,
0.0,
0.0,
0.0,
);
let c_ww_bb = m4 * dq_u * dqv + m3 * d2q_uv;
let c_ww_bd = m3 * (dq_u * dq0v + dqv * dq0_u) + m2 * d2q0_uv;
let c_ww_bdd = m2 * dq0_u * dq0v;
let c_ww_dd_pair = 2.0 * m2 * dq0_u * dq0v;
let recon_ww = c_ww_bb * br * brk
+ c_ww_bd * (br * drk + dr * brk)
+ c_ww_bdd * (br * ddrk + ddr * brk)
+ c_ww_dd_pair * dr * drk;
max_ww = max_ww.max((coeff_ww - recon_ww).abs() / coeff_ww.abs().max(1.0));
}
assert!(max_xw < 1e-12, "alpha_xw drifted from single source: {max_xw:.3e}");
assert!(max_ww < 1e-12, "c_ww drifted from single source: {max_ww:.3e}");
}
}