use super::*;
use crate::bms::signed_probit_neglog_derivatives_up_to_fourth;
use crate::survival::marginal_slope::gpu;
use crate::inference::probability::signed_probit_logcdf_and_mills_ratio;
#[inline]
fn surv_stack(eta: f64) -> Result<[f64; 5], String> {
let (logcdf, _) = signed_probit_logcdf_and_mills_ratio(-eta);
let (k1, k2, k3, k4) = signed_probit_neglog_derivatives_up_to_fourth(-eta, 1.0)?;
Ok([logcdf, k1, -k2, k3, -k4])
}
#[inline]
fn ln_stack(x: f64) -> [f64; 5] {
let inv = 1.0 / x;
let inv2 = inv * inv;
[x.ln(), inv, -inv2, 2.0 * inv2 * inv, -6.0 * inv2 * inv2]
}
trait FlexJet: Sized + Clone {
fn value(&self) -> f64;
fn add(&self, o: &Self) -> Self;
fn sub(&self, o: &Self) -> Self;
fn mul(&self, o: &Self) -> Self;
fn scale(&self, s: f64) -> Self;
fn compose_unary(&self, d: [f64; 5]) -> Self;
#[inline]
fn ln(&self) -> Self {
self.compose_unary(ln_stack(self.value()))
}
}
#[inline]
fn flex_row_nll<J: FlexJet>(
eta0: &J,
eta1: &J,
chi1: &J,
d1: &J,
q1: &J,
qd1: &J,
surv0: [f64; 5],
surv1: [f64; 5],
wi: f64,
di: f64,
) -> J {
let wd = wi * di;
let mut nll = eta0.compose_unary(surv0).scale(wi);
nll = nll.add(&eta1.compose_unary(surv1).scale(-wi * (1.0 - di)));
nll = nll.add(&eta1.mul(eta1).scale(0.5 * wd));
nll = nll.add(&q1.mul(q1).scale(0.5 * wd));
nll = nll.sub(&chi1.ln().scale(wd));
nll = nll.add(&d1.ln().scale(wd));
nll = nll.sub(&qd1.ln().scale(wd));
nll
}
#[derive(Clone)]
struct Jet2 {
v: f64,
g: Vec<f64>,
h: Vec<f64>,
}
impl Jet2 {
fn from_parts(v: f64, g: &[f64], h: &[f64]) -> Self {
let p = g.len();
let hv = if h.is_empty() {
vec![0.0; p * p]
} else {
assert_eq!(h.len(), p * p, "Jet2::from_parts Hessian length");
h.to_vec()
};
Jet2 {
v,
g: g.to_vec(),
h: hv,
}
}
fn from_view(
v: f64,
g: ndarray::ArrayView1<'_, f64>,
h: Option<ndarray::ArrayView2<'_, f64>>,
) -> Self {
let p = g.len();
let gv: Vec<f64> = g.iter().copied().collect();
let hv = match h {
Some(hm) => {
let mut out = vec![0.0; p * p];
for i in 0..p {
for j in 0..p {
out[i * p + j] = hm[[i, j]];
}
}
out
}
None => vec![0.0; p * p],
};
Jet2 { v, g: gv, h: hv }
}
fn primary(x: f64, axis: usize, p: usize) -> Self {
let mut g = vec![0.0; p];
if axis < p {
g[axis] = 1.0;
}
Jet2 {
v: x,
g,
h: vec![0.0; p * p],
}
}
#[inline]
fn p(&self) -> usize {
self.g.len()
}
}
impl FlexJet for Jet2 {
#[inline]
fn value(&self) -> f64 {
self.v
}
fn add(&self, o: &Self) -> Self {
let p = self.p();
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = self.g[i] + o.g[i];
}
for k in 0..p * p {
h[k] = self.h[k] + o.h[k];
}
Jet2 {
v: self.v + o.v,
g,
h,
}
}
fn sub(&self, o: &Self) -> Self {
let p = self.p();
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = self.g[i] - o.g[i];
}
for k in 0..p * p {
h[k] = self.h[k] - o.h[k];
}
Jet2 {
v: self.v - o.v,
g,
h,
}
}
fn mul(&self, o: &Self) -> Self {
let p = self.p();
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = self.v * o.g[i] + self.g[i] * o.v;
}
for i in 0..p {
for j in 0..p {
h[i * p + j] = self.v * o.h[i * p + j]
+ self.g[i] * o.g[j]
+ self.g[j] * o.g[i]
+ self.h[i * p + j] * o.v;
}
}
Jet2 {
v: self.v * o.v,
g,
h,
}
}
fn scale(&self, s: f64) -> Self {
Jet2 {
v: self.v * s,
g: self.g.iter().map(|&x| x * s).collect(),
h: self.h.iter().map(|&x| x * s).collect(),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let p = self.p();
let (f, f1, f2) = (d[0], d[1], d[2]);
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = f1 * self.g[i];
}
for i in 0..p {
for j in 0..p {
h[i * p + j] = f2 * self.g[i] * self.g[j] + f1 * self.h[i * p + j];
}
}
Jet2 { v: f, g, h }
}
}
#[derive(Clone)]
struct Jet3 {
base: Jet2,
eps: Jet2,
}
impl Jet3 {
fn primary(x: f64, axis: usize, p: usize, dir_axis: f64) -> Self {
Jet3 {
base: Jet2::primary(x, axis, p),
eps: Jet2::from_parts(dir_axis, &vec![0.0; p], &[]),
}
}
fn contracted_third(&self) -> Vec<f64> {
self.eps.h.clone()
}
}
impl FlexJet for Jet3 {
#[inline]
fn value(&self) -> f64 {
self.base.v
}
fn add(&self, o: &Self) -> Self {
Jet3 {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
}
}
fn sub(&self, o: &Self) -> Self {
Jet3 {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
}
}
fn mul(&self, o: &Self) -> Self {
Jet3 {
base: self.base.mul(&o.base),
eps: self.base.mul(&o.eps).add(&self.eps.mul(&o.base)),
}
}
fn scale(&self, s: f64) -> Self {
Jet3 {
base: self.base.scale(s),
eps: self.eps.scale(s),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
let eps = fprime.mul(&self.eps);
Jet3 { base, eps }
}
}
#[derive(Clone)]
struct Jet4 {
base: Jet2,
eps: Jet2,
del: Jet2,
eps_del: Jet2,
}
impl Jet4 {
fn primary(x: f64, axis: usize, p: usize, du: f64, dv: f64) -> Self {
let zero = vec![0.0; p];
Jet4 {
base: Jet2::primary(x, axis, p),
eps: Jet2::from_parts(du, &zero, &[]),
del: Jet2::from_parts(dv, &zero, &[]),
eps_del: Jet2::from_parts(0.0, &zero, &[]),
}
}
fn contracted_fourth(&self) -> Vec<f64> {
self.eps_del.h.clone()
}
}
impl FlexJet for Jet4 {
#[inline]
fn value(&self) -> f64 {
self.base.v
}
fn add(&self, o: &Self) -> Self {
Jet4 {
base: self.base.add(&o.base),
eps: self.eps.add(&o.eps),
del: self.del.add(&o.del),
eps_del: self.eps_del.add(&o.eps_del),
}
}
fn sub(&self, o: &Self) -> Self {
Jet4 {
base: self.base.sub(&o.base),
eps: self.eps.sub(&o.eps),
del: self.del.sub(&o.del),
eps_del: self.eps_del.sub(&o.eps_del),
}
}
fn mul(&self, o: &Self) -> Self {
let base = self.base.mul(&o.base);
let eps = self.base.mul(&o.eps).add(&self.eps.mul(&o.base));
let del = self.base.mul(&o.del).add(&self.del.mul(&o.base));
let eps_del = self
.base
.mul(&o.eps_del)
.add(&self.eps.mul(&o.del))
.add(&self.del.mul(&o.eps))
.add(&self.eps_del.mul(&o.base));
Jet4 {
base,
eps,
del,
eps_del,
}
}
fn scale(&self, s: f64) -> Self {
Jet4 {
base: self.base.scale(s),
eps: self.eps.scale(s),
del: self.del.scale(s),
eps_del: self.eps_del.scale(s),
}
}
fn compose_unary(&self, d: [f64; 5]) -> Self {
let base = self.base.compose_unary([d[0], d[1], d[2], d[3], d[4]]);
let fprime = self.base.compose_unary([d[1], d[2], d[3], d[4], d[4]]);
let fsecond = self.base.compose_unary([d[2], d[3], d[4], d[4], d[4]]);
let eps = fprime.mul(&self.eps);
let del = fprime.mul(&self.del);
let eps_del = fsecond
.mul(&self.eps)
.mul(&self.del)
.add(&fprime.mul(&self.eps_del));
Jet4 {
base,
eps,
del,
eps_del,
}
}
}
#[inline]
fn dot(x: &[f64], y: &[f64]) -> f64 {
x.iter().zip(y.iter()).map(|(&a, &b)| a * b).sum()
}
fn mat_vec(m: &[f64], v: &[f64], p: usize) -> Vec<f64> {
let mut out = vec![0.0; p];
for i in 0..p {
let mut acc = 0.0;
for j in 0..p {
acc += m[i * p + j] * v[j];
}
out[i] = acc;
}
out
}
fn quad_form(m: &[f64], v1: &[f64], v2: &[f64], p: usize) -> f64 {
let mut acc = 0.0;
for i in 0..p {
let mi = &m[i * p..i * p + p];
acc += v1[i] * dot(mi, v2);
}
acc
}
pub(crate) struct FlexRowJet2Channels<'a> {
pub eta0_v: f64,
pub eta0_g: ndarray::ArrayView1<'a, f64>,
pub eta0_h: Option<ndarray::ArrayView2<'a, f64>>,
pub eta1_v: f64,
pub eta1_g: ndarray::ArrayView1<'a, f64>,
pub eta1_h: Option<ndarray::ArrayView2<'a, f64>>,
pub chi1_v: f64,
pub chi1_g: ndarray::ArrayView1<'a, f64>,
pub chi1_h: Option<ndarray::ArrayView2<'a, f64>>,
pub d1_v: f64,
pub d1_g: ndarray::ArrayView1<'a, f64>,
pub d1_h: Option<ndarray::ArrayView2<'a, f64>>,
}
pub(crate) struct FlexThirdPacks<'a> {
pub entry_base: &'a gpu::SurvivalFlexBlock10TimepointBase,
pub exit_base: &'a gpu::SurvivalFlexBlock10TimepointBase,
pub entry_ext: &'a gpu::SurvivalFlexBlock10TimepointDirectional,
pub exit_ext: &'a gpu::SurvivalFlexBlock10TimepointDirectional,
}
pub(crate) struct FlexFourthPacks<'a> {
pub entry_base: &'a gpu::SurvivalFlexBlock10TimepointBase,
pub exit_base: &'a gpu::SurvivalFlexBlock10TimepointBase,
pub entry_ext_u: &'a gpu::SurvivalFlexBlock10TimepointDirectional,
pub exit_ext_u: &'a gpu::SurvivalFlexBlock10TimepointDirectional,
pub entry_ext_v: &'a gpu::SurvivalFlexBlock10TimepointDirectional,
pub exit_ext_v: &'a gpu::SurvivalFlexBlock10TimepointDirectional,
pub entry_bi: &'a gpu::SurvivalFlexBlock10TimepointBiDirectional,
pub exit_bi: &'a gpu::SurvivalFlexBlock10TimepointBiDirectional,
}
impl SurvivalMarginalSlopeFamily {
pub(crate) fn flex_row_nll_value_grad_hess(
&self,
row: usize,
primary: &FlexPrimarySlices,
q1: f64,
qd1: f64,
ch: FlexRowJet2Channels<'_>,
) -> Result<(f64, Array1<f64>, Array2<f64>), String> {
let FlexRowJet2Channels {
eta0_v,
eta0_g,
eta0_h,
eta1_v,
eta1_g,
eta1_h,
chi1_v,
chi1_g,
chi1_h,
d1_v,
d1_g,
d1_h,
} = ch;
let p = primary.total;
let wi = self.weights[row];
let di = self.event[row];
let surv0 = surv_stack(eta0_v)?;
let surv1 = surv_stack(eta1_v)?;
let want_hess = eta1_h.is_some();
let eta0 = Jet2::from_view(eta0_v, eta0_g, eta0_h);
let eta1 = Jet2::from_view(eta1_v, eta1_g, eta1_h);
let chi1 = Jet2::from_view(chi1_v, chi1_g, chi1_h);
let d1 = Jet2::from_view(d1_v, d1_g, d1_h);
let q1j = Jet2::primary(q1, primary.q1, p);
let qd1j = Jet2::primary(qd1, primary.qd1, p);
let out = flex_row_nll(&eta0, &eta1, &chi1, &d1, &q1j, &qd1j, surv0, surv1, wi, di);
let value = out.v + wi * di * std::f64::consts::TAU.ln();
let grad = Array1::from(out.g);
let hess = if want_hess {
Array2::from_shape_vec((p, p), out.h).map_err(|e| e.to_string())?
} else {
Array2::zeros((p, p))
};
Ok((value, grad, hess))
}
pub(crate) fn flex_row_nll_third_contracted(
&self,
row: usize,
primary: &FlexPrimarySlices,
q1: f64,
qd1: f64,
dir: &[f64],
packs: FlexThirdPacks<'_>,
) -> Result<Array2<f64>, String> {
let FlexThirdPacks {
entry_base,
exit_base,
entry_ext,
exit_ext,
} = packs;
let p = primary.total;
let wi = self.weights[row];
let di = self.event[row];
let surv0 = surv_stack(entry_base.eta)?;
let surv1 = surv_stack(exit_base.eta)?;
let mk =
|base_v: f64, base_g: &[f64], base_h: &[f64], ext_g: &[f64], ext_h: &[f64]| -> Jet3 {
Jet3 {
base: Jet2::from_parts(base_v, base_g, base_h),
eps: Jet2::from_parts(dot(base_g, dir), ext_g, ext_h),
}
};
let eta0 = mk(
entry_base.eta,
&entry_base.eta_u,
&entry_base.eta_uv,
&entry_ext.eta_u_dir,
&entry_ext.eta_uv_dir,
);
let eta1 = mk(
exit_base.eta,
&exit_base.eta_u,
&exit_base.eta_uv,
&exit_ext.eta_u_dir,
&exit_ext.eta_uv_dir,
);
let chi1 = mk(
exit_base.chi,
&exit_base.chi_u,
&exit_base.chi_uv,
&exit_ext.chi_u_dir,
&exit_ext.chi_uv_dir,
);
let d1 = mk(
exit_base.d,
&exit_base.d_u,
&exit_base.d_uv,
&exit_ext.d_u_dir,
&exit_ext.d_uv_dir,
);
let q1j = Jet3::primary(q1, primary.q1, p, dir[primary.q1]);
let qd1j = Jet3::primary(qd1, primary.qd1, p, dir[primary.qd1]);
let out = flex_row_nll(&eta0, &eta1, &chi1, &d1, &q1j, &qd1j, surv0, surv1, wi, di);
Array2::from_shape_vec((p, p), out.contracted_third()).map_err(|e| e.to_string())
}
pub(crate) fn flex_row_nll_fourth_contracted(
&self,
row: usize,
primary: &FlexPrimarySlices,
q1: f64,
qd1: f64,
dir_u: &[f64],
dir_v: &[f64],
packs: FlexFourthPacks<'_>,
) -> Result<Array2<f64>, String> {
let FlexFourthPacks {
entry_base,
exit_base,
entry_ext_u,
exit_ext_u,
entry_ext_v,
exit_ext_v,
entry_bi,
exit_bi,
} = packs;
let p = primary.total;
let wi = self.weights[row];
let di = self.event[row];
let surv0 = surv_stack(entry_base.eta)?;
let surv1 = surv_stack(exit_base.eta)?;
let mk = |base_v: f64,
base_g: &[f64],
base_h: &[f64],
ext_u_g: &[f64],
ext_u_h: &[f64],
ext_v_g: &[f64],
ext_v_h: &[f64],
bi_h: &[f64]|
-> Jet4 {
let eps_del_v = quad_form(base_h, dir_u, dir_v, p);
let eps_del_g = mat_vec(ext_u_h, dir_v, p);
Jet4 {
base: Jet2::from_parts(base_v, base_g, base_h),
eps: Jet2::from_parts(dot(base_g, dir_u), ext_u_g, ext_u_h),
del: Jet2::from_parts(dot(base_g, dir_v), ext_v_g, ext_v_h),
eps_del: Jet2::from_parts(eps_del_v, &eps_del_g, bi_h),
}
};
let eta0 = mk(
entry_base.eta,
&entry_base.eta_u,
&entry_base.eta_uv,
&entry_ext_u.eta_u_dir,
&entry_ext_u.eta_uv_dir,
&entry_ext_v.eta_u_dir,
&entry_ext_v.eta_uv_dir,
&entry_bi.eta_uv_uv,
);
let eta1 = mk(
exit_base.eta,
&exit_base.eta_u,
&exit_base.eta_uv,
&exit_ext_u.eta_u_dir,
&exit_ext_u.eta_uv_dir,
&exit_ext_v.eta_u_dir,
&exit_ext_v.eta_uv_dir,
&exit_bi.eta_uv_uv,
);
let chi1 = mk(
exit_base.chi,
&exit_base.chi_u,
&exit_base.chi_uv,
&exit_ext_u.chi_u_dir,
&exit_ext_u.chi_uv_dir,
&exit_ext_v.chi_u_dir,
&exit_ext_v.chi_uv_dir,
&exit_bi.chi_uv_uv,
);
let d1 = mk(
exit_base.d,
&exit_base.d_u,
&exit_base.d_uv,
&exit_ext_u.d_u_dir,
&exit_ext_u.d_uv_dir,
&exit_ext_v.d_u_dir,
&exit_ext_v.d_uv_dir,
&exit_bi.d_uv_uv,
);
let q1j = Jet4::primary(q1, primary.q1, p, dir_u[primary.q1], dir_v[primary.q1]);
let qd1j = Jet4::primary(qd1, primary.qd1, p, dir_u[primary.qd1], dir_v[primary.qd1]);
let out = flex_row_nll(&eta0, &eta1, &chi1, &d1, &q1j, &qd1j, surv0, surv1, wi, di);
Array2::from_shape_vec((p, p), out.contracted_fourth()).map_err(|e| e.to_string())
}
}
fn recip<J: FlexJet>(x: &J) -> J {
let v = x.value();
let inv = 1.0 / v;
let inv2 = inv * inv;
x.compose_unary([
inv,
-inv2,
2.0 * inv2 * inv,
-6.0 * inv2 * inv2,
24.0 * inv2 * inv2 * inv,
])
}
fn exp_jet<J: FlexJet>(x: &J) -> J {
let e = x.value().exp();
x.compose_unary([e, e, e, e, e])
}
fn add_const<J: FlexJet>(x: &J, c: f64) -> J {
x.compose_unary([x.value() + c, 1.0, 0.0, 0.0, 0.0])
}
trait MomentTerm: FlexJet {
fn moment_term(&self, m: &Self) -> Self;
}
impl MomentTerm for Jet2 {
fn moment_term(&self, m: &Self) -> Self {
let p = self.p();
let mut g = vec![0.0; p];
let mut h = vec![0.0; p * p];
for i in 0..p {
g[i] = self.g[i] * m.v;
}
for i in 0..p {
for j in 0..p {
h[i * p + j] =
self.h[i * p + j] * m.v + 0.5 * (self.g[i] * m.g[j] + self.g[j] * m.g[i]);
}
}
Jet2 { v: 0.0, g, h }
}
}
fn base_moment_jets<J: FlexJet>(
c: &[J; 4],
z_left: &J,
left_finite: bool,
z_right: &J,
right_finite: bool,
numeric_moments: &[f64],
) -> [J; 5] {
let c0_const: [J; 4] = std::array::from_fn(|k| const_jet_like(&c[k], c[k].value()));
let conv = |lhs: &[J], rhs: &[J]| -> Vec<J> {
let mut out: Vec<J> = (0..lhs.len() + rhs.len() - 1)
.map(|_| const_jet_like(&c[0], 0.0))
.collect();
for (i, li) in lhs.iter().enumerate() {
for (j, rj) in rhs.iter().enumerate() {
out[i + j] = out[i + j].add(&li.mul(rj));
}
}
out
};
let eta_sq = conv(c, c);
let eta0_sq = conv(&c0_const, &c0_const);
let neg_dq: Vec<J> = eta_sq
.iter()
.zip(eta0_sq.iter())
.map(|(a, b)| a.sub(b).scale(-0.5))
.collect();
let mut s_poly: Vec<J> = vec![const_jet_like(&c[0], 1.0)];
let mut power: Vec<J> = s_poly.clone();
let factorials = [1.0_f64, 1.0, 2.0, 6.0, 24.0];
for fact in factorials.iter().skip(1) {
power = conv(&power, &neg_dq);
for (m, coeff) in power.iter().enumerate() {
let term = coeff.scale(1.0 / fact);
if m < s_poly.len() {
s_poly[m] = s_poly[m].add(&term);
} else {
s_poly.push(term);
}
}
}
std::array::from_fn(|n| {
let mut acc = const_jet_like(&c[0], 0.0);
for (m, s_m) in s_poly.iter().enumerate() {
let m_npm = numeric_moments.get(n + m).copied().unwrap_or(0.0);
if m_npm != 0.0 {
acc = acc.add(&s_m.scale(m_npm));
}
}
if let Some(sr) = edge_sliver_jet(n, c, z_right, right_finite) {
acc = acc.add(&sr);
}
if let Some(sl) = edge_sliver_jet(n, c, z_left, left_finite) {
acc = acc.sub(&sl);
}
acc
})
}
fn edge_sliver_jet<J: FlexJet>(n: usize, c: &[J; 4], z_e: &J, finite: bool) -> Option<J> {
if !finite {
return None;
}
let z0 = z_e.value();
let zc = const_jet_like(z_e, z0); let eta = c[3]
.mul(&zc)
.add(&c[2])
.mul(&zc)
.add(&c[1])
.mul(&zc)
.add(&c[0]);
let eta_z = c[2]
.scale(2.0)
.add(&c[3].scale(3.0).mul(&zc))
.mul(&zc)
.add(&c[1]); let eta_zz = c[2].scale(2.0).add(&c[3].scale(6.0).mul(&zc)); let eta_zzz = c[3].scale(6.0); let q_z = zc.add(&eta.mul(&eta_z));
let q_zz = add_const(&eta_z.mul(&eta_z).add(&eta.mul(&eta_zz)), 1.0);
let q_zzz = eta_z.scale(3.0).mul(&eta_zz).add(&eta.mul(&eta_zzz));
let z_pow = {
let mut zk = const_jet_like(z_e, 1.0);
for _ in 0..n {
zk = zk.mul(&zc);
}
zk
};
let q = zc.mul(&zc).add(&eta.mul(&eta)).scale(0.5);
let w = exp_jet(&q.scale(-1.0));
let g = z_pow.mul(&w);
let nz = |power: i32| -> J {
if n == 0 || z0 == 0.0 {
const_jet_like(z_e, 0.0)
} else {
const_jet_like(z_e, n as f64 / z0.powi(power))
}
};
let a1 = nz(1).sub(&q_z);
let a1p = nz(2).scale(-1.0).sub(&q_zz);
let a1pp = nz(3).scale(2.0).sub(&q_zzz);
let g_z = a1.mul(&g);
let b2 = a1p.add(&a1.mul(&a1));
let g_zz = b2.mul(&g);
let b2p = a1pp.add(&a1.mul(&a1p).scale(2.0));
let g_zzz = b2p.add(&a1.mul(&b2)).mul(&g);
let delta = tangent_jet(z_e);
let d2 = delta.mul(&delta);
let d3 = d2.mul(&delta);
let d4 = d3.mul(&delta);
Some(
g.mul(&delta)
.add(&g_z.mul(&d2).scale(0.5))
.add(&g_zz.mul(&d3).scale(1.0 / 6.0))
.add(&g_zzz.mul(&d4).scale(1.0 / 24.0)),
)
}
fn flex_timepoint_inputs_generic<J: FlexJet + MomentTerm>(
template: &J,
b_jet: &J,
du: &[J],
a0: f64,
d_check: f64,
primary_g: usize,
infl: Option<usize>,
q_index: usize,
q: f64,
z_obs: f64,
o_infl: f64,
obs_coeff: [f64; 4],
obs_fixed: &DenestedCellPrimaryFixedPartials,
cells: &[CalibrationCellJetInputs<'_>],
) -> Result<(J, J, J), String> {
let residual = |a: &J| calibration_residual_jet(a, b_jet, primary_g, du, q_index, q, cells);
let a_jet = lift_intercept_flex(template, a0, 1.0 / d_check, 4, residual);
let da = tangent_jet(&a_jet);
let eta_coeff = cell_coeff_jets(&a_jet, obs_coeff, obs_fixed, primary_g, &da, du);
let chi_coeff = cell_chi_poly_jets(&a_jet, obs_fixed, primary_g, &da, du);
let mut eta = add_const(&eval_coeff_jet_at(&eta_coeff, z_obs), o_infl);
if let Some(infl_axis) = infl {
eta = eta.add(&du[infl_axis]);
}
let chi = eval_coeff_jet_at(&chi_coeff, z_obs);
let mut d = const_jet_like(template, 0.0);
for cell in cells {
let c_pos = cell_coeff_jets(&a_jet, cell.base_pos_coeffs, cell.fixed, primary_g, &da, du);
let chi_jets = cell_chi_poly_jets(&a_jet, cell.fixed, primary_g, &da, du);
let edge_l = cell_edge_jet(&a_jet, b_jet, cell.left_edge, cell.cell_left);
let edge_r = cell_edge_jet(&a_jet, b_jet, cell.right_edge, cell.cell_right);
d = d.add(&flex_timepoint_d_cell(
template,
&c_pos,
&chi_jets,
&edge_l,
cell.cell_left.is_finite(),
&edge_r,
cell.cell_right.is_finite(),
cell.numeric_moments,
));
}
Ok((eta, chi, d))
}
#[inline]
fn tangent_jet<J: FlexJet>(x: &J) -> J {
add_const(x, -x.value())
}
#[inline]
fn const_jet_like<J: FlexJet>(template: &J, v: f64) -> J {
add_const(&template.scale(0.0), v)
}
fn lift_intercept_flex<J: FlexJet>(
template: &J,
a0: f64,
inv_fa: f64,
iters: usize,
residual: impl Fn(&J) -> J,
) -> J {
let mut a = const_jet_like(template, a0);
for _ in 0..iters {
let r = residual(&a);
a = a.sub(&r.scale(inv_fa));
}
a
}
fn calibration_residual_jet<J: FlexJet + MomentTerm>(
a_jet: &J,
b_jet: &J,
g_axis: usize,
du: &[J],
q_index: usize,
q: f64,
cells: &[CalibrationCellJetInputs<'_>],
) -> J {
let da = tangent_jet(a_jet);
let inv_two_pi = std::f64::consts::TAU.recip();
let mut r = const_jet_like(a_jet, 0.0);
for cell in cells {
let c_pos = cell_coeff_jets(a_jet, cell.base_pos_coeffs, cell.fixed, g_axis, &da, du);
let edge_l = cell_edge_jet(a_jet, b_jet, cell.left_edge, cell.cell_left);
let edge_r = cell_edge_jet(a_jet, b_jet, cell.right_edge, cell.cell_right);
let m = base_moment_jets(
&c_pos,
&edge_l,
cell.cell_left.is_finite(),
&edge_r,
cell.cell_right.is_finite(),
cell.numeric_moments,
);
let mut cell_r = const_jet_like(a_jet, 0.0);
for k in 0..4 {
cell_r = cell_r.add(&c_pos[k].moment_term(&m[k]));
}
r = r.add(&cell_r.scale(inv_two_pi));
}
if q_index < du.len() {
let phi_q = crate::probability::normal_pdf(q);
let g0 = crate::probability::normal_cdf(-q);
let g1 = -phi_q;
let g2 = q * phi_q;
let g3 = (1.0 - q * q) * phi_q;
let g4 = (q * q * q - 3.0 * q) * phi_q;
let q_jet = add_const(&du[q_index], q);
let q_self = add_const(&q_jet.compose_unary([g0, g1, g2, g3, g4]), -g0);
r = r.add(&q_self);
}
r
}
struct CalibrationCellJetInputs<'a> {
base_pos_coeffs: [f64; 4],
fixed: &'a DenestedCellPrimaryFixedPartials,
cell_left: f64,
cell_right: f64,
left_edge: crate::cubic_cell_kernel::PartitionEdge,
right_edge: crate::cubic_cell_kernel::PartitionEdge,
numeric_moments: &'a [f64],
}
fn cell_edge_jet<J: FlexJet>(
a_jet: &J,
b_jet: &J,
edge: crate::cubic_cell_kernel::PartitionEdge,
z_value: f64,
) -> J {
match edge {
crate::cubic_cell_kernel::PartitionEdge::Crossing { tau } => {
const_jet_like(a_jet, tau).sub(a_jet).mul(&recip(b_jet))
}
crate::cubic_cell_kernel::PartitionEdge::Fixed(_) => {
const_jet_like(a_jet, z_value)
}
}
}
fn cell_coeff_jets<J: FlexJet>(
template: &J,
base_c: [f64; 4],
fixed: &DenestedCellPrimaryFixedPartials,
g_axis: usize,
da: &J,
du: &[J],
) -> [J; 4] {
let p = du.len();
let dada = da.mul(da);
let dadada = dada.mul(da);
let db = &du[g_axis];
let dadb = da.mul(db);
let dbdb = db.mul(db);
std::array::from_fn(|k| {
let mut c = const_jet_like(template, base_c[k]);
c = c
.add(&da.scale(fixed.dc_da[k]))
.add(&dada.scale(0.5 * fixed.dc_daa[k]))
.add(&dadada.scale(fixed.dc_daaa[k] / 6.0));
for u in 0..p {
if u == g_axis {
continue;
}
let duu = &du[u];
let mut chain = duu.scale(fixed.coeff_u[u][k]);
chain = chain
.add(&da.mul(duu).scale(fixed.coeff_au[u][k]))
.add(&dada.mul(duu).scale(0.5 * fixed.coeff_aau[u][k]));
chain = chain
.add(&db.mul(duu).scale(fixed.coeff_bu[u][k]))
.add(&dadb.mul(duu).scale(fixed.coeff_abu[u][k]))
.add(&dbdb.mul(duu).scale(0.5 * fixed.coeff_bbu[u][k]));
chain = chain
.add(&dadada.mul(duu).scale(fixed.coeff_aaau[u][k] / 6.0))
.add(&dada.mul(db).mul(duu).scale(0.5 * fixed.coeff_aabu[u][k]))
.add(&dadb.mul(db).mul(duu).scale(0.5 * fixed.coeff_abbu[u][k]))
.add(&dbdb.mul(db).mul(duu).scale(fixed.coeff_bbbu[u][k] / 6.0));
c = c.add(&chain);
}
c = c
.add(&db.scale(fixed.coeff_u[g_axis][k]))
.add(&dadb.scale(fixed.coeff_au[g_axis][k]))
.add(&dada.mul(db).scale(0.5 * fixed.coeff_aau[g_axis][k]))
.add(&dbdb.scale(0.5 * fixed.coeff_bu[g_axis][k]))
.add(&dadb.mul(db).scale(0.5 * fixed.coeff_abu[g_axis][k]))
.add(&dbdb.mul(db).scale(fixed.coeff_bbu[g_axis][k] / 6.0));
c
})
}
fn cell_chi_poly_jets<J: FlexJet>(
template: &J,
fixed: &DenestedCellPrimaryFixedPartials,
g_axis: usize,
da: &J,
du: &[J],
) -> [J; 4] {
let p = du.len();
let dada = da.mul(da);
let db = &du[g_axis];
std::array::from_fn(|k| {
let mut c = const_jet_like(template, fixed.dc_da[k]);
c = c
.add(&da.scale(fixed.dc_daa[k]))
.add(&dada.scale(0.5 * fixed.dc_daaa[k]));
let dbdb = db.mul(db);
let dadb = da.mul(db);
for u in 0..p {
if u == g_axis {
continue;
}
let duu = &du[u];
let chain = duu
.scale(fixed.coeff_au[u][k])
.add(&da.mul(duu).scale(fixed.coeff_aau[u][k]))
.add(&db.mul(duu).scale(fixed.coeff_abu[u][k]))
.add(&dada.mul(duu).scale(0.5 * fixed.coeff_aaau[u][k]))
.add(&dadb.mul(duu).scale(fixed.coeff_aabu[u][k]))
.add(&dbdb.mul(duu).scale(0.5 * fixed.coeff_abbu[u][k]));
c = c.add(&chain);
}
c = c
.add(&db.scale(fixed.coeff_au[g_axis][k]))
.add(&da.mul(db).scale(fixed.coeff_aau[g_axis][k]))
.add(&dbdb.scale(0.5 * fixed.coeff_abu[g_axis][k]));
c
})
}
fn flex_timepoint_d_cell<J: FlexJet>(
template: &J,
c_jets: &[J; 4],
chi_jets: &[J; 4],
edge_l: &J,
left_finite: bool,
edge_r: &J,
right_finite: bool,
numeric_moments: &[f64],
) -> J {
let m = base_moment_jets(
c_jets,
edge_l,
left_finite,
edge_r,
right_finite,
numeric_moments,
);
let mut acc = const_jet_like(template, 0.0);
for (k, chi_k) in chi_jets.iter().enumerate() {
acc = acc.add(&chi_k.mul(&m[k]));
}
acc.scale(std::f64::consts::TAU.recip())
}
#[inline]
fn eval_coeff_jet_at<J: FlexJet>(coeff_jet: &[J; 4], z: f64) -> J {
let mut zk = 1.0;
let mut acc = const_jet_like(&coeff_jet[0], 0.0);
for c in coeff_jet.iter() {
acc = acc.add(&c.scale(zk));
zk *= z;
}
acc
}
fn cells_from_cached(cached: &CachedPartitionCells) -> Vec<CalibrationCellJetInputs<'_>> {
cached
.cells
.iter()
.map(|entry| {
let cell = entry.partition_cell.cell;
CalibrationCellJetInputs {
base_pos_coeffs: [cell.c0, cell.c1, cell.c2, cell.c3],
fixed: &entry.fixed,
cell_left: cell.left,
cell_right: cell.right,
left_edge: entry.partition_cell.left_edge,
right_edge: entry.partition_cell.right_edge,
numeric_moments: entry.state.moments.as_slice(),
}
})
.collect()
}
fn observed_fixed_for(
family: &SurvivalMarginalSlopeFamily,
primary: &FlexPrimarySlices,
row: usize,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
) -> Result<([f64; 4], DenestedCellPrimaryFixedPartials), String> {
let r = primary.total;
let scale = family.probit_frailty_scale();
let z_obs = family.observed_score_projection(row);
let u_obs = a + b * z_obs;
let obs = family.observed_denested_cell_partials(row, a, b, beta_h, beta_w)?;
let mut coeff_u = vec![[0.0; 4]; r];
let mut coeff_au = vec![[0.0; 4]; r];
let mut coeff_bu = vec![[0.0; 4]; r];
let mut coeff_aau = vec![[0.0; 4]; r];
let mut coeff_abu = vec![[0.0; 4]; r];
let mut coeff_bbu = vec![[0.0; 4]; r];
let mut coeff_aaau = vec![[0.0; 4]; r];
let mut coeff_aabu = vec![[0.0; 4]; r];
let mut coeff_abbu = vec![[0.0; 4]; r];
let mut coeff_bbbu = vec![[0.0; 4]; r];
coeff_u[primary.g] = obs.dc_db;
coeff_au[primary.g] = obs.dc_dab;
coeff_bu[primary.g] = obs.dc_dbb;
coeff_aau[primary.g] = obs.dc_daab;
coeff_abu[primary.g] = obs.dc_dabb;
coeff_bbu[primary.g] = obs.dc_dbbb;
if let Some(h_range) = primary.h.as_ref().filter(|_| family.score_warp.is_some()) {
for local_idx in 0..h_range.len() {
let idx = h_range.start + local_idx;
coeff_u[idx] = scale_coeff4(
family.observed_score_basis_coefficients(row, local_idx, z_obs, b)?,
scale,
);
coeff_bu[idx] = scale_coeff4(
family.observed_score_basis_coefficients(row, local_idx, z_obs, 1.0)?,
scale,
);
}
}
if let (Some(w_range), Some(runtime)) = (primary.w.as_ref(), family.link_dev.as_ref()) {
for local_idx in 0..w_range.len() {
let span = runtime.basis_cubic_at(local_idx, u_obs)?;
let idx = w_range.start + local_idx;
coeff_u[idx] = scale_coeff4(
exact_kernel::link_basis_cell_coefficients(span, a, b),
scale,
);
let (dc_aw, dc_bw) = exact_kernel::link_basis_cell_coefficient_partials(span, a, b);
let (dc_aaw, dc_abw, dc_bbw) =
exact_kernel::link_basis_cell_second_partials(span, a, b);
let (dc_aaaw, dc_aabw, dc_abbw, dc_bbbw) =
exact_kernel::link_basis_cell_third_partials(span);
coeff_au[idx] = scale_coeff4(dc_aw, scale);
coeff_bu[idx] = scale_coeff4(dc_bw, scale);
coeff_aau[idx] = scale_coeff4(dc_aaw, scale);
coeff_abu[idx] = scale_coeff4(dc_abw, scale);
coeff_bbu[idx] = scale_coeff4(dc_bbw, scale);
coeff_aaau[idx] = scale_coeff4(dc_aaaw, scale);
coeff_aabu[idx] = scale_coeff4(dc_aabw, scale);
coeff_abbu[idx] = scale_coeff4(dc_abbw, scale);
coeff_bbbu[idx] = scale_coeff4(dc_bbbw, scale);
}
}
let fixed = DenestedCellPrimaryFixedPartials {
dc_da: obs.dc_da,
dc_daa: obs.dc_daa,
dc_daaa: obs.dc_daaa,
coeff_u,
coeff_au,
coeff_bu,
coeff_aau,
coeff_abu,
coeff_bbu,
coeff_aaau,
coeff_aabu,
coeff_abbu,
coeff_bbbu,
};
Ok((obs.coeff, fixed))
}
impl SurvivalMarginalSlopeFamily {
pub(crate) fn compute_survival_timepoint_exact_jet(
&self,
row: usize,
primary: &FlexPrimarySlices,
q: f64,
q_index: usize,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
o_infl: f64,
) -> Result<SurvivalFlexTimepointExact, String> {
let cached = self.build_cached_partition(primary, a, b, beta_h, beta_w)?;
self.compute_survival_timepoint_exact_jet_from_cached(
row, primary, q, q_index, a, b, beta_h, beta_w, o_infl, &cached,
)
}
pub(crate) fn compute_survival_timepoint_exact_jet_from_cached(
&self,
row: usize,
primary: &FlexPrimarySlices,
q: f64,
q_index: usize,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
o_infl: f64,
cached: &CachedPartitionCells,
) -> Result<SurvivalFlexTimepointExact, String> {
let p = primary.total;
let d_check = self.evaluate_survival_denom_d(a, b, beta_h, beta_w)?;
let z_obs = self.observed_score_projection(row);
let (obs_coeff, obs_fixed) = observed_fixed_for(self, primary, row, a, b, beta_h, beta_w)?;
let cells = cells_from_cached(cached);
let template = Jet2::primary(0.0, usize::MAX, p);
let b_jet = Jet2::primary(b, primary.g, p);
let du: Vec<Jet2> = (0..p).map(|u| Jet2::primary(0.0, u, p)).collect();
let (eta, chi, d) = flex_timepoint_inputs_generic(
&template,
&b_jet,
&du,
a,
d_check,
primary.g,
primary.infl,
q_index,
q,
z_obs,
o_infl,
obs_coeff,
&obs_fixed,
&cells,
)?;
let to_g = |j: &Jet2| Array1::from(j.g.clone());
let to_h = |j: &Jet2| -> Result<Array2<f64>, String> {
Array2::from_shape_vec((p, p), j.h.clone()).map_err(|e| e.to_string())
};
Ok(SurvivalFlexTimepointExact {
eta: eta.value(),
chi: chi.value(),
d: d.value(),
eta_u: to_g(&eta),
eta_uv: to_h(&eta)?,
chi_u: to_g(&chi),
chi_uv: to_h(&chi)?,
d_u: to_g(&d),
d_uv: to_h(&d)?,
})
}
}
impl MomentTerm for Jet3 {
fn moment_term(&self, m: &Self) -> Self {
let base = self.base.moment_term(&m.base);
let eps = jet2_moment_eps(&self.base, &self.eps, &m.base, &m.eps);
Jet3 { base, eps }
}
}
impl MomentTerm for Jet4 {
fn moment_term(&self, m: &Self) -> Self {
let base = self.base.moment_term(&m.base);
let eps = jet2_moment_eps(&self.base, &self.eps, &m.base, &m.eps);
let del = jet2_moment_eps(&self.base, &self.del, &m.base, &m.del);
let eps_del = jet2_moment_eps_del(self, m);
Jet4 {
base,
eps,
del,
eps_del,
}
}
}
fn jet2_moment_eps_del(c: &Jet4, m: &Jet4) -> Jet2 {
let (cb, ca, cd, cad) = (&c.base, &c.eps, &c.del, &c.eps_del);
let (mb, ma, md, mad) = (&m.base, &m.eps, &m.del, &m.eps_del);
let p = cb.p();
let v = 0.5 * ca.v * md.v + cad.v * mb.v + 0.5 * cd.v * ma.v;
let mut g = vec![0.0; p];
for i in 0..p {
g[i] = (1.0 / 3.0) * ca.v * md.g[i]
+ (2.0 / 3.0) * cad.v * mb.g[i]
+ cad.g[i] * mb.v
+ (2.0 / 3.0) * ca.g[i] * md.v
+ (1.0 / 3.0) * cd.v * ma.g[i]
+ (2.0 / 3.0) * cd.g[i] * ma.v
+ (1.0 / 3.0) * cb.g[i] * mad.v;
}
let mut h = vec![0.0; p * p];
for i in 0..p {
for j in 0..p {
let k = i * p + j;
h[k] = 0.25 * ca.v * md.h[k]
+ 0.5 * cad.v * mb.h[k]
+ 0.75 * (cad.g[i] * mb.g[j] + cad.g[j] * mb.g[i])
+ cad.h[k] * mb.v
+ 0.5 * (ca.g[i] * md.g[j] + ca.g[j] * md.g[i])
+ 0.75 * ca.h[k] * md.v
+ 0.25 * cd.v * ma.h[k]
+ 0.5 * (cd.g[i] * ma.g[j] + cd.g[j] * ma.g[i])
+ 0.75 * cd.h[k] * ma.v
+ 0.25 * (cb.g[i] * mad.g[j] + cb.g[j] * mad.g[i])
+ 0.5 * cb.h[k] * mad.v;
}
}
Jet2 { v, g, h }
}
fn jet2_moment_eps(cb: &Jet2, ce: &Jet2, mb: &Jet2, me: &Jet2) -> Jet2 {
let p = cb.p();
let v = ce.v * mb.v;
let mut g = vec![0.0; p];
for i in 0..p {
g[i] = ce.g[i] * mb.v + 0.5 * (ce.v * mb.g[i] + cb.g[i] * me.v);
}
let mut h = vec![0.0; p * p];
for i in 0..p {
for j in 0..p {
h[i * p + j] = ce.h[i * p + j] * mb.v
+ (2.0 / 3.0) * (ce.g[i] * mb.g[j] + ce.g[j] * mb.g[i])
+ (2.0 / 3.0) * cb.h[i * p + j] * me.v
+ (1.0 / 3.0) * ce.v * mb.h[i * p + j]
+ (1.0 / 3.0) * (cb.g[i] * me.g[j] + cb.g[j] * me.g[i]);
}
}
Jet2 { v, g, h }
}
impl SurvivalMarginalSlopeFamily {
pub(crate) fn compute_survival_timepoint_directional_jet_from_cached(
&self,
row: usize,
primary: &FlexPrimarySlices,
q: f64,
q_index: usize,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
cached: &CachedPartitionCells,
dir: &Array1<f64>,
) -> Result<
crate::survival::marginal_slope::gpu::SurvivalFlexBlock10TimepointDirectional,
String,
> {
let p = primary.total;
let d_check = self.evaluate_survival_denom_d(a, b, beta_h, beta_w)?;
let z_obs = self.observed_score_projection(row);
let (obs_coeff, obs_fixed) = observed_fixed_for(self, primary, row, a, b, beta_h, beta_w)?;
let cells = cells_from_cached(cached);
let template = Jet3::primary(0.0, usize::MAX, p, 0.0);
let b_jet = Jet3::primary(b, primary.g, p, dir[primary.g]);
let du: Vec<Jet3> = (0..p).map(|u| Jet3::primary(0.0, u, p, dir[u])).collect();
let (eta, chi, d) = flex_timepoint_inputs_generic(
&template,
&b_jet,
&du,
a,
d_check,
primary.g,
primary.infl,
q_index,
q,
z_obs,
0.0,
obs_coeff,
&obs_fixed,
&cells,
)?;
Ok(
crate::survival::marginal_slope::gpu::SurvivalFlexBlock10TimepointDirectional {
eta_u_dir: eta.eps.g.clone(),
eta_uv_dir: eta.eps.h.clone(),
chi_u_dir: chi.eps.g.clone(),
chi_uv_dir: chi.eps.h.clone(),
d_u_dir: d.eps.g.clone(),
d_uv_dir: d.eps.h.clone(),
},
)
}
pub(crate) fn compute_survival_timepoint_bidirectional_jet_from_cached(
&self,
row: usize,
primary: &FlexPrimarySlices,
q: f64,
q_index: usize,
a: f64,
b: f64,
beta_h: Option<&Array1<f64>>,
beta_w: Option<&Array1<f64>>,
cached: &CachedPartitionCells,
dir1: &Array1<f64>,
dir2: &Array1<f64>,
) -> Result<
crate::survival::marginal_slope::gpu::SurvivalFlexBlock10TimepointBiDirectional,
String,
> {
let p = primary.total;
let d_check = self.evaluate_survival_denom_d(a, b, beta_h, beta_w)?;
let z_obs = self.observed_score_projection(row);
let (obs_coeff, obs_fixed) = observed_fixed_for(self, primary, row, a, b, beta_h, beta_w)?;
let cells = cells_from_cached(cached);
let template = Jet4::primary(0.0, usize::MAX, p, 0.0, 0.0);
let b_jet = Jet4::primary(b, primary.g, p, dir1[primary.g], dir2[primary.g]);
let du: Vec<Jet4> = (0..p)
.map(|u| Jet4::primary(0.0, u, p, dir1[u], dir2[u]))
.collect();
let (eta, chi, d) = flex_timepoint_inputs_generic(
&template,
&b_jet,
&du,
a,
d_check,
primary.g,
primary.infl,
q_index,
q,
z_obs,
0.0,
obs_coeff,
&obs_fixed,
&cells,
)?;
Ok(
crate::survival::marginal_slope::gpu::SurvivalFlexBlock10TimepointBiDirectional {
eta_uv_uv: eta.eps_del.h.clone(),
chi_uv_uv: chi.eps_del.h.clone(),
d_uv_uv: d.eps_del.h.clone(),
},
)
}
}
#[cfg(test)]
mod moment_engine_tests {
use super::*;
use crate::cubic_cell_kernel::{DenestedCubicCell, reduce_sextic_moments};
use crate::marginal_slope_shared::eval_coeff4_at;
use gam_math::jet_scalar::{Order2, filtered_implicit_solve_scalar};
use gam_math::jet_tower::Tower2;
fn lift_intercept_order2<const K: usize>(
d_check: f64,
f_u: &[f64],
f_uv: &[f64],
f_aa: f64,
d_u: &[f64],
a0: f64,
) -> [[f64; K]; K] {
let residual = |a: &Order2<K>| -> Order2<K> {
let ag = a.g();
let ah = a.h();
let mut g = [0.0_f64; K];
let mut h = [[0.0_f64; K]; K];
for i in 0..K {
g[i] = d_check * ag[i] - f_u[i];
}
for i in 0..K {
for j in 0..K {
h[i][j] =
d_check * ah[i][j] + f_aa * ag[i] * ag[j] + d_u[i] * ag[j] + d_u[j] * ag[i]
- f_uv[i * K + j];
}
}
Order2(Tower2 { v: 0.0, g, h })
};
let a = filtered_implicit_solve_scalar::<K, Order2<K>>(a0, 1.0 / d_check, 2, residual);
a.h()
}
impl SurvivalMarginalSlopeFamily {
pub(crate) fn lift_flex_intercept_hessian(
&self,
p: usize,
d_check: f64,
f_u: &Array1<f64>,
f_uv: &Array2<f64>,
f_aa: f64,
d_u: &Array1<f64>,
a0: f64,
) -> Result<Array2<f64>, String> {
let fu = f_u
.as_slice()
.ok_or_else(|| "intercept lift: f_u must be contiguous".to_string())?;
let fuv = f_uv
.as_slice()
.ok_or_else(|| "intercept lift: f_uv must be contiguous".to_string())?;
let du = d_u
.as_slice()
.ok_or_else(|| "intercept lift: d_u must be contiguous".to_string())?;
macro_rules! go {
($k:literal) => {{
let a_uv = lift_intercept_order2::<$k>(d_check, fu, fuv, f_aa, du, a0);
Array2::from_shape_fn((p, p), |(i, j)| a_uv[i][j])
}};
}
let a_uv = match p {
1 => go!(1),
2 => go!(2),
3 => go!(3),
4 => go!(4),
5 => go!(5),
6 => go!(6),
7 => go!(7),
8 => go!(8),
9 => go!(9),
10 => go!(10),
11 => go!(11),
12 => go!(12),
13 => go!(13),
14 => go!(14),
15 => go!(15),
16 => go!(16),
17 => go!(17),
18 => go!(18),
19 => go!(19),
20 => go!(20),
21 => go!(21),
22 => go!(22),
23 => go!(23),
24 => go!(24),
_ => {
let inv = 1.0 / d_check;
let mut a_u = Array1::<f64>::zeros(p);
for u in 0..p {
a_u[u] = fu[u] * inv;
}
let mut a_uv = Array2::<f64>::zeros((p, p));
for u in 0..p {
for v in u..p {
let value = (f_uv[[u, v]]
- d_u[u] * a_u[v]
- d_u[v] * a_u[u]
- f_aa * a_u[u] * a_u[v])
* inv;
a_uv[[u, v]] = value;
a_uv[[v, u]] = value;
}
}
a_uv
}
};
Ok(a_uv)
}
}
fn qprime_coeffs_jet<J: FlexJet>(c: &[J; 4]) -> [J; 6] {
let (c0, c1, c2, c3) = (&c[0], &c[1], &c[2], &c[3]);
let d0 = c0.mul(c1);
let d1 = add_const(&c1.mul(c1).add(&c0.mul(c2).scale(2.0)), 1.0);
let d2 = c0.mul(c3).add(&c1.mul(c2)).scale(3.0);
let d3 = c1.mul(c3).scale(4.0).add(&c2.mul(c2).scale(2.0));
let d4 = c2.mul(c3).scale(5.0);
let d5 = c3.mul(c3).scale(3.0);
[d0, d1, d2, d3, d4, d5]
}
fn cell_q_at_jet<J: FlexJet>(c: &[J; 4], z: &J) -> J {
let eta = c[3].mul(z).add(&c[2]).mul(z).add(&c[1]).mul(z).add(&c[0]);
z.mul(z).add(&eta.mul(&eta)).scale(0.5)
}
fn boundary_edge_term_jet<J: FlexJet>(
c: &[J; 4],
z: &J,
z_pow_n: &J,
finite: bool,
) -> Option<J> {
if !finite {
return None;
}
let q = cell_q_at_jet(c, z);
let w = exp_jet(&q.scale(-1.0));
Some(z_pow_n.mul(&w))
}
fn cell_moment_recurrence_jet<J: FlexJet>(
c: &[J; 4],
z_left: &J,
left_finite: bool,
z_right: &J,
right_finite: bool,
base_m0_m4: &[J; 5],
max_degree: usize,
) -> Vec<J> {
let d = qprime_coeffs_jet(c);
let inv_lead = recip(&d[5]);
let mut moments: Vec<J> = base_m0_m4.iter().cloned().collect();
if max_degree < 5 {
moments.truncate(max_degree + 1);
return moments;
}
let one_l = recip(z_left).mul(z_left);
let one_r = recip(z_right).mul(z_right);
let mut left_pow = one_l;
let mut right_pow = one_r;
for n in 0..=(max_degree - 5) {
let b_left = boundary_edge_term_jet(c, z_left, &left_pow, left_finite);
let b_right = boundary_edge_term_jet(c, z_right, &right_pow, right_finite);
let mut b_n = match (b_right, b_left) {
(Some(r), Some(l)) => r.sub(&l),
(Some(r), None) => r,
(None, Some(l)) => l.scale(-1.0),
(None, None) => moments[0].scale(0.0),
};
let mut numer = if n == 0 {
moments[0].scale(0.0)
} else {
moments[n - 1].scale(n as f64)
};
for j in 0..=4 {
numer = numer.sub(&d[j].mul(&moments[n + j]));
}
numer = numer.sub(&b_n);
moments.push(numer.mul(&inv_lead));
left_pow = if left_finite {
left_pow.mul(z_left)
} else {
b_n.scale(0.0)
};
right_pow = if right_finite {
right_pow.mul(z_right)
} else {
b_n = b_n.scale(0.0);
b_n
};
}
moments
}
fn flex_timepoint_inputs_jet2_impl(
primary: &FlexPrimarySlices,
q_index: usize,
q: f64,
a0: f64,
b: f64,
d_check: f64,
z_obs: f64,
o_infl: f64,
pack: &ObservedCoeffPack,
channels: &FlexChannelInputs<'_>,
cells: &[CalibrationCellJetInputs<'_>],
) -> Result<FlexTimepointJet2Out, String> {
{
let p = primary.total;
let template = Jet2::from_parts(0.0, &vec![0.0; p], &[]);
let b_jet = Jet2::primary(b, primary.g, p);
let du: Vec<Jet2> = (0..p).map(|u| Jet2::primary(0.0, u, p)).collect();
let residual =
|a: &Jet2| calibration_residual_jet(a, &b_jet, primary.g, &du, q_index, q, cells);
let a_jet = lift_intercept_flex(&template, a0, 1.0 / d_check, 2, residual);
let a_u = a_jet.g.clone();
let rho_jet = channel_jet2(p, channels.rho, channels.tau, &a_u, channels.eta_fixed_uv);
let tau_jet =
channel_jet2(p, channels.tau, channels.tau_a, &a_u, channels.chi_fixed_uv);
let b_jet_obs = const_jet_like(&template, b);
let (eta, chi) =
flex_timepoint_eta_chi(&a_jet, &b_jet_obs, z_obs, o_infl, pack, &rho_jet, &tau_jet);
let da = tangent_jet(&a_jet);
let mut d = const_jet_like(&template, 0.0);
for cell in cells {
let c_pos = cell_coeff_jets(
&a_jet,
cell.base_pos_coeffs,
cell.fixed,
primary.g,
&da,
&du,
);
let chi_jets = cell_chi_poly_jets(&a_jet, cell.fixed, primary.g, &da, &du);
let edge_l = cell_edge_jet(&a_jet, &b_jet, cell.left_edge, cell.cell_left);
let edge_r = cell_edge_jet(&a_jet, &b_jet, cell.right_edge, cell.cell_right);
d = d.add(&flex_timepoint_d_cell(
&template,
&c_pos,
&chi_jets,
&edge_l,
cell.cell_left.is_finite(),
&edge_r,
cell.cell_right.is_finite(),
cell.numeric_moments,
));
}
let to_g = |j: &Jet2| Array1::from(j.g.clone());
let to_h = |j: &Jet2| -> Result<Array2<f64>, String> {
Array2::from_shape_vec((p, p), j.h.clone()).map_err(|e| e.to_string())
};
Ok(FlexTimepointJet2Out {
eta: to_g(&eta),
eta_v: eta.value(),
eta_h: to_h(&eta)?,
chi: to_g(&chi),
chi_v: chi.value(),
chi_h: to_h(&chi)?,
d: to_g(&d),
d_v: d.value(),
d_h: to_h(&d)?,
})
}
}
struct FlexChannelInputs<'a> {
rho: &'a [f64],
tau: &'a [f64],
tau_a: &'a [f64],
eta_fixed_uv: &'a Array2<f64>,
chi_fixed_uv: &'a Array2<f64>,
}
fn channel_jet2(
p: usize,
grad: &[f64],
cross: &[f64],
a_u: &[f64],
fixed_uv: &Array2<f64>,
) -> Jet2 {
let mut h = vec![0.0_f64; p * p];
for u in 0..p {
for v in 0..p {
h[u * p + v] = cross[u] * a_u[v] + cross[v] * a_u[u] + fixed_uv[[u, v]];
}
}
Jet2::from_parts(0.0, grad, &h)
}
struct FlexTimepointJet2Out {
eta_v: f64,
eta: Array1<f64>,
eta_h: Array2<f64>,
chi_v: f64,
chi: Array1<f64>,
chi_h: Array2<f64>,
d_v: f64,
d: Array1<f64>,
d_h: Array2<f64>,
}
fn observed_coeff_component_jet<J: FlexJet>(
template: &J,
k: usize,
coeff: [f64; 4],
dc_da: [f64; 4],
dc_db: [f64; 4],
dc_daa: [f64; 4],
dc_dab: [f64; 4],
dc_dbb: [f64; 4],
dc_daaa: [f64; 4],
dc_daab: [f64; 4],
dc_dabb: [f64; 4],
dc_dbbb: [f64; 4],
da: &J,
db: &J,
) -> J {
let dada = da.mul(da);
let dadb = da.mul(db);
let dbdb = db.mul(db);
let mut c = const_jet_like(template, coeff[k]);
c = c.add(&da.scale(dc_da[k])).add(&db.scale(dc_db[k]));
c = c
.add(&dada.scale(0.5 * dc_daa[k]))
.add(&dadb.scale(dc_dab[k]))
.add(&dbdb.scale(0.5 * dc_dbb[k]));
let inv6 = 1.0 / 6.0;
let half = 0.5;
c = c
.add(&dada.mul(da).scale(inv6 * dc_daaa[k]))
.add(&dada.mul(db).scale(half * dc_daab[k]))
.add(&dadb.mul(db).scale(half * dc_dabb[k]))
.add(&dbdb.mul(db).scale(inv6 * dc_dbbb[k]));
c
}
struct ObservedCoeffPack {
coeff: [f64; 4],
dc_da: [f64; 4],
dc_db: [f64; 4],
dc_daa: [f64; 4],
dc_dab: [f64; 4],
dc_dbb: [f64; 4],
dc_daaa: [f64; 4],
dc_daab: [f64; 4],
dc_dabb: [f64; 4],
dc_dbbb: [f64; 4],
}
fn flex_timepoint_eta_chi<J: FlexJet>(
a_jet: &J,
b_jet: &J,
z_obs: f64,
o_infl: f64,
pack: &ObservedCoeffPack,
rho_jet: &J,
tau_jet: &J,
) -> (J, J) {
let da = tangent_jet(a_jet);
let db = tangent_jet(b_jet);
let zero4 = [0.0_f64; 4];
let coeff_jets: [J; 4] = std::array::from_fn(|k| {
observed_coeff_component_jet(
a_jet,
k,
pack.coeff,
pack.dc_da,
pack.dc_db,
pack.dc_daa,
pack.dc_dab,
pack.dc_dbb,
pack.dc_daaa,
pack.dc_daab,
pack.dc_dabb,
pack.dc_dbbb,
&da,
&db,
)
});
let eta = add_const(&eval_coeff_jet_at(&coeff_jets, z_obs), o_infl).add(rho_jet);
let chi_jets: [J; 4] = std::array::from_fn(|k| {
observed_coeff_component_jet(
a_jet,
k,
pack.dc_da,
pack.dc_daa,
pack.dc_dab,
pack.dc_daaa,
pack.dc_daab,
pack.dc_dabb,
zero4,
zero4,
zero4,
zero4,
&da,
&db,
)
});
let chi = eval_coeff_jet_at(&chi_jets, z_obs).add(tau_jet);
(eta, chi)
}
#[test]
fn cell_moment_recurrence_jet_value_matches_numeric_932() {
let cell = DenestedCubicCell {
left: -1.5,
right: 2.0,
c0: 0.3,
c1: -0.4,
c2: 0.5,
c3: 0.2,
};
let base = [1.0_f64, 0.1, 0.6, -0.05, 0.4];
let max_degree = 12usize;
let reference =
reduce_sextic_moments(cell, base, max_degree).expect("numeric sextic moments");
let p = 3usize;
let konst = |x: f64| Jet2::from_parts(x, &vec![0.0; p], &[]);
let c = [
konst(cell.c0),
konst(cell.c1),
konst(cell.c2),
konst(cell.c3),
];
let zl = konst(cell.left);
let zr = konst(cell.right);
let base_jets = [
konst(base[0]),
konst(base[1]),
konst(base[2]),
konst(base[3]),
konst(base[4]),
];
let moments = cell_moment_recurrence_jet(
&c,
&zl,
cell.left.is_finite(),
&zr,
cell.right.is_finite(),
&base_jets,
max_degree,
);
assert_eq!(moments.len(), reference.len(), "moment count");
for (n, (m, r)) in moments.iter().zip(reference.iter()).enumerate() {
assert!(
(m.value() - r).abs() <= 1e-9 * (1.0 + r.abs()),
"moment {n}: jet value {} != numeric {}",
m.value(),
r
);
}
}
#[test]
fn base_moment_jets_first_derivative_matches_fd_932() {
use crate::cubic_cell_kernel::evaluate_cell_moments;
let c0 = [0.25_f64, -0.35, 0.4, 0.15];
let zl0 = -1.2_f64;
let zr0 = 1.7_f64;
let dc = [0.13_f64, 0.21, -0.17, 0.09];
let v_l = -0.23_f64;
let v_r = 0.31_f64;
let cell_at = |theta: f64| DenestedCubicCell {
left: zl0 + theta * v_l,
right: zr0 + theta * v_r,
c0: c0[0] + theta * dc[0],
c1: c0[1] + theta * dc[1],
c2: c0[2] + theta * dc[2],
c3: c0[3] + theta * dc[3],
};
let max_degree = 10usize;
let moments_at = |theta: f64| -> Vec<f64> {
evaluate_cell_moments(cell_at(theta), max_degree)
.expect("numeric cell moments")
.moments
.into_vec()
};
let numeric0 = moments_at(0.0);
let p = 1usize;
let seeded = |x: f64, vel: f64| {
let mut g = vec![0.0; p];
g[0] = vel;
Jet2::from_parts(x, &g, &[])
};
let c_jets = [
seeded(c0[0], dc[0]),
seeded(c0[1], dc[1]),
seeded(c0[2], dc[2]),
seeded(c0[3], dc[3]),
];
let zl_jet = seeded(zl0, v_l);
let zr_jet = seeded(zr0, v_r);
let m_jets = base_moment_jets(&c_jets, &zl_jet, true, &zr_jet, true, &numeric0);
let h = 1e-6_f64;
let mp = moments_at(h);
let mm = moments_at(-h);
for n in 0..5 {
let fd = (mp[n] - mm[n]) / (2.0 * h);
let jet = &m_jets[n];
assert!(
(jet.value() - numeric0[n]).abs() <= 1e-12 * (1.0 + numeric0[n].abs()),
"M_{n} value {} != numeric {}",
jet.value(),
numeric0[n]
);
assert!(
(jet.g[0] - fd).abs() <= 1e-5 * (1.0 + fd.abs()),
"M_{n} dθ analytic {} != FD {}",
jet.g[0],
fd
);
}
}
#[test]
fn base_moment_jets_second_derivative_matches_fd_932() {
use crate::cubic_cell_kernel::evaluate_cell_moments;
let c0 = [0.25_f64, -0.35, 0.4, 0.15];
let zl0 = -1.2_f64;
let zr0 = 1.7_f64;
let dc = [0.13_f64, 0.21, -0.17, 0.09];
let v_l = -0.23_f64;
let v_r = 0.31_f64;
let cell_at = |theta: f64| DenestedCubicCell {
left: zl0 + theta * v_l,
right: zr0 + theta * v_r,
c0: c0[0] + theta * dc[0],
c1: c0[1] + theta * dc[1],
c2: c0[2] + theta * dc[2],
c3: c0[3] + theta * dc[3],
};
let max_degree = 27usize;
let moments_at = |theta: f64| -> Vec<f64> {
evaluate_cell_moments(cell_at(theta), max_degree)
.expect("numeric cell moments")
.moments
.into_vec()
};
let analytic_first = |theta: f64, n: usize| -> f64 {
let numeric = moments_at(theta);
let seeded = |x: f64, vel: f64| {
let g = vec![vel];
Jet2::from_parts(x, &g, &[])
};
let cell = cell_at(theta);
let c_jets = [
seeded(cell.c0, dc[0]),
seeded(cell.c1, dc[1]),
seeded(cell.c2, dc[2]),
seeded(cell.c3, dc[3]),
];
let zl_jet = seeded(cell.left, v_l);
let zr_jet = seeded(cell.right, v_r);
let m = base_moment_jets(&c_jets, &zl_jet, true, &zr_jet, true, &numeric);
m[n].g[0]
};
let numeric0 = moments_at(0.0);
let seeded = |x: f64, vel: f64| {
let g = vec![vel];
Jet2::from_parts(x, &g, &[])
};
let c_jets = [
seeded(c0[0], dc[0]),
seeded(c0[1], dc[1]),
seeded(c0[2], dc[2]),
seeded(c0[3], dc[3]),
];
let zl_jet = seeded(zl0, v_l);
let zr_jet = seeded(zr0, v_r);
let m_jets = base_moment_jets(&c_jets, &zl_jet, true, &zr_jet, true, &numeric0);
let h = 1e-5_f64;
for n in 0..5 {
let fd2 = (analytic_first(h, n) - analytic_first(-h, n)) / (2.0 * h);
let hess = m_jets[n].h[0];
assert!(
(hess - fd2).abs() <= 2e-4 * (1.0 + fd2.abs()),
"M_{n} d²θ analytic {} != FD-of-analytic {}",
hess,
fd2
);
}
}
#[test]
fn flex_timepoint_eta_chi_value_and_grad_932() {
let z_obs = 0.7_f64;
let o_infl = 0.05_f64;
let pack = ObservedCoeffPack {
coeff: [0.2, -0.3, 0.15, 0.05],
dc_da: [1.1, 0.2, 0.03, 0.0],
dc_db: [0.4, 1.05, 0.1, 0.02],
dc_daa: [0.07, 0.02, 0.0, 0.0],
dc_dab: [0.2, 0.09, 0.01, 0.0],
dc_dbb: [0.11, 0.04, 0.005, 0.0],
dc_daaa: [0.003, 0.0, 0.0, 0.0],
dc_daab: [0.006, 0.001, 0.0, 0.0],
dc_dabb: [0.004, 0.002, 0.0, 0.0],
dc_dbbb: [0.008, 0.001, 0.0, 0.0],
};
let a0 = 0.3_f64;
let b0 = 1.2_f64;
let a_u = 0.25_f64;
let b_u = -0.4_f64;
let p = 1usize;
let a_jet = Jet2::from_parts(a0, &[a_u], &[]);
let b_jet = Jet2::from_parts(b0, &[b_u], &[]);
let zero = Jet2::from_parts(0.0, &vec![0.0; p], &[]);
let (eta, chi) = flex_timepoint_eta_chi(&a_jet, &b_jet, z_obs, o_infl, &pack, &zero, &zero);
let coeff_scalar = |da: f64, db: f64| -> [f64; 4] {
std::array::from_fn(|k| {
pack.coeff[k]
+ pack.dc_da[k] * da
+ pack.dc_db[k] * db
+ 0.5 * pack.dc_daa[k] * da * da
+ pack.dc_dab[k] * da * db
+ 0.5 * pack.dc_dbb[k] * db * db
+ pack.dc_daaa[k] * da * da * da / 6.0
+ 0.5 * pack.dc_daab[k] * da * da * db
+ 0.5 * pack.dc_dabb[k] * da * db * db
+ pack.dc_dbbb[k] * db * db * db / 6.0
})
};
let eta_scalar = |theta: f64| -> f64 {
let c = coeff_scalar(a_u * theta, b_u * theta);
eval_coeff4_scalar(&c, z_obs) + o_infl
};
let chi_scalar = |theta: f64| -> f64 {
let dc = coeff_scalar_da(&pack, a_u * theta, b_u * theta);
eval_coeff4_scalar(&dc, z_obs)
};
assert!(
(eta.value() - eta_scalar(0.0)).abs() <= 1e-12 * (1.0 + eta_scalar(0.0).abs()),
"eta value {} != {}",
eta.value(),
eta_scalar(0.0)
);
assert!(
(chi.value() - chi_scalar(0.0)).abs() <= 1e-12 * (1.0 + chi_scalar(0.0).abs()),
"chi value {} != {}",
chi.value(),
chi_scalar(0.0)
);
let h = 1e-6_f64;
let eta_fd = (eta_scalar(h) - eta_scalar(-h)) / (2.0 * h);
let chi_fd = (chi_scalar(h) - chi_scalar(-h)) / (2.0 * h);
assert!(
(eta.g[0] - eta_fd).abs() <= 1e-5 * (1.0 + eta_fd.abs()),
"eta grad {} != FD {}",
eta.g[0],
eta_fd
);
assert!(
(chi.g[0] - chi_fd).abs() <= 1e-5 * (1.0 + chi_fd.abs()),
"chi grad {} != FD {}",
chi.g[0],
chi_fd
);
}
fn eval_coeff4_scalar(c: &[f64; 4], z: f64) -> f64 {
let mut acc = 0.0;
for &ck in c.iter().rev() {
acc = acc * z + ck;
}
acc
}
fn coeff_scalar_da(pack: &ObservedCoeffPack, da: f64, db: f64) -> [f64; 4] {
std::array::from_fn(|k| {
pack.dc_da[k]
+ pack.dc_daa[k] * da
+ pack.dc_dab[k] * db
+ 0.5 * pack.dc_daaa[k] * da * da
+ pack.dc_daab[k] * da * db
+ 0.5 * pack.dc_dabb[k] * db * db
})
}
#[test]
fn cell_coeff_jets_value_and_grad_932() {
let p = 3usize;
let g_axis = 1usize;
let base_c = [0.2_f64, -0.3, 0.15, 0.05];
let mk_run = |seed: f64| -> Vec<[f64; 4]> {
(0..p)
.map(|u| std::array::from_fn(|k| seed * (1.0 + u as f64) * (1.0 + k as f64) * 0.01))
.collect()
};
let fixed = DenestedCellPrimaryFixedPartials {
dc_da: [1.1, 0.2, 0.03, 0.0],
dc_daa: [0.07, 0.02, 0.0, 0.0],
dc_daaa: [0.003, 0.0, 0.0, 0.0],
coeff_u: mk_run(0.9),
coeff_au: mk_run(0.4),
coeff_bu: mk_run(0.5),
coeff_aau: mk_run(0.12),
coeff_abu: mk_run(0.16),
coeff_bbu: mk_run(0.11),
coeff_aaau: mk_run(0.02),
coeff_aabu: mk_run(0.03),
coeff_abbu: mk_run(0.04),
coeff_bbbu: mk_run(0.05),
};
let a0 = 0.3_f64;
let a_u = 0.25_f64;
let v = [0.2_f64, -0.4, 0.33];
let seeded = |x: f64, vel: f64| {
let g = vec![vel];
Jet2::from_parts(x, &g, &[])
};
let a_jet = seeded(a0, a_u);
let da = tangent_jet(&a_jet);
let du: Vec<Jet2> = (0..p).map(|u| seeded(0.0, v[u])).collect();
let jets = cell_coeff_jets(&a_jet, base_c, &fixed, g_axis, &da, &du);
let scalar_c = |theta: f64| -> [f64; 4] {
let da = a_u * theta;
let db = v[g_axis] * theta;
std::array::from_fn(|k| {
let mut acc = base_c[k]
+ fixed.dc_da[k] * da
+ 0.5 * fixed.dc_daa[k] * da * da
+ fixed.dc_daaa[k] * da * da * da / 6.0;
for u in 0..p {
if u == g_axis {
continue;
}
let duu = v[u] * theta;
acc += fixed.coeff_u[u][k] * duu
+ fixed.coeff_au[u][k] * da * duu
+ 0.5 * fixed.coeff_aau[u][k] * da * da * duu
+ fixed.coeff_bu[u][k] * db * duu
+ fixed.coeff_abu[u][k] * da * db * duu
+ 0.5 * fixed.coeff_bbu[u][k] * db * db * duu
+ fixed.coeff_aaau[u][k] * da * da * da * duu / 6.0
+ 0.5 * fixed.coeff_aabu[u][k] * da * da * db * duu
+ 0.5 * fixed.coeff_abbu[u][k] * da * db * db * duu
+ fixed.coeff_bbbu[u][k] * db * db * db * duu / 6.0;
}
acc += fixed.coeff_u[g_axis][k] * db
+ fixed.coeff_au[g_axis][k] * da * db
+ 0.5 * fixed.coeff_aau[g_axis][k] * da * da * db
+ 0.5 * fixed.coeff_bu[g_axis][k] * db * db
+ 0.5 * fixed.coeff_abu[g_axis][k] * da * db * db
+ fixed.coeff_bbu[g_axis][k] * db * db * db / 6.0;
acc
})
};
let h = 1e-6_f64;
let c0 = scalar_c(0.0);
let cp = scalar_c(h);
let cm = scalar_c(-h);
for k in 0..4 {
assert!(
(jets[k].value() - c0[k]).abs() <= 1e-12 * (1.0 + c0[k].abs()),
"c_{k} value {} != {}",
jets[k].value(),
c0[k]
);
let fd = (cp[k] - cm[k]) / (2.0 * h);
assert!(
(jets[k].g[0] - fd).abs() <= 1e-5 * (1.0 + fd.abs()),
"c_{k} grad {} != FD {}",
jets[k].g[0],
fd
);
}
let da_iso = Jet2::from_parts(0.0, &vec![0.0; p], &[]);
let du_iso: Vec<Jet2> = (0..p).map(|u| Jet2::primary(0.0, u, p)).collect();
let jets_iso = cell_coeff_jets(&da_iso, base_c, &fixed, g_axis, &da_iso, &du_iso);
for k in 0..4 {
let hgg = jets_iso[k].h[g_axis * p + g_axis];
assert!(
(hgg - fixed.coeff_bu[g_axis][k]).abs()
<= 1e-12 * (1.0 + fixed.coeff_bu[g_axis][k].abs()),
"c_{k} Hess[g,g] {} != dc_dbb {} (2× = the pre-fix g-diagonal bug)",
hgg,
fixed.coeff_bu[g_axis][k]
);
}
}
#[test]
fn flex_timepoint_d_cell_value_and_grad_932() {
use crate::cubic_cell_kernel::evaluate_cell_moments;
let zl = -1.1_f64;
let zr = 1.6_f64;
let c_base = [0.2_f64, -0.3, 0.18, 0.06];
let dc_da = [1.05_f64, 0.22, 0.04, 0.0];
let dc_daa = [0.08_f64, 0.03, 0.0, 0.0];
let dc_daaa = [0.004_f64, 0.0, 0.0, 0.0];
let cell_at = |theta: f64| {
let c: [f64; 4] = std::array::from_fn(|k| {
c_base[k]
+ dc_da[k] * theta
+ 0.5 * dc_daa[k] * theta * theta
+ dc_daaa[k] * theta * theta * theta / 6.0
});
DenestedCubicCell {
left: zl,
right: zr,
c0: c[0],
c1: c[1],
c2: c[2],
c3: c[3],
}
};
let dc_da_at = |theta: f64| -> [f64; 4] {
std::array::from_fn(|k| dc_da[k] + dc_daa[k] * theta + 0.5 * dc_daaa[k] * theta * theta)
};
let max_degree = 10usize;
let moments_at = |theta: f64| -> Vec<f64> {
evaluate_cell_moments(cell_at(theta), max_degree)
.expect("numeric cell moments")
.moments
.into_vec()
};
let d_scalar = |theta: f64| -> f64 {
let m = moments_at(theta);
let chi = dc_da_at(theta);
let mut acc = 0.0;
for k in 0..4 {
acc += chi[k] * m[k];
}
acc * std::f64::consts::TAU.recip()
};
let seeded = |x: f64, vel: f64| {
let g = vec![vel];
Jet2::from_parts(x, &g, &[])
};
let cell0 = cell_at(0.0);
let c_jets = [
seeded(cell0.c0, dc_da[0]),
seeded(cell0.c1, dc_da[1]),
seeded(cell0.c2, dc_da[2]),
seeded(cell0.c3, dc_da[3]),
];
let dc_da0 = dc_da_at(0.0);
let chi_jets = [
seeded(dc_da0[0], dc_daa[0]),
seeded(dc_da0[1], dc_daa[1]),
seeded(dc_da0[2], dc_daa[2]),
seeded(dc_da0[3], dc_daa[3]),
];
let template = seeded(0.0, 0.0);
let edge_l = seeded(zl, 0.0); let edge_r = seeded(zr, 0.0);
let numeric0 = moments_at(0.0);
let d_jet = flex_timepoint_d_cell(
&template, &c_jets, &chi_jets, &edge_l, true, &edge_r, true, &numeric0,
);
assert!(
(d_jet.value() - d_scalar(0.0)).abs() <= 1e-10 * (1.0 + d_scalar(0.0).abs()),
"D value {} != {}",
d_jet.value(),
d_scalar(0.0)
);
let h = 1e-6_f64;
let fd = (d_scalar(h) - d_scalar(-h)) / (2.0 * h);
assert!(
(d_jet.g[0] - fd).abs() <= 1e-4 * (1.0 + fd.abs()),
"D grad (d_u) {} != FD {}",
d_jet.g[0],
fd
);
}
#[test]
fn cell_chi_poly_jets_value_and_grad_932() {
let p = 2usize;
let g_axis = 1usize;
let mk_run = |seed: f64| -> Vec<[f64; 4]> {
(0..p)
.map(|u| std::array::from_fn(|k| seed * (1.0 + u as f64) * (1.0 + k as f64) * 0.01))
.collect()
};
let fixed = DenestedCellPrimaryFixedPartials {
dc_da: [1.05, 0.22, 0.04, 0.0],
dc_daa: [0.08, 0.03, 0.0, 0.0],
dc_daaa: [0.004, 0.0, 0.0, 0.0],
coeff_u: mk_run(0.9),
coeff_au: mk_run(0.4),
coeff_bu: mk_run(0.5),
coeff_aau: mk_run(0.12),
coeff_abu: mk_run(0.16),
coeff_bbu: mk_run(0.11),
coeff_aaau: mk_run(0.02),
coeff_aabu: mk_run(0.03),
coeff_abbu: mk_run(0.04),
coeff_bbbu: mk_run(0.05),
};
let a_u = 0.25_f64;
let v = [0.2_f64, -0.4];
let seeded = |x: f64, vel: f64| {
let g = vec![vel];
Jet2::from_parts(x, &g, &[])
};
let a_jet = seeded(0.3, a_u);
let da = tangent_jet(&a_jet);
let du: Vec<Jet2> = (0..p).map(|u| seeded(0.0, v[u])).collect();
let chi = cell_chi_poly_jets(&a_jet, &fixed, g_axis, &da, &du);
let chi_scalar = |theta: f64| -> [f64; 4] {
let da = a_u * theta;
let db = v[g_axis] * theta;
std::array::from_fn(|k| {
let mut acc =
fixed.dc_da[k] + fixed.dc_daa[k] * da + 0.5 * fixed.dc_daaa[k] * da * da;
for u in 0..p {
let duu = v[u] * theta;
acc += fixed.coeff_au[u][k] * duu
+ fixed.coeff_aau[u][k] * da * duu
+ fixed.coeff_abu[u][k] * db * duu;
}
acc
})
};
let h = 1e-6_f64;
let c0 = chi_scalar(0.0);
let cp = chi_scalar(h);
let cm = chi_scalar(-h);
for k in 0..4 {
assert!(
(chi[k].value() - c0[k]).abs() <= 1e-12 * (1.0 + c0[k].abs()),
"chi_{k} value {} != {}",
chi[k].value(),
c0[k]
);
let fd = (cp[k] - cm[k]) / (2.0 * h);
assert!(
(chi[k].g[0] - fd).abs() <= 1e-5 * (1.0 + fd.abs()),
"chi_{k} grad {} != FD {}",
chi[k].g[0],
fd
);
}
}
#[test]
fn lift_intercept_flex_first_order_matches_hand_ift_932() {
use crate::cubic_cell_kernel::{
DenestedCubicCell, PartitionEdge, cell_first_derivative_from_moments,
evaluate_cell_moments,
};
let cell = DenestedCubicCell {
left: -1.0,
right: 1.4,
c0: 0.2,
c1: -0.25,
c2: 0.15,
c3: 0.05,
};
let numeric = evaluate_cell_moments(cell, 9)
.expect("numeric moments")
.moments
.into_vec();
let base_pos = [cell.c0, cell.c1, cell.c2, cell.c3];
let p = 1usize;
let dc_da = [0.9_f64, 0.2, 0.05, 0.0];
let coeff_u0 = [0.3_f64, -0.15, 0.08, 0.0];
let zero_run: Vec<[f64; 4]> = vec![[0.0; 4]; p];
let mut coeff_u = zero_run.clone();
coeff_u[0] = coeff_u0;
let fixed = DenestedCellPrimaryFixedPartials {
dc_da,
dc_daa: [0.0; 4],
dc_daaa: [0.0; 4],
coeff_u,
coeff_au: zero_run.clone(),
coeff_bu: zero_run.clone(),
coeff_aau: zero_run.clone(),
coeff_abu: zero_run.clone(),
coeff_bbu: zero_run.clone(),
coeff_aaau: zero_run.clone(),
coeff_aabu: zero_run.clone(),
coeff_abbu: zero_run.clone(),
coeff_bbbu: zero_run,
};
let neg_coeff_u0 = coeff_u0.map(|v| -v);
let neg_dc_da = dc_da.map(|v| -v);
let f_u0 = cell_first_derivative_from_moments(&neg_coeff_u0, &numeric).expect("f_u");
let f_a = cell_first_derivative_from_moments(&neg_dc_da, &numeric).expect("f_a");
let d_check = f_a.abs();
let a_u_hand = f_u0 / d_check;
let a0 = 0.31_f64;
let template = Jet2::from_parts(0.0, &vec![0.0; p], &[]);
let a_jet0 = Jet2::from_parts(a0, &vec![0.0; p], &[]);
let b_jet = Jet2::from_parts(1.1, &vec![0.0; p], &[]);
let du: Vec<Jet2> = (0..p).map(|u| Jet2::primary(0.0, u, p)).collect();
let cells = vec![CalibrationCellJetInputs {
base_pos_coeffs: base_pos,
fixed: &fixed,
cell_left: cell.left,
cell_right: cell.right,
left_edge: PartitionEdge::Fixed(cell.left),
right_edge: PartitionEdge::Fixed(cell.right),
numeric_moments: &numeric,
}];
let r0 = calibration_residual_jet(&a_jet0, &b_jet, 0, &du, p, 0.0, &cells);
assert!(
(r0.g[0] - (-f_u0)).abs() <= 1e-9 * (1.0 + f_u0.abs()),
"residual grad {} != -f_u {}",
r0.g[0],
-f_u0
);
let residual = |a: &Jet2| calibration_residual_jet(a, &b_jet, 0, &du, p, 0.0, &cells);
let a_lift = lift_intercept_flex(&template, a0, 1.0 / d_check, 2, residual);
assert!(
(a_lift.g[0] - a_u_hand).abs() <= 1e-6 * (1.0 + a_u_hand.abs()),
"lifted a_u {} != hand f_u/D {}",
a_lift.g[0],
a_u_hand
);
}
#[test]
fn flex_timepoint_inputs_jet2_assembly_composes_932() {
use crate::cubic_cell_kernel::{
DenestedCubicCell, PartitionEdge, cell_first_derivative_from_moments,
evaluate_cell_moments,
};
let p = 2usize;
let primary = FlexPrimarySlices {
q0: 0,
q1: 0,
qd1: 0,
g: 1,
h: None,
w: None,
infl: None,
total: p,
};
let cell = DenestedCubicCell {
left: -1.0,
right: 1.4,
c0: 0.2,
c1: -0.25,
c2: 0.15,
c3: 0.05,
};
let numeric = evaluate_cell_moments(cell, 27)
.expect("numeric moments")
.moments
.into_vec();
let dc_da = [0.9_f64, 0.2, 0.05, 0.0];
let zero_run: Vec<[f64; 4]> = vec![[0.0; 4]; p];
let fixed = DenestedCellPrimaryFixedPartials {
dc_da,
dc_daa: [0.0; 4],
dc_daaa: [0.0; 4],
coeff_u: zero_run.clone(),
coeff_au: zero_run.clone(),
coeff_bu: zero_run.clone(),
coeff_aau: zero_run.clone(),
coeff_abu: zero_run.clone(),
coeff_bbu: zero_run.clone(),
coeff_aaau: zero_run.clone(),
coeff_aabu: zero_run.clone(),
coeff_abbu: zero_run.clone(),
coeff_bbbu: zero_run,
};
let neg_dc_da = dc_da.map(|v| -v);
let f_a = cell_first_derivative_from_moments(&neg_dc_da, &numeric).expect("f_a");
let d_check = f_a.abs();
let z_obs = 0.6_f64;
let o_infl = 0.04_f64;
let pack = ObservedCoeffPack {
coeff: [0.2, -0.3, 0.15, 0.05],
dc_da: [1.1, 0.2, 0.03, 0.0],
dc_db: [0.4, 1.05, 0.1, 0.02],
dc_daa: [0.07, 0.02, 0.0, 0.0],
dc_dab: [0.2, 0.09, 0.01, 0.0],
dc_dbb: [0.11, 0.04, 0.005, 0.0],
dc_daaa: [0.003, 0.0, 0.0, 0.0],
dc_daab: [0.006, 0.001, 0.0, 0.0],
dc_dabb: [0.004, 0.002, 0.0, 0.0],
dc_dbbb: [0.008, 0.001, 0.0, 0.0],
};
let rho = vec![0.0_f64; p];
let tau = vec![0.0_f64; p];
let tau_a = vec![0.0_f64; p];
let eta_fixed_uv = Array2::<f64>::zeros((p, p));
let chi_fixed_uv = Array2::<f64>::zeros((p, p));
let channels = FlexChannelInputs {
rho: &rho,
tau: &tau,
tau_a: &tau_a,
eta_fixed_uv: &eta_fixed_uv,
chi_fixed_uv: &chi_fixed_uv,
};
let cells = vec![CalibrationCellJetInputs {
base_pos_coeffs: [cell.c0, cell.c1, cell.c2, cell.c3],
fixed: &fixed,
cell_left: cell.left,
cell_right: cell.right,
left_edge: PartitionEdge::Fixed(cell.left),
right_edge: PartitionEdge::Fixed(cell.right),
numeric_moments: &numeric,
}];
let out = flex_timepoint_inputs_jet2_impl(
&primary, primary.q1, 0.0, 0.31, 1.1, d_check, z_obs, o_infl, &pack, &channels, &cells,
)
.expect("jet timepoint inputs");
let eta_ref = {
let mut acc = 0.0;
for &c in pack.coeff.iter().rev() {
acc = acc * z_obs + c;
}
acc + o_infl
};
let chi_ref = {
let mut acc = 0.0;
for &c in pack.dc_da.iter().rev() {
acc = acc * z_obs + c;
}
acc
};
let d_ref = {
let mut acc = 0.0;
for k in 0..4 {
acc += dc_da[k] * numeric[k];
}
acc * std::f64::consts::TAU.recip()
};
assert!(
(out.eta_v - eta_ref).abs() <= 1e-9 * (1.0 + eta_ref.abs()),
"eta value {} != {}",
out.eta_v,
eta_ref
);
assert!(
(out.chi_v - chi_ref).abs() <= 1e-9 * (1.0 + chi_ref.abs()),
"chi value {} != {}",
out.chi_v,
chi_ref
);
assert!(
(out.d_v - d_ref).abs() <= 1e-9 * (1.0 + d_ref.abs()),
"d value {} != {}",
out.d_v,
d_ref
);
assert_eq!(out.eta.len(), p);
for arr in [&out.eta, &out.chi, &out.d] {
for v in arr.iter() {
assert!(v.is_finite(), "gradient channel finite");
}
}
for mat in [&out.eta_h, &out.chi_h, &out.d_h] {
assert_eq!(mat.shape(), [p, p]);
for v in mat.iter() {
assert!(v.is_finite(), "Hessian channel finite");
}
}
}
#[test]
fn flex_timepoint_inputs_jet2_hessian_matches_hand_channel_coupling_932() {
use crate::cubic_cell_kernel::{
DenestedCubicCell, PartitionEdge, cell_first_derivative_from_moments,
evaluate_cell_moments,
};
let p = 4usize;
let g_axis = 1usize;
let h_axis = 2usize;
let w_axis = 3usize;
let primary = FlexPrimarySlices {
q0: 0,
q1: 0,
qd1: 0,
g: g_axis,
h: Some(h_axis..h_axis + 1),
w: Some(w_axis..w_axis + 1),
infl: None,
total: p,
};
let cell = DenestedCubicCell {
left: -1.1,
right: 1.3,
c0: 0.22,
c1: -0.18,
c2: 0.13,
c3: 0.04,
};
let numeric = evaluate_cell_moments(cell, 27)
.expect("numeric moments")
.moments
.into_vec();
let dc_da = [0.85_f64, 0.21, 0.06, 0.0];
let mk_run = |base: [f64; 4], step: f64| -> Vec<[f64; 4]> {
(0..p)
.map(|u| base.map(|c| c * (0.1 + step * (u as f64 + 1.0))))
.collect()
};
let fixed = DenestedCellPrimaryFixedPartials {
dc_da,
dc_daa: [0.05, 0.02, 0.0, 0.0],
dc_daaa: [0.004, 0.0, 0.0, 0.0],
coeff_u: mk_run([0.3, 0.1, 0.02, 0.0], 0.07),
coeff_au: mk_run([0.12, 0.04, 0.0, 0.0], 0.05),
coeff_bu: mk_run([0.09, 0.03, 0.0, 0.0], 0.04),
coeff_aau: mk_run([0.02, 0.0, 0.0, 0.0], 0.01),
coeff_abu: mk_run([0.015, 0.0, 0.0, 0.0], 0.01),
coeff_bbu: mk_run([0.01, 0.0, 0.0, 0.0], 0.008),
coeff_aaau: vec![[0.0; 4]; p],
coeff_aabu: vec![[0.0; 4]; p],
coeff_abbu: vec![[0.0; 4]; p],
coeff_bbbu: vec![[0.0; 4]; p],
};
let neg_dc_da = dc_da.map(|v| -v);
let f_a = cell_first_derivative_from_moments(&neg_dc_da, &numeric).expect("f_a");
let d_check = f_a.abs();
let z_obs = 0.55_f64;
let o_infl = 0.0_f64;
let b = 1.07_f64;
let a0 = 0.29_f64;
let pack = ObservedCoeffPack {
coeff: [0.21, -0.27, 0.14, 0.05],
dc_da: [1.05, 0.19, 0.04, 0.0],
dc_db: [0.41, 1.02, 0.09, 0.02],
dc_daa: [0.08, 0.03, 0.0, 0.0],
dc_dab: [0.22, 0.1, 0.012, 0.0],
dc_dbb: [0.13, 0.05, 0.006, 0.0],
dc_daaa: [0.0035, 0.0, 0.0, 0.0],
dc_daab: [0.007, 0.0012, 0.0, 0.0],
dc_dabb: [0.0045, 0.0023, 0.0, 0.0],
dc_dbbb: [0.0085, 0.0011, 0.0, 0.0],
};
let mut rho = vec![0.0_f64; p];
let mut tau = vec![0.0_f64; p];
let mut tau_a = vec![0.0_f64; p];
rho[g_axis] = eval_coeff4_at(&pack.dc_db, z_obs);
rho[h_axis] = 0.37;
rho[w_axis] = 0.29;
tau[g_axis] = eval_coeff4_at(&pack.dc_dab, z_obs);
tau[w_axis] = 0.18;
tau_a[g_axis] = eval_coeff4_at(&pack.dc_daab, z_obs);
tau_a[w_axis] = 0.11;
let mut eta_fixed_uv = Array2::<f64>::zeros((p, p));
let mut chi_fixed_uv = Array2::<f64>::zeros((p, p));
let set_sym = |m: &mut Array2<f64>, i: usize, j: usize, v: f64| {
m[[i, j]] = v;
m[[j, i]] = v;
};
set_sym(&mut eta_fixed_uv, g_axis, g_axis, 0.14);
set_sym(&mut eta_fixed_uv, g_axis, h_axis, 0.21);
set_sym(&mut eta_fixed_uv, g_axis, w_axis, 0.17);
set_sym(&mut chi_fixed_uv, g_axis, g_axis, 0.09);
set_sym(&mut chi_fixed_uv, g_axis, w_axis, 0.12);
let channels = FlexChannelInputs {
rho: &rho,
tau: &tau,
tau_a: &tau_a,
eta_fixed_uv: &eta_fixed_uv,
chi_fixed_uv: &chi_fixed_uv,
};
let cells = vec![CalibrationCellJetInputs {
base_pos_coeffs: [cell.c0, cell.c1, cell.c2, cell.c3],
fixed: &fixed,
cell_left: cell.left,
cell_right: cell.right,
left_edge: PartitionEdge::Crossing {
tau: cell.left * b + a0,
},
right_edge: PartitionEdge::Crossing {
tau: cell.right * b + a0,
},
numeric_moments: &numeric,
}];
let out = flex_timepoint_inputs_jet2_impl(
&primary, primary.q1, 0.0, a0, b, d_check, z_obs, o_infl, &pack, &channels, &cells,
)
.expect("jet timepoint inputs");
let template = Jet2::from_parts(0.0, &vec![0.0; p], &[]);
let b_jet = Jet2::primary(b, g_axis, p);
let du: Vec<Jet2> = (0..p).map(|u| Jet2::primary(0.0, u, p)).collect();
let residual =
|a: &Jet2| calibration_residual_jet(a, &b_jet, g_axis, &du, primary.q1, 0.0, &cells);
let a_jet = lift_intercept_flex(&template, a0, 1.0 / d_check, 2, residual);
let a_u = a_jet.g.clone();
let a_uv = |u: usize, v: usize| a_jet.h[u * p + v];
let chi = eval_coeff4_at(&pack.dc_da, z_obs);
let eta_aa = eval_coeff4_at(&pack.dc_daa, z_obs);
let eta_aaa = eval_coeff4_at(&pack.dc_daaa, z_obs);
for u in 0..p {
for v in 0..p {
let eta_hand = chi * a_uv(u, v)
+ eta_aa * a_u[u] * a_u[v]
+ tau[u] * a_u[v]
+ tau[v] * a_u[u]
+ eta_fixed_uv[[u, v]];
let chi_hand = eta_aa * a_uv(u, v)
+ eta_aaa * a_u[u] * a_u[v]
+ tau_a[u] * a_u[v]
+ tau_a[v] * a_u[u]
+ chi_fixed_uv[[u, v]];
assert!(
(out.eta_h[[u, v]] - eta_hand).abs() <= 1e-9 * (1.0 + eta_hand.abs()),
"eta_uv[{u},{v}] jet {} != hand {}",
out.eta_h[[u, v]],
eta_hand
);
assert!(
(out.chi_h[[u, v]] - chi_hand).abs() <= 1e-9 * (1.0 + chi_hand.abs()),
"chi_uv[{u},{v}] jet {} != hand {}",
out.chi_h[[u, v]],
chi_hand
);
}
}
for u in 0..p {
let eta_u_hand = chi * a_u[u] + rho[u];
let chi_u_hand = eta_aa * a_u[u] + tau[u];
assert!(
(out.eta[u] - eta_u_hand).abs() <= 1e-9 * (1.0 + eta_u_hand.abs()),
"eta_u[{u}] jet {} != hand {}",
out.eta[u],
eta_u_hand
);
assert!(
(out.chi[u] - chi_u_hand).abs() <= 1e-9 * (1.0 + chi_u_hand.abs()),
"chi_u[{u}] jet {} != hand {}",
out.chi[u],
chi_u_hand
);
}
for u in 0..p {
for v in 0..p {
assert!(out.d_h[[u, v]].is_finite(), "d_uv finite");
assert!(
(out.d_h[[u, v]] - out.d_h[[v, u]]).abs() <= 1e-9,
"d_uv symmetric at [{u},{v}]"
);
}
}
}
fn make_g_only_flex_family(n: usize) -> SurvivalMarginalSlopeFamily {
let event: Array1<f64> =
Array1::from_iter((0..n).map(|i| if (i * 31 + 7) % 5 >= 3 { 1.0 } else { 0.0 }));
let weights: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.5 + ((i * 13 + 4) % 5) as f64 * 0.1));
let z: Array1<f64> = Array1::from_iter(
(0..n).map(|i| -1.0 + 2.0 * (((i * 17 + 5) % n) as f64 + 0.5) / (n as f64)),
);
let offset_entry: Array1<f64> = Array1::from_iter(
(0..n).map(|i| -0.4 + 0.7 * (((i * 11 + 3) % n) as f64 + 0.5) / (n as f64)),
);
let offset_exit: Array1<f64> = Array1::from_iter(
(0..n).map(|i| 0.1 + 0.6 * (((i * 19 + 7) % n) as f64 + 0.5) / (n as f64)),
);
let derivative_offset_exit: Array1<f64> =
Array1::from_iter((0..n).map(|i| 0.5 + 0.05 * ((i * 23 + 1) % 3) as f64));
let marginal_design = Array2::from_shape_fn((n, 1), |(i, _)| {
0.3 + 0.4 * (((i * 29 + 11) % n) as f64) / (n as f64)
});
let logslope_design = Array2::from_shape_fn((n, 1), |(i, _)| {
0.2 + 0.5 * (((i * 37 + 9) % n) as f64) / (n as f64)
});
SurvivalMarginalSlopeFamily {
n,
event: Arc::new(event),
weights: Arc::new(weights),
z: Arc::new(z.insert_axis(Axis(1))),
score_covariance: MarginalSlopeCovariance::Diagonal(Array1::from(vec![1.0])),
gaussian_frailty_sd: None,
derivative_guard: 1e-6,
design_entry: DesignMatrix::from(Array2::zeros((n, 0))),
design_exit: DesignMatrix::from(Array2::zeros((n, 0))),
design_derivative_exit: DesignMatrix::from(Array2::zeros((n, 0))),
offset_entry: Arc::new(offset_entry),
offset_exit: Arc::new(offset_exit),
derivative_offset_exit: Arc::new(derivative_offset_exit),
marginal_design: DesignMatrix::from(marginal_design),
logslope_design: DesignMatrix::from(logslope_design),
logslope_surface_ranges: vec![0..0],
score_warp: None,
link_dev: None,
influence_absorber: None,
time_linear_constraints: None,
time_wiggle_knots: None,
time_wiggle_degree: None,
time_wiggle_ncols: 0,
intercept_warm_starts: None,
auto_subsample_phase_counter: Arc::new(AtomicUsize::new(0)),
auto_subsample_last_rho: Arc::new(Mutex::new(None)),
}
}
#[test]
fn flex_timepoint_inputs_jet3_directional_matches_hand_932() {
let n = 16usize;
let family = make_g_only_flex_family(n);
let primary = flex_primary_slices(&family);
let p = primary.total;
let row = 5usize;
let g = 0.21_f64;
let m_beta = 0.15_f64;
let q1 = family.offset_exit[row] + family.marginal_design.to_dense()[[row, 0]] * m_beta;
let o_infl = 0.0_f64;
let (a1, d1) = family
.solve_row_survival_intercept_with_slot(
q1,
g,
None,
None,
Some((row, SurvivalInterceptSlotKind::Exit)),
)
.expect("intercept solve");
let cached = family
.build_cached_partition(&primary, a1, g, None, None)
.expect("cached partition");
let dir =
Array1::from_iter((0..p).map(|c| 0.1 + 0.05 * (c as f64) - 0.02 * ((c % 3) as f64)));
let hand = family
.compute_survival_timepoint_directional_exact_from_cached(
row, &primary, q1, primary.q1, a1, g, None, None, &cached, &dir, true,
)
.expect("hand directional");
let (obs_coeff, obs_fixed) =
observed_fixed_for(&family, &primary, row, a1, g, None, None).expect("obs fixed");
let cells = cells_from_cached(&cached);
let z_obs = family.observed_score_projection(row);
let d_check = family
.evaluate_survival_denom_d(a1, g, None, None)
.expect("denom");
let template = Jet3::primary(0.0, usize::MAX, p, 0.0);
let b_jet = Jet3::primary(g, primary.g, p, dir[primary.g]);
let du: Vec<Jet3> = (0..p).map(|u| Jet3::primary(0.0, u, p, dir[u])).collect();
let (eta, chi, dnorm) = flex_timepoint_inputs_generic(
&template,
&b_jet,
&du,
a1,
d_check,
primary.g,
primary.infl,
primary.q1,
q1,
z_obs,
o_infl,
obs_coeff,
&obs_fixed,
&cells,
)
.expect("generic jet3");
let cmp_vec = |label: &str, jet: &Vec<f64>, hand: &[f64]| {
for u in 0..p {
assert!(
(jet[u] - hand[u]).abs() <= 1e-6 * (1.0 + hand[u].abs()),
"{label}[{u}] jet {} != hand {}",
jet[u],
hand[u]
);
}
};
let cmp_mat = |label: &str, jet: &Vec<f64>, hand: &Array2<f64>| {
for u in 0..p {
for v in 0..p {
assert!(
(jet[u * p + v] - hand[[u, v]]).abs() <= 1e-6 * (1.0 + hand[[u, v]].abs()),
"{label}[{u},{v}] jet {} != hand {}",
jet[u * p + v],
hand[[u, v]]
);
}
}
};
let base = family
.compute_survival_timepoint_exact_from_cached(
row, &primary, q1, primary.q1, a1, g, d1, None, None, o_infl, true, &cached,
)
.expect("hand base");
assert!(
(eta.base.v - base.eta).abs() <= 1e-6 * (1.0 + base.eta.abs()),
"eta base value {} != hand {}",
eta.base.v,
base.eta
);
cmp_vec("eta_u", &eta.base.g, base.eta_u.as_slice().unwrap());
cmp_mat("eta_uv", &eta.base.h, &base.eta_uv);
cmp_vec("chi_u", &chi.base.g, base.chi_u.as_slice().unwrap());
cmp_mat("chi_uv", &chi.base.h, &base.chi_uv);
cmp_vec("d_u", &dnorm.base.g, base.d_u.as_slice().unwrap());
cmp_mat("d_uv", &dnorm.base.h, &base.d_uv);
cmp_vec("eta_u_dir", &eta.eps.g, hand.eta_u_dir.as_slice().unwrap());
cmp_mat("eta_uv_dir", &eta.eps.h, &hand.eta_uv_dir);
cmp_vec("chi_u_dir", &chi.eps.g, hand.chi_u_dir.as_slice().unwrap());
cmp_mat("chi_uv_dir", &chi.eps.h, &hand.chi_uv_dir);
cmp_vec("d_u_dir", &dnorm.eps.g, hand.d_u_dir.as_slice().unwrap());
cmp_mat("d_uv_dir", &dnorm.eps.h, &hand.d_uv_dir);
}
#[test]
fn flex_timepoint_inputs_jet4_bidirectional_matches_hand_932() {
let n = 16usize;
let family = make_g_only_flex_family(n);
let primary = flex_primary_slices(&family);
let p = primary.total;
let row = 7usize;
let g = 0.18_f64;
let m_beta = 0.15_f64;
let q1 = family.offset_exit[row] + family.marginal_design.to_dense()[[row, 0]] * m_beta;
let o_infl = 0.0_f64;
let solved = family
.solve_row_survival_intercept_with_slot(
q1,
g,
None,
None,
Some((row, SurvivalInterceptSlotKind::Exit)),
)
.expect("intercept solve");
let a1 = solved.0;
let cached = family
.build_cached_partition(&primary, a1, g, None, None)
.expect("cached partition");
let dir1 =
Array1::from_iter((0..p).map(|c| 0.12 + 0.04 * (c as f64) - 0.01 * ((c % 2) as f64)));
let dir2 =
Array1::from_iter((0..p).map(|c| -0.07 + 0.05 * ((c % 3) as f64) + 0.02 * (c as f64)));
let hand = family
.compute_survival_timepoint_bidirectional_exact_from_cached(
row, &primary, q1, primary.q1, a1, g, None, None, &cached, &dir1, &dir2,
)
.expect("hand bidirectional");
let (obs_coeff, obs_fixed) =
observed_fixed_for(&family, &primary, row, a1, g, None, None).expect("obs fixed");
let cells = cells_from_cached(&cached);
let z_obs = family.observed_score_projection(row);
let d_check = family
.evaluate_survival_denom_d(a1, g, None, None)
.expect("denom");
let template = Jet4::primary(0.0, usize::MAX, p, 0.0, 0.0);
let b_jet = Jet4::primary(g, primary.g, p, dir1[primary.g], dir2[primary.g]);
let du: Vec<Jet4> = (0..p)
.map(|u| Jet4::primary(0.0, u, p, dir1[u], dir2[u]))
.collect();
let (eta, chi, dnorm) = flex_timepoint_inputs_generic(
&template,
&b_jet,
&du,
a1,
d_check,
primary.g,
primary.infl,
primary.q1,
q1,
z_obs,
o_infl,
obs_coeff,
&obs_fixed,
&cells,
)
.expect("generic jet4");
let cmp_mat = |label: &str, jet: &Vec<f64>, hand: &Array2<f64>| {
for u in 0..p {
for v in 0..p {
assert!(
(jet[u * p + v] - hand[[u, v]]).abs() <= 1e-6 * (1.0 + hand[[u, v]].abs()),
"{label}[{u},{v}] jet {} != hand {}",
jet[u * p + v],
hand[[u, v]]
);
}
}
};
cmp_mat("eta_uv_uv", &eta.eps_del.h, &hand.eta_uv_uv);
cmp_mat("chi_uv_uv", &chi.eps_del.h, &hand.chi_uv_uv);
cmp_mat("d_uv_uv", &dnorm.eps_del.h, &hand.d_uv_uv);
}
fn flex_test_deviation_runtime() -> DeviationRuntime {
build_score_warp_deviation_block_from_seed(
&Array1::from(vec![-1.0, 0.0, 1.0]),
&DeviationBlockConfig {
degree: 3,
num_internal_knots: 1,
penalty_order: 2,
penalty_orders: vec![1, 2, 3],
double_penalty: false,
monotonicity_eps: 1e-4,
},
)
.expect("build test deviation runtime")
.runtime
}
fn make_ghw_flex_family(n: usize) -> SurvivalMarginalSlopeFamily {
let mut family = make_g_only_flex_family(n);
family.score_warp = Some(flex_test_deviation_runtime());
family.link_dev = Some(flex_test_deviation_runtime());
family
}
#[test]
fn flex_timepoint_inputs_ghw_jet3_jet4_match_hand_932() {
let n = 16usize;
let family = make_ghw_flex_family(n);
let primary = flex_primary_slices(&family);
let p = primary.total;
let row = 6usize;
let g = 0.2_f64;
let h_len = primary.h.as_ref().map(|r| r.len()).unwrap_or(0);
let w_len = primary.w.as_ref().map(|r| r.len()).unwrap_or(0);
let beta_h = Array1::from_iter(
(0..h_len).map(|i| 0.1 + 0.05 * (i as f64) - 0.02 * ((i % 2) as f64)),
);
let beta_w = Array1::from_iter(
(0..w_len).map(|i| -0.08 + 0.04 * (i as f64) + 0.01 * ((i % 3) as f64)),
);
let bh = Some(&beta_h);
let bw = Some(&beta_w);
let m_beta = 0.15_f64;
let q1 = family.offset_exit[row] + family.marginal_design.to_dense()[[row, 0]] * m_beta;
let o_infl = 0.0_f64;
let solved = family
.solve_row_survival_intercept_with_slot(
q1,
g,
bh,
bw,
Some((row, SurvivalInterceptSlotKind::Exit)),
)
.expect("intercept solve");
let a1 = solved.0;
let d1 = solved.1;
let cached = family
.build_cached_partition(&primary, a1, g, bh, bw)
.expect("cached partition");
let (obs_coeff, obs_fixed) =
observed_fixed_for(&family, &primary, row, a1, g, bh, bw).expect("obs fixed");
let cells = cells_from_cached(&cached);
let z_obs = family.observed_score_projection(row);
let d_check = family
.evaluate_survival_denom_d(a1, g, bh, bw)
.expect("denom");
let (oracle_eta_uvuv, oracle_chi_uvuv, oracle_d_uvuv) = {
let dir1 = Array1::from_iter(
(0..p).map(|c| 0.12 + 0.04 * (c as f64) - 0.01 * ((c % 2) as f64)),
);
let dir2 = Array1::from_iter(
(0..p).map(|c| -0.07 + 0.05 * ((c % 3) as f64) + 0.02 * (c as f64)),
);
let scalars_of = |pert: &Array1<f64>| -> (f64, f64, f64, f64) {
let q1_pert = q1 + pert[primary.q1];
let g_pert = g + pert[primary.g];
let bh_pert: Array1<f64> = Array1::from_iter(
(0..h_len).map(|i| beta_h[i] + pert[primary.h.as_ref().unwrap().start + i]),
);
let bw_pert: Array1<f64> = Array1::from_iter(
(0..w_len).map(|i| beta_w[i] + pert[primary.w.as_ref().unwrap().start + i]),
);
let a_pert = family
.solve_row_survival_intercept_with_slot(
q1_pert,
g_pert,
Some(&bh_pert),
Some(&bw_pert),
None,
)
.expect("oracle intercept solve")
.0;
let obs = family
.observed_denested_cell_partials(
row,
a_pert,
g_pert,
Some(&bh_pert),
Some(&bw_pert),
)
.expect("oracle observed partials");
let d_pert = family
.evaluate_survival_denom_d(a_pert, g_pert, Some(&bh_pert), Some(&bw_pert))
.expect("oracle denom");
(
a_pert,
eval_coeff4_at(&obs.coeff, z_obs) + o_infl,
eval_coeff4_at(&obs.dc_da, z_obs),
d_pert,
)
};
let hq = 2.0e-3_f64;
let ht = 3.0e-3_f64;
let pert_vec =
|su: f64, u: usize, sv: f64, v: usize, t1: f64, t2: f64| -> Array1<f64> {
let mut pert = &dir1 * t1 + &dir2 * t2;
pert[u] += su;
pert[v] += sv;
pert
};
let mixed = |u: usize, v: usize| -> (f64, f64, f64, f64) {
let acc =
|w: f64, su: f64, sv: f64, t1: f64, t2: f64, out: &mut (f64, f64, f64, f64)| {
let s = scalars_of(&pert_vec(su, u, sv, v, t1, t2));
out.0 += w * s.0;
out.1 += w * s.1;
out.2 += w * s.2;
out.3 += w * s.3;
};
let hess_uv = |t1: f64, t2: f64| -> (f64, f64, f64, f64) {
let mut o = (0.0, 0.0, 0.0, 0.0);
if u == v {
acc(1.0, hq, 0.0, t1, t2, &mut o);
acc(-2.0, 0.0, 0.0, t1, t2, &mut o);
acc(1.0, -hq, 0.0, t1, t2, &mut o);
let inv = 1.0 / (hq * hq);
(o.0 * inv, o.1 * inv, o.2 * inv, o.3 * inv)
} else {
acc(1.0, hq, hq, t1, t2, &mut o);
acc(-1.0, hq, -hq, t1, t2, &mut o);
acc(-1.0, -hq, hq, t1, t2, &mut o);
acc(1.0, -hq, -hq, t1, t2, &mut o);
let inv = 1.0 / (4.0 * hq * hq);
(o.0 * inv, o.1 * inv, o.2 * inv, o.3 * inv)
}
};
let a = hess_uv(ht, ht);
let b = hess_uv(ht, -ht);
let c = hess_uv(-ht, ht);
let d = hess_uv(-ht, -ht);
let inv = 1.0 / (4.0 * ht * ht);
(
(a.0 - b.0 - c.0 + d.0) * inv,
(a.1 - b.1 - c.1 + d.1) * inv,
(a.2 - b.2 - c.2 + d.2) * inv,
(a.3 - b.3 - c.3 + d.3) * inv,
)
};
let template4 = Jet4::primary(0.0, usize::MAX, p, 0.0, 0.0);
let b_jet4 = Jet4::primary(g, primary.g, p, dir1[primary.g], dir2[primary.g]);
let du4: Vec<Jet4> = (0..p)
.map(|u| Jet4::primary(0.0, u, p, dir1[u], dir2[u]))
.collect();
let residual_probe = |a: &Jet4| {
calibration_residual_jet(a, &b_jet4, primary.g, &du4, primary.q1, q1, &cells)
};
let a_jet_probe = lift_intercept_flex(&template4, a1, 1.0 / d_check, 4, residual_probe);
let jet_a_uvuv = a_jet_probe.eps_del.h[primary.q1 * p + primary.q1];
let mut o_eta = Array2::<f64>::zeros((p, p));
let mut o_chi = Array2::<f64>::zeros((p, p));
let mut o_d = Array2::<f64>::zeros((p, p));
let mut ref_a_uvuv = 0.0_f64;
for u in 0..p {
for v in u..p {
let (a_uvuv, eta_uvuv, chi_uvuv, d_uvuv) = mixed(u, v);
o_eta[[u, v]] = eta_uvuv;
o_eta[[v, u]] = eta_uvuv;
o_chi[[u, v]] = chi_uvuv;
o_chi[[v, u]] = chi_uvuv;
o_d[[u, v]] = d_uvuv;
o_d[[v, u]] = d_uvuv;
if u == primary.q1 && v == primary.q1 {
ref_a_uvuv = a_uvuv;
}
}
}
assert!(
(jet_a_uvuv - ref_a_uvuv).abs() <= 1e-3 * (1.0 + ref_a_uvuv.abs()),
"#932 PROBE a_uv_uv[q1,q1]: jet {jet_a_uvuv} != scalar-FD {ref_a_uvuv} \
(diff {})",
jet_a_uvuv - ref_a_uvuv,
);
(o_eta, o_chi, o_d)
};
let cmp_vec = |label: &str, jet: &Vec<f64>, hand: &[f64]| {
for u in 0..p {
assert!(
(jet[u] - hand[u]).abs() <= 1e-6 * (1.0 + hand[u].abs()),
"{label}[{u}] jet {} != hand {}",
jet[u],
hand[u]
);
}
};
let cmp_mat = |label: &str, jet: &Vec<f64>, hand: &Array2<f64>| {
for u in 0..p {
for v in 0..p {
assert!(
(jet[u * p + v] - hand[[u, v]]).abs() <= 1e-6 * (1.0 + hand[[u, v]].abs()),
"{label}[{u},{v}] jet {} != hand {}",
jet[u * p + v],
hand[[u, v]]
);
}
}
};
let dir =
Array1::from_iter((0..p).map(|c| 0.1 + 0.05 * (c as f64) - 0.02 * ((c % 3) as f64)));
let hand_dir = family
.compute_survival_timepoint_directional_exact_from_cached(
row, &primary, q1, primary.q1, a1, g, bh, bw, &cached, &dir, true,
)
.expect("hand directional");
let base = family
.compute_survival_timepoint_exact_from_cached(
row, &primary, q1, primary.q1, a1, g, d1, bh, bw, o_infl, true, &cached,
)
.expect("hand base");
let template3 = Jet3::primary(0.0, usize::MAX, p, 0.0);
let b_jet3 = Jet3::primary(g, primary.g, p, dir[primary.g]);
let du3: Vec<Jet3> = (0..p).map(|u| Jet3::primary(0.0, u, p, dir[u])).collect();
let (eta3, chi3, d3) = flex_timepoint_inputs_generic(
&template3,
&b_jet3,
&du3,
a1,
d_check,
primary.g,
primary.infl,
primary.q1,
q1,
z_obs,
o_infl,
obs_coeff,
&obs_fixed,
&cells,
)
.expect("generic jet3");
cmp_vec("eta_u", &eta3.base.g, base.eta_u.as_slice().unwrap());
cmp_mat("eta_uv", &eta3.base.h, &base.eta_uv);
cmp_vec("chi_u", &chi3.base.g, base.chi_u.as_slice().unwrap());
cmp_mat("chi_uv", &chi3.base.h, &base.chi_uv);
cmp_vec("d_u", &d3.base.g, base.d_u.as_slice().unwrap());
cmp_mat("d_uv", &d3.base.h, &base.d_uv);
cmp_vec(
"eta_u_dir",
&eta3.eps.g,
hand_dir.eta_u_dir.as_slice().unwrap(),
);
cmp_mat("eta_uv_dir", &eta3.eps.h, &hand_dir.eta_uv_dir);
cmp_vec(
"chi_u_dir",
&chi3.eps.g,
hand_dir.chi_u_dir.as_slice().unwrap(),
);
cmp_mat("chi_uv_dir", &chi3.eps.h, &hand_dir.chi_uv_dir);
cmp_vec("d_u_dir", &d3.eps.g, hand_dir.d_u_dir.as_slice().unwrap());
cmp_mat("d_uv_dir", &d3.eps.h, &hand_dir.d_uv_dir);
let dir1 =
Array1::from_iter((0..p).map(|c| 0.12 + 0.04 * (c as f64) - 0.01 * ((c % 2) as f64)));
let dir2 =
Array1::from_iter((0..p).map(|c| -0.07 + 0.05 * ((c % 3) as f64) + 0.02 * (c as f64)));
let template4 = Jet4::primary(0.0, usize::MAX, p, 0.0, 0.0);
let b_jet4 = Jet4::primary(g, primary.g, p, dir1[primary.g], dir2[primary.g]);
let du4: Vec<Jet4> = (0..p)
.map(|u| Jet4::primary(0.0, u, p, dir1[u], dir2[u]))
.collect();
let (eta4, chi4, d4) = flex_timepoint_inputs_generic(
&template4,
&b_jet4,
&du4,
a1,
d_check,
primary.g,
primary.infl,
primary.q1,
q1,
z_obs,
o_infl,
obs_coeff,
&obs_fixed,
&cells,
)
.expect("generic jet4");
let cmp_mat_oracle = |label: &str, jet: &Vec<f64>, oracle: &Array2<f64>| {
let mut fails: Vec<String> = Vec::new();
for u in 0..p {
for v in 0..p {
let o = oracle[[u, v]];
let j = jet[u * p + v];
if (j - o).abs() > 1e-3 * (1.0 + o.abs()) {
let rel = (j - o).abs() / (1.0 + o.abs());
fails.push(format!("[{u},{v}] jet {j:.6} oracle {o:.6} rel {rel:.2e}"));
}
}
}
assert!(
fails.is_empty(),
"{label} jet != scalar-FD oracle at {} entr{}: {}",
fails.len(),
if fails.len() == 1 { "y" } else { "ies" },
fails.join("; "),
);
};
cmp_mat_oracle("eta_uv_uv", &eta4.eps_del.h, &oracle_eta_uvuv);
cmp_mat_oracle("chi_uv_uv", &chi4.eps_del.h, &oracle_chi_uvuv);
cmp_mat_oracle("d_uv_uv", &d4.eps_del.h, &oracle_d_uvuv);
}
}