use crate::families::jet_tower::Tower4;
use ndarray::{Array1, Array2, ArrayView1};
type GJet = Tower4<1>;
#[derive(Clone, Debug)]
pub struct ClosureFamily {
pub harmonics: usize,
pub window: f64,
}
#[inline]
fn cos_stack(theta: f64) -> [f64; 5] {
let (s, c) = theta.sin_cos();
[c, -s, -c, s, c]
}
#[inline]
fn sin_stack(theta: f64) -> [f64; 5] {
let (s, c) = theta.sin_cos();
[s, c, -s, -c, s]
}
impl ClosureFamily {
pub fn new(harmonics: usize, window: f64) -> Self {
Self { harmonics, window }
}
#[inline]
pub fn raw_dim(&self) -> usize {
1 + 2 * self.harmonics
}
pub fn row_jet(&self, s: f64, gamma: f64) -> (Array1<f64>, Array1<f64>, Array1<f64>) {
let d = self.raw_dim();
let mut value = Array1::zeros(d);
let mut dg = Array1::zeros(d);
let mut dgg = Array1::zeros(d);
value[0] = 1.0;
let g = GJet::variable(gamma, 0);
for m in 1..=self.harmonics {
let theta = g.scale(m as f64 * s);
let cos_col = theta.compose_unary(cos_stack(theta.v));
let sin_col = theta.compose_unary(sin_stack(theta.v));
let ci = 2 * m - 1;
let si = 2 * m;
value[ci] = cos_col.v;
dg[ci] = cos_col.g[0];
dgg[ci] = cos_col.h[0][0];
value[si] = sin_col.v;
dg[si] = sin_col.g[0];
dgg[si] = sin_col.h[0][0];
}
(value, dg, dgg)
}
pub fn design(&self, s: ArrayView1<'_, f64>, gamma: f64) -> Array2<f64> {
let n = s.len();
let d = self.raw_dim();
let mut phi = Array2::zeros((n, d));
for (i, &si) in s.iter().enumerate() {
let (v, _, _) = self.row_jet(si, gamma);
phi.row_mut(i).assign(&v);
}
phi
}
pub fn design_jet(
&self,
s: ArrayView1<'_, f64>,
gamma: f64,
) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
let n = s.len();
let d = self.raw_dim();
let mut phi = Array2::zeros((n, d));
let mut dphi = Array2::zeros((n, d));
let mut ddphi = Array2::zeros((n, d));
for (i, &si) in s.iter().enumerate() {
let (v, dv, ddv) = self.row_jet(si, gamma);
phi.row_mut(i).assign(&v);
dphi.row_mut(i).assign(&dv);
ddphi.row_mut(i).assign(&ddv);
}
(phi, dphi, ddphi)
}
}
pub fn boundary_conductance(gamma: f64) -> (f64, f64, f64) {
let g = gamma.clamp(0.0, 1.0);
let c = 3.0 * g * g - 2.0 * g * g * g;
let cp = 6.0 * g - 6.0 * g * g;
let cpp = 6.0 - 12.0 * g;
(c, cp, cpp)
}
pub fn conductance_penalty_jet(
s_open: &Array2<f64>,
s_wrap: &Array2<f64>,
gamma: f64,
) -> (Array2<f64>, Array2<f64>, Array2<f64>) {
let (c, cp, cpp) = boundary_conductance(gamma);
let s = s_open + &(s_wrap * c);
let ds = s_wrap * cp;
let dds = s_wrap * cpp;
(s, ds, dds)
}
#[derive(Clone, Copy, Debug)]
pub struct ClosureProfileCi {
pub gamma_hat: f64,
pub ci_lo: f64,
pub ci_hi: f64,
pub ci_includes_circle: bool,
pub ci_includes_interval: bool,
pub singular_boundary: bool,
}
fn chi2_1_quantile(level: f64) -> f64 {
let z = inv_std_normal(0.5 * (1.0 + level));
z * z
}
pub fn profile_ci_from_grid(grid: &[(f64, f64)], level: f64) -> Result<ClosureProfileCi, String> {
if grid.len() < 2 {
return Err("closure profile CI needs at least two grid points".into());
}
let half_chi2 = 0.5 * chi2_1_quantile(level);
let (mut gamma_hat, mut v_min) = (grid[0].0, grid[0].1);
for &(g, v) in grid {
if !g.is_finite() || !v.is_finite() {
return Err("closure profile grid has non-finite entries".into());
}
if v < v_min {
v_min = v;
gamma_hat = g;
}
}
let in_set = |v: f64| v - v_min <= half_chi2 + 1e-12;
let mut ci_lo = gamma_hat;
let mut ci_hi = gamma_hat;
for w in grid.windows(2) {
let (g0, v0) = w[0];
let (g1, v1) = w[1];
let (a0, a1) = (in_set(v0), in_set(v1));
if a0 {
ci_lo = ci_lo.min(g0);
ci_hi = ci_hi.max(g0);
}
if a1 {
ci_lo = ci_lo.min(g1);
ci_hi = ci_hi.max(g1);
}
if a0 != a1 {
let target = v_min + half_chi2;
let t = ((target - v0) / (v1 - v0)).clamp(0.0, 1.0);
let g_cross = g0 + t * (g1 - g0);
ci_lo = ci_lo.min(g_cross);
ci_hi = ci_hi.max(g_cross);
}
}
ci_lo = ci_lo.clamp(0.0, 1.0);
ci_hi = ci_hi.clamp(0.0, 1.0);
let ci_includes_circle = ci_hi >= 1.0 - 1e-9;
let ci_includes_interval = ci_lo <= 1e-9;
let singular_boundary = gamma_hat <= 1e-9 && {
let interior = grid.iter().find(|&&(g, _)| g > 1e-9);
interior.map(|&(_, v)| v >= v_min - 1e-9).unwrap_or(false)
};
Ok(ClosureProfileCi {
gamma_hat,
ci_lo,
ci_hi,
ci_includes_circle,
ci_includes_interval,
singular_boundary,
})
}
fn inv_std_normal(p: f64) -> f64 {
if p <= 0.0 {
return f64::NEG_INFINITY;
}
if p >= 1.0 {
return f64::INFINITY;
}
const A: [f64; 6] = [
-3.969_683_028_665_376e1,
2.209_460_984_245_205e2,
-2.759_285_104_469_687e2,
1.383_577_518_672_690e2,
-3.066_479_806_614_716e1,
2.506_628_277_459_239e0,
];
const B: [f64; 5] = [
-5.447_609_879_822_406e1,
1.615_858_368_580_409e2,
-1.556_989_798_598_866e2,
6.680_131_188_771_972e1,
-1.328_068_155_288_572e1,
];
const C: [f64; 6] = [
-7.784_894_002_430_293e-3,
-3.223_964_580_411_365e-1,
-2.400_758_277_161_838e0,
-2.549_732_539_343_734e0,
4.374_664_141_464_968e0,
2.938_163_982_698_783e0,
];
const D: [f64; 4] = [
7.784_695_709_041_462e-3,
3.224_671_290_700_398e-1,
2.445_134_137_142_996e0,
3.754_408_661_907_416e0,
];
const P_LOW: f64 = 0.024_25;
let x = if p < P_LOW {
let q = (-2.0 * p.ln()).sqrt();
(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
/ ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
} else if p <= 1.0 - P_LOW {
let q = p - 0.5;
let r = q * q;
(((((A[0] * r + A[1]) * r + A[2]) * r + A[3]) * r + A[4]) * r + A[5]) * q
/ (((((B[0] * r + B[1]) * r + B[2]) * r + B[3]) * r + B[4]) * r + 1.0)
} else {
let q = (-2.0 * (1.0 - p).ln()).sqrt();
-(((((C[0] * q + C[1]) * q + C[2]) * q + C[3]) * q + C[4]) * q + C[5])
/ ((((D[0] * q + D[1]) * q + D[2]) * q + D[3]) * q + 1.0)
};
let e = 0.5 * libm::erfc(-x / std::f64::consts::SQRT_2) - p;
let u = e * (2.0 * std::f64::consts::PI).sqrt() * (0.5 * x * x).exp();
x - u / (1.0 + 0.5 * x * u)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
#[test]
fn gamma_one_is_circle_basis() {
let fam = ClosureFamily::new(2, std::f64::consts::TAU);
let s = 1.3_f64;
let (v, _, _) = fam.row_jet(s, 1.0);
assert!((v[0] - 1.0).abs() < 1e-15);
assert!((v[1] - s.cos()).abs() < 1e-14);
assert!((v[2] - s.sin()).abs() < 1e-14);
assert!((v[3] - (2.0 * s).cos()).abs() < 1e-14);
assert!((v[4] - (2.0 * s).sin()).abs() < 1e-14);
}
#[test]
fn basis_gamma_jet_matches_fd() {
let fam = ClosureFamily::new(3, std::f64::consts::TAU);
let s = 0.8_f64;
for &g0 in &[1.0_f64, 0.5, 0.05, 1e-6] {
let (_, dg, dgg) = fam.row_jet(s, g0);
let h = 1e-5;
let (vp, _, _) = fam.row_jet(s, g0 + h);
let (vm, _, _) = fam.row_jet(s, g0 - h);
let (v0, _, _) = fam.row_jet(s, g0);
for j in 0..fam.raw_dim() {
let fd1 = (vp[j] - vm[j]) / (2.0 * h);
let fd2 = (vp[j] - 2.0 * v0[j] + vm[j]) / (h * h);
assert!(
(dg[j] - fd1).abs() < 1e-5,
"d/dγ col {j} at γ={g0}: analytic {} fd {fd1}",
dg[j]
);
assert!(
(dgg[j] - fd2).abs() < 1e-3,
"d²/dγ² col {j} at γ={g0}: analytic {} fd {fd2}",
dgg[j]
);
}
}
}
#[test]
fn conductance_endpoints_and_flat() {
let (c0, cp0, _) = boundary_conductance(0.0);
let (c1, cp1, _) = boundary_conductance(1.0);
assert!(c0.abs() < 1e-15 && (c1 - 1.0).abs() < 1e-15);
assert!(cp0.abs() < 1e-15 && cp1.abs() < 1e-15);
}
#[test]
fn conductance_penalty_interpolates() {
let s_open = array![[2.0, -1.0], [-1.0, 2.0]];
let s_wrap = array![[1.0, -1.0], [-1.0, 1.0]];
let (s0, _, _) = conductance_penalty_jet(&s_open, &s_wrap, 0.0);
let (s1, _, _) = conductance_penalty_jet(&s_open, &s_wrap, 1.0);
assert!((&s0 - &s_open).iter().all(|v| v.abs() < 1e-14));
let circle = &s_open + &s_wrap;
assert!((&s1 - &circle).iter().all(|v| v.abs() < 1e-14));
let g = 0.4;
let (_, ds, _) = conductance_penalty_jet(&s_open, &s_wrap, g);
let h = 1e-6;
let (sp, _, _) = conductance_penalty_jet(&s_open, &s_wrap, g + h);
let (sm, _, _) = conductance_penalty_jet(&s_open, &s_wrap, g - h);
let fd = (&sp - &sm).mapv(|v| v / (2.0 * h));
assert!((&ds - &fd).iter().all(|v| v.abs() < 1e-6));
}
#[test]
fn profile_ci_interior_minimum() {
let v = |g: f64| 100.0 + 50.0 * (g - 0.6).powi(2);
let grid: Vec<(f64, f64)> = (0..=100)
.map(|k| k as f64 / 100.0)
.map(|g| (g, v(g)))
.collect();
let ci = profile_ci_from_grid(&grid, 0.95).unwrap();
assert!((ci.gamma_hat - 0.6).abs() < 0.02, "γ̂ {}", ci.gamma_hat);
assert!(!ci.ci_includes_circle, "CI hi {}", ci.ci_hi);
assert!(!ci.ci_includes_interval, "CI lo {}", ci.ci_lo);
assert!(!ci.singular_boundary);
let want = (chi2_1_quantile(0.95) / (2.0 * 50.0)).sqrt();
assert!(((ci.ci_hi - ci.ci_lo) / 2.0 - want).abs() < 0.02);
}
#[test]
fn profile_ci_includes_circle_at_boundary() {
let v = |g: f64| 10.0 + 30.0 * (g - 1.05).powi(2); let grid: Vec<(f64, f64)> = (0..=100)
.map(|k| k as f64 / 100.0)
.map(|g| (g, v(g)))
.collect();
let ci = profile_ci_from_grid(&grid, 0.95).unwrap();
assert!(ci.ci_includes_circle);
assert!(!ci.singular_boundary);
}
#[test]
fn profile_flags_singular_boundary() {
let v = |g: f64| 10.0 + 20.0 * g; let grid: Vec<(f64, f64)> = (0..=100)
.map(|k| k as f64 / 100.0)
.map(|g| (g, v(g)))
.collect();
let ci = profile_ci_from_grid(&grid, 0.95).unwrap();
assert!((ci.gamma_hat).abs() < 1e-9);
assert!(ci.singular_boundary);
assert!(ci.ci_includes_interval);
}
#[test]
fn chi2_quantile_known_value() {
assert!((chi2_1_quantile(0.95) - 3.841_458_820_694_124).abs() < 1e-6);
}
}