pub(crate) trait RuntimeJet: Sized + Clone {
fn value(&self) -> f64;
fn add(&self, o: &Self) -> Self;
fn sub(&self, o: &Self) -> Self;
fn mul(&self, o: &Self) -> Self;
fn scale(&self, s: f64) -> Self;
fn compose_unary(&self, d: [f64; 5]) -> Self;
}
#[derive(Clone, Debug)]
pub(crate) struct Jet2 {
pub(crate) v: f64,
pub(crate) g: Vec<f64>,
pub(crate) h: Vec<f64>,
}
impl Jet2 {
pub(crate) fn constant(x: f64, p: usize) -> Self {
Jet2 {
v: x,
g: vec![0.0; p],
h: vec![0.0; p * p],
}
}
pub(crate) fn primary(x: f64, axis: usize, p: usize) -> Self {
let mut g = vec![0.0; p];
if axis < p {
g[axis] = 1.0;
}
Jet2 {
v: x,
g,
h: vec![0.0; p * p],
}
}
#[inline]
pub(crate) fn p(&self) -> usize {
self.g.len()
}
}
impl RuntimeJet for Jet2 {
#[inline]
fn value(&self) -> f64 {
self.v
}
fn add(&self, o: &Self) -> Self {
let p = self.p();
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = self.g[i] + o.g[i];
}
for k in 0..p * p {
h[k] = self.h[k] + o.h[k];
}
Jet2 {
v: self.v + o.v,
g,
h,
}
}
fn sub(&self, o: &Self) -> Self {
let p = self.p();
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = self.g[i] - o.g[i];
}
for k in 0..p * p {
h[k] = self.h[k] - o.h[k];
}
Jet2 {
v: self.v - o.v,
g,
h,
}
}
fn mul(&self, o: &Self) -> Self {
let p = self.p();
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = self.v * o.g[i] + self.g[i] * o.v;
}
for i in 0..p {
for j in 0..p {
h[i * p + j] = self.v * o.h[i * p + j]
+ self.g[i] * o.g[j]
+ self.g[j] * o.g[i]
+ self.h[i * p + j] * o.v;
}
}
Jet2 {
v: self.v * o.v,
g,
h,
}
}
fn scale(&self, s: f64) -> Self {
Jet2 {
v: self.v * s,
g: self.g.iter().map(|&x| x * s).collect(),
h: self.h.iter().map(|&x| x * s).collect(),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let p = self.p();
let (f, f1, f2) = (d[0], d[1], d[2]);
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = f1 * self.g[i];
}
for i in 0..p {
for j in 0..p {
h[i * p + j] = f2 * self.g[i] * self.g[j] + f1 * self.h[i * p + j];
}
}
Jet2 { v: f, g, h }
}
}
pub(crate) fn filtered_implicit_solve_jet2(
a0: f64,
inv_fa: f64,
iters: usize,
p: usize,
f: impl Fn(&Jet2) -> Jet2,
) -> Jet2 {
let mut a = Jet2::constant(a0, p);
for _ in 0..iters {
let residual = f(&a);
a = a.sub(&residual.scale(inv_fa));
}
a
}
pub(crate) fn poly_conv_jet(a: &[Jet2], b: &[Jet2], p: usize) -> Vec<Jet2> {
let mut out = vec![Jet2::constant(0.0, p); a.len() + b.len() - 1];
for (i, ai) in a.iter().enumerate() {
for (j, bj) in b.iter().enumerate() {
out[i + j] = out[i + j].add(&ai.mul(bj));
}
}
out
}
pub(crate) fn cell_base_moment_jets(
c_jets: &[Jet2; 4],
c0_scalar: [f64; 4],
scalar_moments: &[f64],
max_n: usize,
) -> Vec<Jet2> {
let p = c_jets[0].p();
let cst = |x: f64| Jet2::constant(x, p);
let dc: [Jet2; 4] = std::array::from_fn(|k| c_jets[k].sub(&cst(c0_scalar[k])));
let eta0: [Jet2; 4] = std::array::from_fn(|k| cst(c0_scalar[k]));
let eta0_deta = poly_conv_jet(&eta0, &dc, p);
let deta_sq = poly_conv_jet(&dc, &dc, p);
let mut dq: Vec<Jet2> = (0..eta0_deta.len())
.map(|m| eta0_deta[m].add(&deta_sq[m].scale(0.5)))
.collect();
let dq_sq = poly_conv_jet(&dq, &dq, p);
let s_len = dq_sq.len();
dq.resize(s_len, cst(0.0));
let mut s_poly: Vec<Jet2> = (0..s_len)
.map(|m| dq[m].scale(-1.0).add(&dq_sq[m].scale(0.5)))
.collect();
s_poly[0] = s_poly[0].add(&cst(1.0));
(0..=max_n)
.map(|n| {
let mut acc = cst(0.0);
for (m, s_m) in s_poly.iter().enumerate() {
let mom = scalar_moments[n + m];
acc = acc.add(&s_m.scale(mom));
}
acc
})
.collect()
}
pub(crate) fn edge_sliver_jet(
n: usize,
c_jets: &[Jet2; 4],
c0_scalar: [f64; 4],
ze_jet: &Jet2,
ze0: f64,
) -> Jet2 {
let p = c_jets[0].p();
let cst = |x: f64| Jet2::constant(x, p);
let mut deta = cst(0.0);
let mut z_pow = 1.0_f64;
for k in 0..4 {
deta = deta.add(&c_jets[k].sub(&cst(c0_scalar[k])).scale(z_pow));
z_pow *= ze0;
}
let eta0 =
c0_scalar[0] + c0_scalar[1] * ze0 + c0_scalar[2] * ze0 * ze0 + c0_scalar[3] * ze0 * ze0 * ze0;
let dq = deta.scale(eta0).add(&deta.mul(&deta).scale(0.5));
let edq = cst(1.0).sub(&dq).add(&dq.mul(&dq).scale(0.5));
let q0 = 0.5 * (ze0 * ze0 + eta0 * eta0);
let g0 = ze0.powi(n as i32) * (-q0).exp();
let g_jet = edq.scale(g0); let delta = ze_jet.sub(&cst(ze0)); let eta_p = c0_scalar[1] + 2.0 * c0_scalar[2] * ze0 + 3.0 * c0_scalar[3] * ze0 * ze0;
let q_z = ze0 + eta0 * eta_p;
let z_nm1 = if n == 0 {
0.0
} else {
(n as f64) * ze0.powi(n as i32 - 1)
};
let g_z = (z_nm1 - ze0.powi(n as i32) * q_z) * (-q0).exp();
g_jet.mul(&delta).add(&delta.mul(&delta).scale(0.5 * g_z))
}
pub(crate) fn cell_base_moment_jets_moving(
c_jets: &[Jet2; 4],
c0_scalar: [f64; 4],
scalar_moments: &[f64],
max_n: usize,
zl_jet: &Jet2,
zl0: f64,
zr_jet: &Jet2,
zr0: f64,
) -> Vec<Jet2> {
let interior = cell_base_moment_jets(c_jets, c0_scalar, scalar_moments, max_n);
(0..=max_n)
.map(|n| {
let s_r = edge_sliver_jet(n, c_jets, c0_scalar, zr_jet, zr0);
let s_l = edge_sliver_jet(n, c_jets, c0_scalar, zl_jet, zl0);
interior[n].add(&s_r).sub(&s_l)
})
.collect()
}
pub(crate) struct CellCoeffAbPartials {
pub(crate) c0: [f64; 4],
pub(crate) dc_da: [f64; 4],
pub(crate) dc_db: [f64; 4],
pub(crate) dc_daa: [f64; 4],
pub(crate) dc_dab: [f64; 4],
pub(crate) dc_dbb: [f64; 4],
}
pub(crate) fn cell_coeff_jet_ab(
part: &CellCoeffAbPartials,
a_jet: &Jet2,
b_jet: &Jet2,
a0: f64,
b0: f64,
) -> [Jet2; 4] {
let p = a_jet.p();
let cst = |x: f64| Jet2::constant(x, p);
let da = a_jet.sub(&cst(a0)); let db = b_jet.sub(&cst(b0)); let da2 = da.mul(&da);
let dadb = da.mul(&db);
let db2 = db.mul(&db);
std::array::from_fn(|k| {
cst(part.c0[k])
.add(&da.scale(part.dc_da[k]))
.add(&db.scale(part.dc_db[k]))
.add(&da2.scale(0.5 * part.dc_daa[k]))
.add(&dadb.scale(part.dc_dab[k]))
.add(&db2.scale(0.5 * part.dc_dbb[k]))
})
}
#[cfg(test)]
mod tests {
use super::*;
fn fd_grad_hess(theta: &[f64], step: f64, f: impl Fn(&[f64]) -> f64) -> (Vec<f64>, Vec<f64>) {
let p = theta.len();
let mut grad = vec![0.0; p];
for i in 0..p {
let mut tp = theta.to_vec();
let mut tm = theta.to_vec();
tp[i] += step;
tm[i] -= step;
grad[i] = (f(&tp) - f(&tm)) / (2.0 * step);
}
let mut hess = vec![0.0; p * p];
for i in 0..p {
for j in 0..p {
let mut tpp = theta.to_vec();
let mut tpm = theta.to_vec();
let mut tmp = theta.to_vec();
let mut tmm = theta.to_vec();
tpp[i] += step;
tpp[j] += step;
tpm[i] += step;
tpm[j] -= step;
tmp[i] -= step;
tmp[j] += step;
tmm[i] -= step;
tmm[j] -= step;
hess[i * p + j] = (f(&tpp) - f(&tpm) - f(&tmp) + f(&tmm)) / (4.0 * step * step);
}
}
(grad, hess)
}
#[test]
fn runtime_jet2_algebra_matches_finite_differences_932() {
let theta = [0.37_f64, -0.21, 0.84];
let p = theta.len();
let scalar = |t: &[f64]| -> f64 {
let (x, y, z) = (t[0], t[1], t[2]);
(x * y).exp() + (z * z + 1.0) * x
};
let x = Jet2::primary(theta[0], 0, p);
let y = Jet2::primary(theta[1], 1, p);
let z = Jet2::primary(theta[2], 2, p);
let xy = x.mul(&y);
let e = xy.value().exp();
let exp_xy = xy.compose_unary([e, e, e, e, e]);
let z2p1 = z.mul(&z).add(&Jet2::constant(1.0, p));
let jet = exp_xy.add(&z2p1.mul(&x));
let (fg, fh) = fd_grad_hess(&theta, 1e-4, scalar);
assert!((jet.value() - scalar(&theta)).abs() < 1e-12);
for i in 0..p {
assert!(
(jet.g[i] - fg[i]).abs() < 1e-6,
"grad[{i}]: jet {} vs fd {}",
jet.g[i],
fg[i]
);
for j in 0..p {
assert!(
(jet.h[i * p + j] - fh[i * p + j]).abs() < 1e-4,
"hess[{i},{j}]: jet {} vs fd {}",
jet.h[i * p + j],
fh[i * p + j]
);
}
}
}
#[test]
fn runtime_jet2_implicit_lift_matches_analytic_ift_932() {
let theta = [1.3_f64, -0.45];
let p = theta.len();
let mut a0 = 0.0_f64;
for _ in 0..200 {
let fa = a0 * a0 * a0 + theta[0] * a0 + theta[1];
let dfa = 3.0 * a0 * a0 + theta[0];
a0 -= fa / dfa;
}
let f_a = 3.0 * a0 * a0 + theta[0];
let inv_fa = 1.0 / f_a;
let theta0 = Jet2::primary(theta[0], 0, p);
let theta1 = Jet2::primary(theta[1], 1, p);
let constraint = |a: &Jet2| -> Jet2 {
let a2 = a.mul(a);
let a3 = a2.mul(a);
a3.add(&theta0.mul(a)).add(&theta1)
};
let a_jet = filtered_implicit_solve_jet2(a0, inv_fa, 2, p, constraint);
let a_t0 = -a0 / f_a;
let a_t1 = -1.0 / f_a;
let f_aa = 6.0 * a0;
let a_t0t0 = -(2.0 * a_t0 + f_aa * a_t0 * a_t0) / f_a;
let a_t0t1 = -(a_t1 + f_aa * a_t0 * a_t1) / f_a;
let a_t1t1 = -(f_aa * a_t1 * a_t1) / f_a;
assert!((a_jet.value() - a0).abs() < 1e-12);
assert!((a_jet.g[0] - a_t0).abs() < 1e-12, "a_t0 {} vs {a_t0}", a_jet.g[0]);
assert!((a_jet.g[1] - a_t1).abs() < 1e-12, "a_t1 {} vs {a_t1}", a_jet.g[1]);
assert!(
(a_jet.h[0] - a_t0t0).abs() < 1e-12,
"a_t0t0 {} vs {a_t0t0}",
a_jet.h[0]
);
assert!(
(a_jet.h[1] - a_t0t1).abs() < 1e-12,
"a_t0t1 {} vs {a_t0t1}",
a_jet.h[1]
);
assert!(
(a_jet.h[p + 1] - a_t1t1).abs() < 1e-12,
"a_t1t1 {} vs {a_t1t1}",
a_jet.h[p + 1]
);
assert!((a_jet.h[1] - a_jet.h[p]).abs() < 1e-14);
}
#[test]
fn cell_base_moment_jets_match_fd_932() {
use crate::cubic_cell_kernel::{DenestedCubicCell, evaluate_cell_moments};
let base = [0.10_f64, 0.50, -0.20, 0.10];
let (left, right) = (-0.80_f64, 0.70_f64);
let cell = |c: [f64; 4]| DenestedCubicCell {
left,
right,
c0: c[0],
c1: c[1],
c2: c[2],
c3: c[3],
};
let max_n = 4usize;
let scalar_deg = max_n + 12;
let moments_at = |c: [f64; 4]| -> Vec<f64> {
evaluate_cell_moments(cell(c), scalar_deg)
.expect("cell moments")
.moments
.into_vec()
};
let p = 4usize;
let c_jets: [Jet2; 4] = std::array::from_fn(|k| Jet2::primary(base[k], k, p));
let scalar_moments = moments_at(base);
let m_jets = cell_base_moment_jets(&c_jets, base, &scalar_moments, max_n);
for n in 0..=max_n {
assert!(
(m_jets[n].v - scalar_moments[n]).abs() <= 1e-12 * scalar_moments[n].abs().max(1.0),
"M[{n}] value {:+.12e} != scalar {:+.12e}",
m_jets[n].v,
scalar_moments[n]
);
}
let h = 1e-4_f64;
for n in 0..=max_n {
for k in 0..p {
let mut cp = base;
let mut cm = base;
cp[k] += h;
cm[k] -= h;
let fd_g = (moments_at(cp)[n] - moments_at(cm)[n]) / (2.0 * h);
assert!(
(m_jets[n].g[k] - fd_g).abs() <= 1e-5 * fd_g.abs().max(1.0) + 1e-9,
"dM[{n}]/dc[{k}] jet {:+.12e} != fd {:+.12e}",
m_jets[n].g[k],
fd_g
);
for l in 0..p {
let mut cpp = base;
let mut cpm = base;
let mut cmp = base;
let mut cmm = base;
cpp[k] += h;
cpp[l] += h;
cpm[k] += h;
cpm[l] -= h;
cmp[k] -= h;
cmp[l] += h;
cmm[k] -= h;
cmm[l] -= h;
let fd_h = (moments_at(cpp)[n] - moments_at(cpm)[n] - moments_at(cmp)[n]
+ moments_at(cmm)[n])
/ (4.0 * h * h);
assert!(
(m_jets[n].h[k * p + l] - fd_h).abs() <= 1e-3 * fd_h.abs().max(1.0) + 1e-6,
"d2M[{n}]/dc[{k}]dc[{l}] jet {:+.12e} != fd {:+.12e}",
m_jets[n].h[k * p + l],
fd_h
);
}
}
}
}
#[test]
fn cell_base_moment_jets_moving_match_fd_932() {
use crate::cubic_cell_kernel::{DenestedCubicCell, evaluate_cell_moments};
let base_c = [0.10_f64, 0.50, -0.20, 0.10];
let (zl0, zr0) = (-0.80_f64, 0.70_f64);
let max_n = 4usize;
let scalar_deg = max_n + 12;
let p = 6usize;
let base6 = [base_c[0], base_c[1], base_c[2], base_c[3], zl0, zr0];
let cell_of = |t: [f64; 6]| DenestedCubicCell {
left: t[4],
right: t[5],
c0: t[0],
c1: t[1],
c2: t[2],
c3: t[3],
};
let moments_at = |t: [f64; 6]| -> Vec<f64> {
evaluate_cell_moments(cell_of(t), scalar_deg)
.expect("cell moments")
.moments
.into_vec()
};
let scalar_moments = moments_at(base6);
let c_jets: [Jet2; 4] = std::array::from_fn(|k| Jet2::primary(base_c[k], k, p));
let zl_jet = Jet2::primary(zl0, 4, p);
let zr_jet = Jet2::primary(zr0, 5, p);
let m_jets = cell_base_moment_jets_moving(
&c_jets,
base_c,
&scalar_moments,
max_n,
&zl_jet,
zl0,
&zr_jet,
zr0,
);
for n in 0..=max_n {
assert!(
(m_jets[n].v - scalar_moments[n]).abs() <= 1e-12 * scalar_moments[n].abs().max(1.0),
"moving M[{n}] value {:+.12e} != scalar {:+.12e}",
m_jets[n].v,
scalar_moments[n]
);
}
let h = 1e-4_f64;
for n in 0..=max_n {
for k in 0..p {
let mut tp = base6;
let mut tm = base6;
tp[k] += h;
tm[k] -= h;
let fd_g = (moments_at(tp)[n] - moments_at(tm)[n]) / (2.0 * h);
assert!(
(m_jets[n].g[k] - fd_g).abs() <= 1e-5 * fd_g.abs().max(1.0) + 1e-9,
"moving dM[{n}]/dθ[{k}] jet {:+.12e} != fd {:+.12e}",
m_jets[n].g[k],
fd_g
);
for l in 0..p {
let mut tpp = base6;
let mut tpm = base6;
let mut tmp = base6;
let mut tmm = base6;
tpp[k] += h;
tpp[l] += h;
tpm[k] += h;
tpm[l] -= h;
tmp[k] -= h;
tmp[l] += h;
tmm[k] -= h;
tmm[l] -= h;
let fd_h = (moments_at(tpp)[n] - moments_at(tpm)[n] - moments_at(tmp)[n]
+ moments_at(tmm)[n])
/ (4.0 * h * h);
assert!(
(m_jets[n].h[k * p + l] - fd_h).abs() <= 2e-3 * fd_h.abs().max(1.0) + 1e-6,
"moving d2M[{n}]/dθ[{k}]dθ[{l}] jet {:+.12e} != fd {:+.12e}",
m_jets[n].h[k * p + l],
fd_h
);
}
}
}
}
#[test]
fn cell_coeff_jet_ab_match_fd_932() {
use crate::cubic_cell_kernel::{
LocalSpanCubic, denested_cell_coefficient_partials, denested_cell_coefficients,
denested_cell_second_partials,
};
let score_span = LocalSpanCubic {
left: -1.0,
right: 1.0,
c0: 0.10,
c1: 0.30,
c2: -0.10,
c3: 0.05,
};
let link_span = LocalSpanCubic {
left: -1.20,
right: 0.90,
c0: 0.20,
c1: 0.25,
c2: 0.08,
c3: -0.03,
};
let (a0, b0) = (0.31_f64, 0.60_f64);
let c0 = denested_cell_coefficients(score_span, link_span, a0, b0);
let (dc_da, dc_db) = denested_cell_coefficient_partials(score_span, link_span, a0, b0);
let (dc_daa, dc_dab, dc_dbb) =
denested_cell_second_partials(score_span, link_span, a0, b0);
let part = CellCoeffAbPartials {
c0,
dc_da,
dc_db,
dc_daa,
dc_dab,
dc_dbb,
};
let p = 2usize; let a_jet = Jet2::primary(a0, 0, p);
let b_jet = Jet2::primary(b0, 1, p);
let cj = cell_coeff_jet_ab(&part, &a_jet, &b_jet, a0, b0);
let coeff_at = |a: f64, b: f64| denested_cell_coefficients(score_span, link_span, a, b);
let h = 1e-5_f64;
for k in 0..4 {
assert!(
(cj[k].v - c0[k]).abs() <= 1e-12 * c0[k].abs().max(1.0),
"c[{k}] value {:+.12e} != {:+.12e}",
cj[k].v,
c0[k]
);
let g_a = (coeff_at(a0 + h, b0)[k] - coeff_at(a0 - h, b0)[k]) / (2.0 * h);
let g_b = (coeff_at(a0, b0 + h)[k] - coeff_at(a0, b0 - h)[k]) / (2.0 * h);
assert!(
(cj[k].g[0] - g_a).abs() <= 1e-6 * g_a.abs().max(1.0) + 1e-9,
"dc[{k}]/da jet {:+.12e} != fd {:+.12e}",
cj[k].g[0],
g_a
);
assert!(
(cj[k].g[1] - g_b).abs() <= 1e-6 * g_b.abs().max(1.0) + 1e-9,
"dc[{k}]/db jet {:+.12e} != fd {:+.12e}",
cj[k].g[1],
g_b
);
let h_aa = (coeff_at(a0 + h, b0)[k] - 2.0 * c0[k] + coeff_at(a0 - h, b0)[k]) / (h * h);
let h_bb = (coeff_at(a0, b0 + h)[k] - 2.0 * c0[k] + coeff_at(a0, b0 - h)[k]) / (h * h);
let h_ab = (coeff_at(a0 + h, b0 + h)[k] - coeff_at(a0 + h, b0 - h)[k]
- coeff_at(a0 - h, b0 + h)[k]
+ coeff_at(a0 - h, b0 - h)[k])
/ (4.0 * h * h);
assert!(
(cj[k].h[0] - h_aa).abs() <= 1e-4 * h_aa.abs().max(1.0) + 1e-5,
"d2c[{k}]/da2 jet {:+.12e} != fd {:+.12e}",
cj[k].h[0],
h_aa
);
assert!(
(cj[k].h[3] - h_bb).abs() <= 1e-4 * h_bb.abs().max(1.0) + 1e-5,
"d2c[{k}]/db2 jet {:+.12e} != fd {:+.12e}",
cj[k].h[3],
h_bb
);
assert!(
(cj[k].h[1] - h_ab).abs() <= 1e-4 * h_ab.abs().max(1.0) + 1e-5,
"d2c[{k}]/dadb jet {:+.12e} != fd {:+.12e}",
cj[k].h[1],
h_ab
);
}
}
}