use crate::solver::psi_gram_tensor::PsiGramTensor;
use ndarray::{Array1, Array2, ArrayView1};
pub struct FrozenWeightGramTensor {
inner: PsiGramTensor,
frozen_w: Array1<f64>,
}
impl FrozenWeightGramTensor {
pub fn build(
mut eval_design: impl FnMut(f64) -> Result<Array2<f64>, String>,
frozen_w: ArrayView1<'_, f64>,
working_z: ArrayView1<'_, f64>,
psi_lo: f64,
psi_hi: f64,
) -> Option<Self> {
let n = frozen_w.len();
if n == 0 || working_z.len() != n {
return None;
}
if frozen_w.iter().any(|&w| !w.is_finite() || w < 0.0) {
return None;
}
if working_z.iter().any(|&z| !z.is_finite()) {
return None;
}
let sqrt_w: Array1<f64> = frozen_w.mapv(f64::sqrt);
let weighted_z: Array1<f64> =
Array1::from_iter(working_z.iter().zip(sqrt_w.iter()).map(|(&z, &s)| z * s));
let unit_weights: Array1<f64> = Array1::ones(n);
let sqrt_w_closure = sqrt_w.clone();
let weighted_eval = move |psi: f64| -> Result<Array2<f64>, String> {
let mut design = eval_design(psi)?;
if design.nrows() != sqrt_w_closure.len() {
return Err(format!(
"frozen-W tensor: design has {} rows, expected {}",
design.nrows(),
sqrt_w_closure.len()
));
}
for (mut row, &s) in design.outer_iter_mut().zip(sqrt_w_closure.iter()) {
row.mapv_inplace(|v| v * s);
}
Ok(design)
};
let inner = PsiGramTensor::build(
weighted_eval,
unit_weights.view(),
weighted_z.view(),
psi_lo,
psi_hi,
)
.ok()?;
Some(Self {
inner,
frozen_w: frozen_w.to_owned(),
})
}
pub fn contains(&self, psi: f64) -> bool {
self.inner.contains(psi)
}
pub fn contains_for_gradient(&self, psi: f64) -> bool {
self.inner.contains_for_gradient(psi)
}
pub fn gram_at(&self, psi: f64) -> Array2<f64> {
self.inner.gram_at(psi)
}
pub fn rhs_at(&self, psi: f64) -> Array1<f64> {
self.inner.rhs_at(psi)
}
pub fn dgram_dpsi(&self, psi: f64) -> Array2<f64> {
self.inner.dgram_dpsi(psi)
}
pub fn drhs_dpsi(&self, psi: f64) -> Array1<f64> {
self.inner.drhs_dpsi(psi)
}
pub fn d2gram_dpsi2(&self, psi: f64) -> Array2<f64> {
self.inner.d2gram_dpsi2(psi)
}
pub fn d2rhs_dpsi2(&self, psi: f64) -> Array1<f64> {
self.inner.d2rhs_dpsi2(psi)
}
pub fn frozen_weights(&self) -> ArrayView1<'_, f64> {
self.frozen_w.view()
}
pub fn gradient_pair_if_sound(
&self,
psi: f64,
w_trial: ArrayView1<'_, f64>,
) -> Option<(Array2<f64>, Array1<f64>)> {
if !self.contains_for_gradient(psi) {
return None;
}
if !self.weight_drift_within(w_trial, Self::GRADIENT_WEIGHT_DRIFT_RTOL) {
return None;
}
Some((self.dgram_dpsi(psi), self.drhs_dpsi(psi)))
}
pub const GRADIENT_WEIGHT_DRIFT_RTOL: f64 = 1.0e-9;
pub fn weight_drift_within(&self, w_trial: ArrayView1<'_, f64>, rtol: f64) -> bool {
if w_trial.len() != self.frozen_w.len() || !(rtol.is_finite() && rtol > 0.0) {
return false;
}
let w_scale = self
.frozen_w
.iter()
.fold(0.0_f64, |acc, &w| acc.max(w.abs()))
.max(1e-300);
for (&wt, &w0) in w_trial.iter().zip(self.frozen_w.iter()) {
if !wt.is_finite() {
return false;
}
if (wt - w0).abs() > rtol * w_scale {
return false;
}
}
true
}
}
pub mod endgame_path {}
#[cfg(test)]
mod tests {
use super::*;
fn synth_design(psi: f64, n: usize, k: usize) -> Result<Array2<f64>, String> {
let mut x = Array2::<f64>::zeros((n, k));
for i in 0..n {
for j in 0..k {
let r = 0.05 + (i as f64 + 1.0) * (j as f64 + 1.0) / (n as f64 * k as f64) * 3.0;
if j == k - 1 {
x[[i, j]] = r * r * r;
} else {
let s = r * psi.exp();
x[[i, j]] = (1.0 + s) * (-s).exp();
}
}
}
Ok(x)
}
fn frozen_weights(n: usize) -> Array1<f64> {
Array1::from_shape_fn(n, |i| {
let p = 0.1 + 0.8 * ((i as f64 + 0.5) / n as f64);
p * (1.0 - p)
})
}
fn working_z(n: usize) -> Array1<f64> {
Array1::from_shape_fn(n, |i| ((i as f64 * 0.37).sin()) + 0.5)
}
fn exact_weighted_gram(psi: f64, n: usize, k: usize, w: &Array1<f64>) -> Array2<f64> {
let design = synth_design(psi, n, k).unwrap();
let mut wd = design.clone();
for (mut row, &wi) in wd.outer_iter_mut().zip(w.iter()) {
row.mapv_inplace(|v| v * wi);
}
design.t().dot(&wd)
}
fn exact_weighted_xty(
psi: f64,
n: usize,
k: usize,
w: &Array1<f64>,
z: &Array1<f64>,
) -> Array1<f64> {
let design = synth_design(psi, n, k).unwrap();
let mut out = Array1::<f64>::zeros(k);
for i in 0..n {
for j in 0..k {
out[j] += design[[i, j]] * w[i] * z[i];
}
}
out
}
#[test]
fn frozen_w_gram_matches_exact_weighted_rebuild() {
let (n, k) = (200usize, 5usize);
let (psi_lo, psi_hi) = (-0.6, 0.6);
let w = frozen_weights(n);
let z = working_z(n);
let tensor = FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
w.view(),
z.view(),
psi_lo,
psi_hi,
)
.expect("frozen-W tensor must certify on the analytic Matérn-shaped design");
for &frac in &[0.137_f64, 0.382, 0.618, 0.851] {
let psi = psi_lo + frac * (psi_hi - psi_lo);
assert!(tensor.contains(psi));
let assembled = tensor.gram_at(psi);
let exact = exact_weighted_gram(psi, n, k, &w);
let scale = exact
.iter()
.fold(0.0_f64, |a, &v| a.max(v.abs()))
.max(1e-300);
for (a, b) in assembled.iter().zip(exact.iter()) {
assert!(
(a - b).abs() <= 1e-9 * scale,
"XᵀWX(ψ={psi}) tensor vs exact frozen-W rebuild off by {}",
(a - b).abs()
);
}
let rhs = tensor.rhs_at(psi);
let exact_rhs = exact_weighted_xty(psi, n, k, &w, &z);
let rscale = exact_rhs
.iter()
.fold(0.0_f64, |a, &v| a.max(v.abs()))
.max(1e-300);
for (a, b) in rhs.iter().zip(exact_rhs.iter()) {
assert!(
(a - b).abs() <= 1e-9 * rscale,
"XᵀWz(ψ={psi}) tensor vs exact off by {}",
(a - b).abs()
);
}
}
}
#[test]
fn frozen_w_dgram_matches_finite_difference() {
let (n, k) = (160usize, 4usize);
let (psi_lo, psi_hi) = (-0.5, 0.5);
let w = frozen_weights(n);
let z = working_z(n);
let tensor = FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
w.view(),
z.view(),
psi_lo,
psi_hi,
)
.expect("frozen-W tensor must certify");
let psi = 0.0;
if !tensor.contains_for_gradient(psi) {
panic!("gradient sub-window must certify at the window center");
}
let analytic = tensor.dgram_dpsi(psi);
let h = 1e-4;
let g = |p: f64| exact_weighted_gram(p, n, k, &w);
let fd = (g(psi - 2.0 * h) - 8.0 * &g(psi - h) + 8.0 * &g(psi + h) - g(psi + 2.0 * h))
/ (12.0 * h);
let scale = fd.iter().fold(0.0_f64, |a, &v| a.max(v.abs())).max(1e-300);
for (a, b) in analytic.iter().zip(fd.iter()) {
assert!(
(a - b).abs() <= 1e-7 * scale,
"∂(XᵀWX)/∂ψ analytic vs FD off by {}",
(a - b).abs()
);
}
}
#[test]
fn weight_drift_guard_accepts_warm_and_rejects_far() {
let (n, k) = (120usize, 4usize);
let w = frozen_weights(n);
let z = working_z(n);
let tensor = FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
w.view(),
z.view(),
-0.4,
0.4,
)
.expect("certify");
assert!(tensor.weight_drift_within(w.view(), 1e-6));
let w_near: Array1<f64> = w.mapv(|v| v * (1.0 + 1e-4));
assert!(tensor.weight_drift_within(w_near.view(), 1e-3));
let w_far: Array1<f64> = w.mapv(|v| v * 2.0);
assert!(!tensor.weight_drift_within(w_far.view(), 1e-2));
let mut w_bad = w.clone();
w_bad[0] = f64::NAN;
assert!(!tensor.weight_drift_within(w_bad.view(), 1e-1));
let w_short = Array1::<f64>::ones(n - 1);
assert!(!tensor.weight_drift_within(w_short.view(), 1e-1));
}
#[test]
fn gradient_pair_sound_only_in_window_and_under_tight_drift() {
let (n, k) = (160usize, 4usize);
let (psi_lo, psi_hi) = (-0.5, 0.5);
let w = frozen_weights(n);
let z = working_z(n);
let tensor = FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
w.view(),
z.view(),
psi_lo,
psi_hi,
)
.expect("frozen-W tensor must certify");
let psi = 0.0;
assert!(
tensor.contains_for_gradient(psi),
"interior must certify for the gradient lane"
);
let pair = tensor
.gradient_pair_if_sound(psi, w.view())
.expect("zero-drift in-window trial must serve the gradient n-free");
let dg = tensor.dgram_dpsi(psi);
let db = tensor.drhs_dpsi(psi);
for (a, b) in pair.0.iter().zip(dg.iter()) {
assert_eq!(a, b, "gradient_pair ∂G/∂ψ must be the raw derivative");
}
for (a, b) in pair.1.iter().zip(db.iter()) {
assert_eq!(a, b, "gradient_pair ∂b/∂ψ must be the raw derivative");
}
let w_value_grade: Array1<f64> = w.mapv(|v| v * (1.0 + 1e-6)); assert!(
tensor.weight_drift_within(w_value_grade.view(), 1e-3),
"value-grade gate must accept a 1e-6 drift"
);
assert!(
tensor
.gradient_pair_if_sound(psi, w_value_grade.view())
.is_none(),
"gradient lane must refuse a drift past GRADIENT_WEIGHT_DRIFT_RTOL"
);
let w_grad_grade: Array1<f64> = w.mapv(|v| v * (1.0 + 1e-11));
assert!(
tensor
.gradient_pair_if_sound(psi, w_grad_grade.view())
.is_some(),
"gradient lane must accept a drift within GRADIENT_WEIGHT_DRIFT_RTOL"
);
let psi_edge = psi_hi - 1e-6;
if !tensor.contains_for_gradient(psi_edge) {
assert!(
tensor.gradient_pair_if_sound(psi_edge, w.view()).is_none(),
"near-edge ψ outside the gradient sub-window must refuse the lane"
);
}
}
#[test]
fn rejects_degenerate_and_nonfinite_inputs() {
let (n, k) = (50usize, 3usize);
let w = frozen_weights(n);
let z = working_z(n);
assert!(
FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
Array1::<f64>::zeros(0).view(),
Array1::<f64>::zeros(0).view(),
-0.3,
0.3,
)
.is_none()
);
let mut w_neg = w.clone();
w_neg[1] = -0.1;
assert!(
FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
w_neg.view(),
z.view(),
-0.3,
0.3,
)
.is_none()
);
assert!(
FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
w.view(),
Array1::<f64>::zeros(n - 1).view(),
-0.3,
0.3,
)
.is_none()
);
assert!(
FrozenWeightGramTensor::build(
|psi| synth_design(psi, n, k),
w.view(),
z.view(),
0.3,
0.3,
)
.is_none()
);
}
}