use ndarray::{Array1, Array2, ArrayView1};
pub const PSI_GRAM_CERT_RTOL: f64 = 1.0e-12;
pub const PSI_GRAM_SPOT_RTOL: f64 = 1.0e-10;
pub const PSI_GRAM_GRAD_SPOT_RTOL: f64 = 1.0e-11;
pub const PSI_GRAM_GRAD_SCAN_POINTS: usize = 64;
pub const PSI_GRAM_NODE_LADDER: [usize; 4] = [9, 17, 33, 65];
pub const PSI_GRAM_SPOT_POINTS: usize = 3;
pub struct PsiGramTensor {
psi_lo: f64,
psi_hi: f64,
grad_psi_lo: f64,
grad_psi_hi: f64,
n_coeff: usize,
k: usize,
gram: Vec<Array2<f64>>,
rhs: Vec<Array1<f64>>,
zt_w_z: f64,
}
enum BuildOutcome {
EvalFailed,
TailNotCertified,
Candidate(PsiGramTensor),
}
fn cheb_t(x: f64, n: usize) -> Vec<f64> {
let mut t = vec![0.0; n];
if n > 0 {
t[0] = 1.0;
}
if n > 1 {
t[1] = x;
}
for d in 2..n {
t[d] = 2.0 * x * t[d - 1] - t[d - 2];
}
t
}
fn cheb_t_prime(x: f64, n: usize) -> Vec<f64> {
let mut u = vec![0.0; n.max(1)];
if !u.is_empty() {
u[0] = 1.0;
}
if n > 1 {
u[1] = 2.0 * x;
}
for d in 2..n {
u[d] = 2.0 * x * u[d - 1] - u[d - 2];
}
let mut tp = vec![0.0; n];
for d in 1..n {
tp[d] = d as f64 * u[d - 1];
}
tp
}
impl PsiGramTensor {
pub fn build(
mut eval_design: impl FnMut(f64) -> Result<Array2<f64>, String>,
weights: ArrayView1<'_, f64>,
z: ArrayView1<'_, f64>,
psi_lo: f64,
psi_hi: f64,
) -> Option<Self> {
if !(psi_lo.is_finite() && psi_hi.is_finite()) || psi_hi <= psi_lo {
return None;
}
for &m in PSI_GRAM_NODE_LADDER.iter() {
match Self::build_at(&mut eval_design, weights, z, psi_lo, psi_hi, m) {
BuildOutcome::EvalFailed => return None,
BuildOutcome::TailNotCertified => continue,
BuildOutcome::Candidate(mut candidate) => {
if candidate.spot_check(&mut eval_design, weights) {
candidate.certify_gradient_window(&mut eval_design, weights);
return Some(candidate);
}
}
}
}
None
}
fn build_at(
eval_design: &mut impl FnMut(f64) -> Result<Array2<f64>, String>,
weights: ArrayView1<'_, f64>,
z: ArrayView1<'_, f64>,
psi_lo: f64,
psi_hi: f64,
m: usize,
) -> BuildOutcome {
let mut nodes_x = vec![0.0_f64; m];
let mut designs: Vec<Array2<f64>> = Vec::with_capacity(m);
for (i, x_slot) in nodes_x.iter_mut().enumerate() {
let x = (std::f64::consts::PI * (2 * i + 1) as f64 / (2 * m) as f64).cos();
*x_slot = x;
let psi = 0.5 * (psi_lo + psi_hi) + 0.5 * (psi_hi - psi_lo) * x;
let Ok(design) = eval_design(psi) else {
return BuildOutcome::EvalFailed;
};
if design.iter().any(|v| !v.is_finite()) {
return BuildOutcome::EvalFailed;
}
designs.push(design);
}
let (n, k) = designs[0].dim();
if designs.iter().any(|d| d.dim() != (n, k))
|| weights.len() != n
|| z.len() != n
|| n == 0
|| k == 0
{
return BuildOutcome::EvalFailed;
}
let t_at_nodes: Vec<Vec<f64>> = nodes_x.iter().map(|&x| cheb_t(x, m)).collect();
let mut coeff_slabs: Vec<Array2<f64>> = Vec::with_capacity(m);
for d in 0..m {
let gamma = if d == 0 { 1.0 } else { 2.0 };
let mut slab = Array2::<f64>::zeros((n, k));
for (i, design) in designs.iter().enumerate() {
let wgt = gamma / m as f64 * t_at_nodes[i][d];
slab.scaled_add(wgt, design);
}
coeff_slabs.push(slab);
}
let mut col_scale = vec![0.0_f64; k];
for slab in &coeff_slabs {
for (j, scale) in col_scale.iter_mut().enumerate() {
for i in 0..n {
*scale = scale.max(slab[[i, j]].abs());
}
}
}
let tail_start = m - (m / 4).max(1);
for slab in coeff_slabs.iter().skip(tail_start) {
for (j, &scale) in col_scale.iter().enumerate() {
let bound = PSI_GRAM_CERT_RTOL * scale.max(1e-300);
for i in 0..n {
if slab[[i, j]].abs() > bound {
return BuildOutcome::TailNotCertified;
}
}
}
}
let mut weighted: Vec<Array2<f64>> = Vec::with_capacity(m);
for slab in &coeff_slabs {
let mut ws = slab.clone();
for (mut row, &w) in ws.outer_iter_mut().zip(weights.iter()) {
row.mapv_inplace(|v| v * w);
}
weighted.push(ws);
}
let mut wz = Array1::<f64>::zeros(z.len());
let mut zt_w_z = 0.0_f64;
for ((slot, &w), &zv) in wz.iter_mut().zip(weights.iter()).zip(z.iter()) {
*slot = w * zv;
zt_w_z += w * zv * zv;
}
let mut gram: Vec<Array2<f64>> = Vec::with_capacity(m * m);
let mut rhs = Vec::with_capacity(m);
for d in 0..m {
for e in 0..m {
if e < d {
let g: Array2<f64> = gram[e * m + d].t().to_owned();
gram.push(g);
} else {
gram.push(coeff_slabs[d].t().dot(&weighted[e]));
}
}
rhs.push(coeff_slabs[d].t().dot(&wz));
}
BuildOutcome::Candidate(Self {
psi_lo,
psi_hi,
grad_psi_lo: psi_lo,
grad_psi_hi: psi_hi,
n_coeff: m,
k,
gram,
rhs,
zt_w_z,
})
}
fn spot_check(
&self,
eval_design: &mut impl FnMut(f64) -> Result<Array2<f64>, String>,
weights: ArrayView1<'_, f64>,
) -> bool {
for s in 0..PSI_GRAM_SPOT_POINTS {
let frac = ((s as f64 + 1.0) * 0.618_033_988_749_894_9).fract();
let psi = self.psi_lo + frac * (self.psi_hi - self.psi_lo);
let Ok(design) = eval_design(psi) else {
return false;
};
let mut wd = design.clone();
for (mut row, &w) in wd.outer_iter_mut().zip(weights.iter()) {
row.mapv_inplace(|v| v * w);
}
let exact = design.t().dot(&wd);
let assembled = self.gram_at(psi);
let scale = exact
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()))
.max(1e-300);
for (a, b) in assembled.iter().zip(exact.iter()) {
if (a - b).abs() > PSI_GRAM_SPOT_RTOL * scale {
return false;
}
}
}
true
}
fn certify_gradient_window(
&mut self,
eval_design: &mut impl FnMut(f64) -> Result<Array2<f64>, String>,
weights: ArrayView1<'_, f64>,
) {
let span = self.psi_hi - self.psi_lo;
let h = (span * 1e-3).max(1e-6);
let exact_dgram = |psi: f64,
eval: &mut dyn FnMut(f64) -> Result<Array2<f64>, String>|
-> Option<Array2<f64>> {
let weighted_gram = |p: f64,
eval: &mut dyn FnMut(f64) -> Result<Array2<f64>, String>|
-> Option<Array2<f64>> {
let design = eval(p).ok()?;
let mut wd = design.clone();
for (mut row, &w) in wd.outer_iter_mut().zip(weights.iter()) {
row.mapv_inplace(|v| v * w);
}
Some(design.t().dot(&wd))
};
let g_m2 = weighted_gram(psi - 2.0 * h, eval)?;
let g_m1 = weighted_gram(psi - h, eval)?;
let g_p1 = weighted_gram(psi + h, eval)?;
let g_p2 = weighted_gram(psi + 2.0 * h, eval)?;
Some((g_m2 - 8.0 * &g_m1 + 8.0 * &g_p1 - g_p2) / (12.0 * h))
};
let certifies = |me: &Self,
psi: f64,
eval: &mut dyn FnMut(f64) -> Result<Array2<f64>, String>|
-> bool {
if psi - 2.0 * h <= me.psi_lo || psi + 2.0 * h >= me.psi_hi {
return false;
}
let Some(exact) = exact_dgram(psi, eval) else {
return false;
};
let analytic = me.dgram_dpsi(psi);
let scale = exact
.iter()
.fold(0.0_f64, |acc, &v| acc.max(v.abs()))
.max(1e-300);
analytic
.iter()
.zip(exact.iter())
.all(|(a, b)| (a - b).abs() <= PSI_GRAM_GRAD_SPOT_RTOL * scale)
};
let n = PSI_GRAM_GRAD_SCAN_POINTS;
let mut lo = self.psi_hi;
let mut hi = self.psi_lo;
let mut found = false;
for i in 0..=n {
let psi = self.psi_lo + span * (i as f64) / (n as f64);
if certifies(self, psi, eval_design) {
lo = psi;
found = true;
break;
}
}
for i in (0..=n).rev() {
let psi = self.psi_lo + span * (i as f64) / (n as f64);
if certifies(self, psi, eval_design) {
hi = psi;
break;
}
}
if found && hi > lo {
self.grad_psi_lo = lo;
self.grad_psi_hi = hi;
} else {
self.grad_psi_lo = f64::NAN;
self.grad_psi_hi = f64::NAN;
}
}
pub fn contains(&self, psi: f64) -> bool {
psi.is_finite() && psi >= self.psi_lo && psi <= self.psi_hi
}
pub fn contains_for_gradient(&self, psi: f64) -> bool {
psi.is_finite()
&& self.grad_psi_lo.is_finite()
&& self.grad_psi_hi.is_finite()
&& psi >= self.grad_psi_lo
&& psi <= self.grad_psi_hi
}
fn mapped(&self, psi: f64) -> f64 {
(2.0 * psi - (self.psi_lo + self.psi_hi)) / (self.psi_hi - self.psi_lo)
}
pub fn gram_at(&self, psi: f64) -> Array2<f64> {
let x = self.mapped(psi);
let t = cheb_t(x, self.n_coeff);
let mut out = Array2::<f64>::zeros((self.k, self.k));
for d in 0..self.n_coeff {
for e in 0..self.n_coeff {
out.scaled_add(t[d] * t[e], &self.gram[d * self.n_coeff + e]);
}
}
out
}
pub fn rhs_at(&self, psi: f64) -> Array1<f64> {
let x = self.mapped(psi);
let t = cheb_t(x, self.n_coeff);
let mut out = Array1::<f64>::zeros(self.k);
for (d, td) in t.iter().enumerate() {
out.scaled_add(*td, &self.rhs[d]);
}
out
}
pub fn dgram_dpsi(&self, psi: f64) -> Array2<f64> {
let x = self.mapped(psi);
let dx_dpsi = 2.0 / (self.psi_hi - self.psi_lo);
let t = cheb_t(x, self.n_coeff);
let tp = cheb_t_prime(x, self.n_coeff);
let mut out = Array2::<f64>::zeros((self.k, self.k));
for d in 0..self.n_coeff {
for e in 0..self.n_coeff {
out.scaled_add(
(tp[d] * t[e] + t[d] * tp[e]) * dx_dpsi,
&self.gram[d * self.n_coeff + e],
);
}
}
out
}
pub fn drhs_dpsi(&self, psi: f64) -> Array1<f64> {
let x = self.mapped(psi);
let dx_dpsi = 2.0 / (self.psi_hi - self.psi_lo);
let tp = cheb_t_prime(x, self.n_coeff);
let mut out = Array1::<f64>::zeros(self.k);
for (d, tpd) in tp.iter().enumerate() {
out.scaled_add(*tpd * dx_dpsi, &self.rhs[d]);
}
out
}
pub fn gaussian_fixed_cache_at(&self, psi: f64) -> crate::pirls::GaussianFixedCache {
crate::pirls::GaussianFixedCache {
xtwx_orig: self.gram_at(psi),
xtwy_orig: self.rhs_at(psi),
centered_weighted_y_sq: self.zt_w_z,
xtwx_sparse_orig: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synth_design(psi: f64, n: usize, k: usize) -> Result<Array2<f64>, String> {
let mut x = Array2::<f64>::zeros((n, k));
for i in 0..n {
for j in 0..k {
let r = 0.05 + (i as f64 + 1.0) * (j as f64 + 1.0) / (n as f64 * k as f64) * 3.0;
if j == k - 1 {
x[[i, j]] = r * r * r;
} else {
let s = r * psi.exp();
x[[i, j]] = (1.0 + s) * (-s).exp();
}
}
}
Ok(x)
}
fn exact_gram(psi: f64, n: usize, k: usize, w: &Array1<f64>) -> Array2<f64> {
let design = synth_design(psi, n, k).unwrap();
let mut wd = design.clone();
for (mut row, &wi) in wd.outer_iter_mut().zip(w.iter()) {
row.mapv_inplace(|v| v * wi);
}
design.t().dot(&wd)
}
#[test]
fn psi_gram_tensor_matches_exact_gram_and_fd_gradient() {
let (n, k) = (160usize, 7usize);
let w = Array1::from_iter((0..n).map(|i| 1.0 + 0.5 * ((i % 3) as f64)));
let z = Array1::from_iter((0..n).map(|i| ((i as f64) * 0.37).sin()));
let (psi_lo, psi_hi) = (-1.2_f64, 1.0_f64);
let tensor = PsiGramTensor::build(
|psi| synth_design(psi, n, k),
w.view(),
z.view(),
psi_lo,
psi_hi,
)
.expect("analytic synthetic design must certify");
for &psi in &[-1.1, -0.63, 0.0, 0.41, 0.97] {
assert!(tensor.contains(psi));
let exact = exact_gram(psi, n, k, &w);
let fast = tensor.gram_at(psi);
let scale = exact.iter().fold(0.0_f64, |a, &v| a.max(v.abs()));
for (a, b) in fast.iter().zip(exact.iter()) {
assert!(
(a - b).abs() <= 1e-9 * scale,
"gram mismatch at psi={psi}: fast={a}, exact={b}"
);
}
let design = synth_design(psi, n, k).unwrap();
let mut wz = Array1::<f64>::zeros(n);
for ((slot, &wi), &zi) in wz.iter_mut().zip(w.iter()).zip(z.iter()) {
*slot = wi * zi;
}
let exact_rhs = design.t().dot(&wz);
let fast_rhs = tensor.rhs_at(psi);
let rscale = exact_rhs.iter().fold(0.0_f64, |a, &v| a.max(v.abs()));
for (a, b) in fast_rhs.iter().zip(exact_rhs.iter()) {
assert!(
(a - b).abs() <= 1e-9 * rscale,
"rhs mismatch at psi={psi}: fast={a}, exact={b}"
);
}
let h = 1e-5;
let g_plus = exact_gram(psi + h, n, k, &w);
let g_minus = exact_gram(psi - h, n, k, &w);
let dg = tensor.dgram_dpsi(psi);
let dscale = dg.iter().fold(0.0_f64, |a, &v| a.max(v.abs())).max(1e-12);
for ((a, p), m_) in dg.iter().zip(g_plus.iter()).zip(g_minus.iter()) {
let fd = (p - m_) / (2.0 * h);
assert!(
(a - fd).abs() <= 1e-5 * dscale,
"dgram/dpsi mismatch at psi={psi}: analytic={a}, fd={fd}"
);
}
}
assert!(!tensor.contains(psi_hi + 0.5));
assert!(!tensor.contains(psi_lo - 0.5));
for &psi in &[-0.9, 0.2, 0.8] {
let cache = tensor.gaussian_fixed_cache_at(psi);
let design = synth_design(psi, n, k).unwrap();
let mut wd = design.clone();
for (mut row, &wi) in wd.outer_iter_mut().zip(w.iter()) {
row.mapv_inplace(|v| v * wi);
}
let exact_gram = design.t().dot(&wd);
let exact_rhs = wd.t().dot(&z);
let exact_ztwz: f64 = w.iter().zip(z.iter()).map(|(&wi, &zi)| wi * zi * zi).sum();
assert!(
(cache.centered_weighted_y_sq - exact_ztwz).abs()
<= 1e-12 * exact_ztwz.abs().max(1e-300),
"zᵀWz drift: cache={}, exact={exact_ztwz}",
cache.centered_weighted_y_sq
);
let solve = |g: &Array2<f64>, r: &Array1<f64>| -> Array1<f64> {
let mut a = g.clone();
for i in 0..k {
a[[i, i]] += 1.0;
}
let mut aug = Array2::<f64>::zeros((k, k + 1));
aug.slice_mut(ndarray::s![.., ..k]).assign(&a);
aug.slice_mut(ndarray::s![.., k]).assign(r);
for col in 0..k {
let piv = (col..k)
.max_by(|&p, &q| aug[[p, col]].abs().total_cmp(&aug[[q, col]].abs()))
.unwrap();
if piv != col {
for j in 0..=k {
let tmp = aug[[col, j]];
aug[[col, j]] = aug[[piv, j]];
aug[[piv, j]] = tmp;
}
}
let p = aug[[col, col]];
for row in 0..k {
if row == col {
continue;
}
let f = aug[[row, col]] / p;
for j in col..=k {
aug[[row, j]] -= f * aug[[col, j]];
}
}
}
Array1::from_iter((0..k).map(|i| aug[[i, k]] / aug[[i, i]]))
};
let beta_fast = solve(&cache.xtwx_orig, &cache.xtwy_orig);
let beta_exact = solve(&exact_gram, &exact_rhs);
let bscale = beta_exact
.iter()
.fold(0.0_f64, |a, &v| a.max(v.abs()))
.max(1e-300);
for (a, b) in beta_fast.iter().zip(beta_exact.iter()) {
assert!(
(a - b).abs() <= 1e-8 * bscale,
"penalized solve drift at psi={psi}: fast={a}, exact={b}"
);
}
}
}
#[test]
fn psi_gram_tensor_refuses_non_analytic_design() {
let (n, k) = (40usize, 3usize);
let w = Array1::from_elem(n, 1.0);
let z = Array1::from_elem(n, 0.5);
let tensor = PsiGramTensor::build(
|psi| {
let mut x = Array2::<f64>::zeros((n, k));
for i in 0..n {
for j in 0..k {
x[[i, j]] = psi.abs() + (i + j) as f64 / (n + k) as f64;
}
}
Ok(x)
},
w.view(),
z.view(),
-1.0,
1.0,
);
assert!(
tensor.is_none(),
"kinked design must fail the tail-decay/spot-check certificates"
);
}
}