use crate::jet_scalar::JetScalar;
use crate::jet_tower::{
KernelChannels, RowNllProgramGeneric, Tower4, generic_full_tower, verify_kernel_channels,
};
#[derive(Clone, Copy, Debug)]
struct GaussianRow {
y: f64,
eta: f64,
s: f64,
}
struct GaussianLocScaleRow {
rows: Vec<GaussianRow>,
}
impl RowNllProgramGeneric<2> for GaussianLocScaleRow {
fn n_rows(&self) -> usize {
self.rows.len()
}
fn primaries(&self, row: usize) -> Result<[f64; 2], String> {
let r = self
.rows
.get(row)
.ok_or_else(|| format!("GaussianLocScaleRow: row {row} out of range"))?;
Ok([r.eta, r.s])
}
fn row_nll_generic<S: JetScalar<2>>(&self, row: usize, p: &[S; 2]) -> Result<S, String> {
let data = self
.rows
.get(row)
.ok_or_else(|| format!("GaussianLocScaleRow: row {row} out of range"))?;
let eta = &p[0];
let s = &p[1];
let r = S::constant(data.y).sub(eta);
let w = s.scale(-2.0).exp();
Ok(s.add(&w.mul(&r).mul(&r).scale(0.5)))
}
}
fn gaussian_closed_form_channels(
row: &GaussianRow,
third_dirs: &[[f64; 2]],
fourth_pairs: &[([f64; 2], [f64; 2])],
) -> KernelChannels<2> {
let r = row.y - row.eta;
let w = (-2.0 * row.s).exp();
let value = row.s + 0.5 * w * r * r;
let gradient = [-w * r, 1.0 - w * r * r];
let hessian = [[w, 2.0 * w * r], [2.0 * w * r, 2.0 * w * r * r]];
let t3 = |a: usize, b: usize, c: usize| -> f64 {
match a + b + c {
0 => 0.0, 1 => -2.0 * w, 2 => -4.0 * w * r, _ => -4.0 * w * r * r, }
};
let t4 = |a: usize, b: usize, c: usize, d: usize| -> f64 {
match a + b + c + d {
0 | 1 => 0.0, 2 => 4.0 * w, 3 => 8.0 * w * r, _ => 8.0 * w * r * r, }
};
let third = third_dirs
.iter()
.map(|dir| {
let mut contracted = [[0.0_f64; 2]; 2];
for a in 0..2 {
for b in 0..2 {
let mut acc = 0.0;
for c in 0..2 {
acc += t3(a, b, c) * dir[c];
}
contracted[a][b] = acc;
}
}
(*dir, contracted)
})
.collect();
let fourth = fourth_pairs
.iter()
.map(|(u, v)| {
let mut contracted = [[0.0_f64; 2]; 2];
for a in 0..2 {
for b in 0..2 {
let mut acc = 0.0;
for c in 0..2 {
for d in 0..2 {
acc += t4(a, b, c, d) * u[c] * v[d];
}
}
contracted[a][b] = acc;
}
}
(*u, *v, contracted)
})
.collect();
KernelChannels {
value,
gradient,
hessian,
third,
fourth,
}
}
struct Lcg(u64);
impl Lcg {
fn next_f64(&mut self) -> f64 {
self.0 = self
.0
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((self.0 >> 11) as f64) / ((1u64 << 53) as f64)
}
fn uniform(&mut self, lo: f64, hi: f64) -> f64 {
lo + (hi - lo) * self.next_f64()
}
}
#[test]
fn gaussian_loc_scale_jet_tower_matches_hand_derived_via_universal_oracle() {
let mut rng = Lcg(0x9322_0203_face_b00c);
let third_dirs: [[f64; 2]; 3] = [[0.7, -1.3], [-0.4, 0.6], [1.2, 0.2]];
let fourth_pairs: [([f64; 2], [f64; 2]); 3] = [
([0.7, -1.3], [-0.4, 0.6]),
([-0.4, 0.6], [1.2, 0.2]),
([1.2, 0.2], [0.7, -1.3]),
];
let mut rows = Vec::new();
for _ in 0..24 {
rows.push(GaussianRow {
y: rng.uniform(-3.0, 3.0),
eta: rng.uniform(-2.0, 2.0),
s: rng.uniform(-1.0, 1.0),
});
}
let program = GaussianLocScaleRow { rows: rows.clone() };
const REL_TOL: f64 = 1e-11;
for (row, fixture) in rows.iter().enumerate() {
let tower: Tower4<2> = generic_full_tower(&program, row).expect("gaussian jet tower");
let claims =
gaussian_closed_form_channels(fixture, &third_dirs, &fourth_pairs);
verify_kernel_channels(&tower, &claims, REL_TOL).unwrap_or_else(|e| {
panic!(
"row {row}: Gaussian location-scale hand channels disagree with #932 \
jet-tower truth: {e}"
)
});
}
}
#[test]
fn gaussian_loc_scale_packed_scalars_match_hand_derived_contractions() {
use crate::jet_tower::{
generic_fourth_contracted, generic_row_kernel, generic_third_contracted,
};
let mut rng = Lcg(0x0bad_c0de_9322_0203);
let third_dirs: [[f64; 2]; 3] = [[0.9, -0.5], [-1.1, 0.3], [0.2, 1.4]];
let mut rows = Vec::new();
for _ in 0..18 {
rows.push(GaussianRow {
y: rng.uniform(-3.0, 3.0),
eta: rng.uniform(-2.0, 2.0),
s: rng.uniform(-1.0, 1.0),
});
}
let program = GaussianLocScaleRow { rows: rows.clone() };
const REL_TOL: f64 = 1e-11;
let close = |a: f64, b: f64, label: &str| {
let band = REL_TOL + REL_TOL * a.abs().max(b.abs());
assert!(
(a - b).abs() <= band,
"{label}: jet {a:+.15e} vs hand {b:+.15e} (band {band:.3e})"
);
};
for (row, fixture) in rows.iter().enumerate() {
let hand = gaussian_closed_form_channels(fixture, &[], &[]);
let (v, g, h) = generic_row_kernel(&program, row).expect("Order2 channel");
close(v, hand.value, &format!("row {row} Order2 value"));
for i in 0..2 {
close(g[i], hand.gradient[i], &format!("row {row} Order2 grad[{i}]"));
for j in 0..2 {
close(
h[i][j],
hand.hessian[i][j],
&format!("row {row} Order2 hess[{i}][{j}]"),
);
}
}
let tower: Tower4<2> = generic_full_tower(&program, row).expect("tower");
for (di, dir) in third_dirs.iter().enumerate() {
let third = generic_third_contracted(&program, row, dir).expect("OneSeed third");
let truth = tower.third_contracted(dir);
for i in 0..2 {
for j in 0..2 {
close(
third[i][j],
truth[i][j],
&format!("row {row} dir {di} OneSeed third[{i}][{j}]"),
);
}
}
}
for (ui, u) in third_dirs.iter().enumerate() {
let v = third_dirs[(ui + 1) % third_dirs.len()];
let fourth =
generic_fourth_contracted(&program, row, u, &v).expect("TwoSeed fourth");
let truth = tower.fourth_contracted(u, &v);
for i in 0..2 {
for j in 0..2 {
close(
fourth[i][j],
truth[i][j],
&format!("row {row} pair {ui} TwoSeed fourth[{i}][{j}]"),
);
}
}
}
}
}