use std::collections::HashMap;
use std::sync::Arc;
use super::grid_spline_2d::{chol_solve, cholesky_logdet};
const OVERLAP: f64 = 2.0;
const H0_FRACTION: f64 = 0.5;
const INITIAL_LEVELS: usize = 3;
const MAX_LEVELS: usize = 16;
const MAX_CENTERS: usize = 200_000;
const REFINE_TOL: f64 = 1e-3;
const DENSE_GRAM_MAX: usize = 1536;
const LOG_LAMBDA_GRID: usize = 25;
const LOG_LAMBDA_LO: f64 = -18.0;
const LOG_LAMBDA_HI: f64 = 18.0;
const LOG_LAMBDA_TOL: f64 = 1e-6;
const CG_RTOL: f64 = 1e-10;
const CG_MAX_ITERS: usize = 4000;
const SLQ_PROBES: usize = 24;
const SLQ_LANCZOS_STEPS: usize = 48;
const RNG_SEED: u64 = 0x1032_CA5C_ADE0_5EED;
const EIG_FLOOR: f64 = 1e-300;
struct SplitMix64(u64);
impl SplitMix64 {
fn new(seed: u64) -> Self {
SplitMix64(seed)
}
fn next_u64(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.0;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
fn next_unit(&mut self) -> f64 {
((self.next_u64() >> 11) as f64 + 0.5) / 9_007_199_254_740_992.0
}
fn next_normal(&mut self) -> f64 {
let u1 = self.next_unit();
let u2 = self.next_unit();
(-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()
}
fn next_sign(&mut self) -> f64 {
if self.next_u64() & 1 == 0 { 1.0 } else { -1.0 }
}
}
#[inline]
fn cell_of(z: &[f64; 3], dim: usize, width: f64) -> (i32, i32, i32) {
let mut c = [0_i32; 3];
for a in 0..dim {
c[a] = (z[a] / width).floor() as i32;
}
(c[0], c[1], c[2])
}
struct HashGrid {
width: f64,
dim: usize,
cells: HashMap<(i32, i32, i32), Vec<u32>>,
}
impl HashGrid {
fn new(width: f64, dim: usize) -> Self {
HashGrid {
width,
dim,
cells: HashMap::new(),
}
}
fn insert(&mut self, idx: u32, z: &[f64; 3]) {
let key = cell_of(z, self.dim, self.width);
self.cells.entry(key).or_default().push(idx);
}
fn for_neighbors(&self, z: &[f64; 3], mut visit: impl FnMut(u32)) {
let (c0, c1, c2) = cell_of(z, self.dim, self.width);
let d2 = if self.dim > 2 { 1 } else { 0 };
let d1 = if self.dim > 1 { 1 } else { 0 };
for i0 in -1..=1_i32 {
for i1 in -d1..=d1 {
for i2 in -d2..=d2 {
if let Some(bucket) = self.cells.get(&(c0 + i0, c1 + i1, c2 + i2)) {
for &idx in bucket {
visit(idx);
}
}
}
}
}
}
}
#[inline]
fn dist2(a: &[f64; 3], b: &[f64; 3], dim: usize) -> f64 {
let mut s = 0.0;
for k in 0..dim {
let d = a[k] - b[k];
s += d * d;
}
s
}
#[inline]
fn wendland(r: f64) -> f64 {
if r >= 1.0 {
return 0.0;
}
let v = 1.0 - r;
let v2 = v * v;
v2 * v2 * (4.0 * r + 1.0)
}
struct Level {
h: f64,
delta: f64,
weight: f64,
centers: Vec<[f64; 3]>,
col_offset: usize,
grid: HashGrid,
}
struct Core {
dim: usize,
metric: [f64; 3],
z_lo: [f64; 3],
z_range: [f64; 3],
sobolev_s: f64,
levels: Vec<Level>,
net: Vec<[f64; 3]>,
m: usize,
row_ptr: Vec<usize>,
col_idx: Vec<u32>,
vals: Vec<f64>,
w: Vec<f64>,
y: Vec<f64>,
z: Vec<[f64; 3]>,
rhs: Vec<f64>,
ytwy: f64,
gram_diag: Vec<f64>,
pen_diag: Vec<f64>,
pen_logdet_const: f64,
dense_gram: Option<Vec<f64>>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum LogdetMethod {
DenseExact,
Slq,
}
#[derive(Clone, Copy, Debug)]
pub struct CascadeCertificate {
pub solve_rel_residual: f64,
pub solve_iters: usize,
pub logdet_method: LogdetMethod,
}
#[derive(Clone, Copy, Debug)]
pub struct RefinementCertificate {
pub next_level_gain_bound: f64,
pub tolerance: f64,
pub exhausted: bool,
}
pub struct ResidualCascadeDesign {
core: Arc<Core>,
}
pub struct ResidualCascadeFit {
core: Arc<Core>,
pub coeff: Vec<f64>,
pub log_lambda: f64,
pub sigma2: f64,
pub restricted_loglik: f64,
pub rss_pen: f64,
pub certificate: CascadeCertificate,
pub refinement: Option<RefinementCertificate>,
}
impl Core {
fn scale_point(&self, x: &[f64]) -> [f64; 3] {
let mut z = [0.0_f64; 3];
for a in 0..self.dim {
z[a] = self.metric[a] * x[a] - self.z_lo[a];
}
z
}
fn basis_row_scaled(&self, z: &[f64; 3]) -> Vec<(usize, f64)> {
let mut row = Vec::with_capacity(self.dim + 1 + self.levels.len() * 8);
row.push((0, 1.0));
for a in 0..self.dim {
row.push((a + 1, 2.0 * z[a] / self.z_range[a] - 1.0));
}
for level in &self.levels {
let start = row.len();
level.grid.for_neighbors(z, |j| {
let c = &level.centers[j as usize];
let r = dist2(z, c, self.dim).sqrt() / level.delta;
let v = wendland(r);
if v > 0.0 {
row.push((level.col_offset + j as usize, v));
}
});
row[start..].sort_unstable_by_key(|&(col, _)| col);
}
row
}
fn matvec(&self, lambda: f64, v: &[f64], out: &mut [f64]) {
for (o, (&d, &x)) in out.iter_mut().zip(self.pen_diag.iter().zip(v.iter())) {
*o = lambda * d * x;
}
for i in 0..self.w.len() {
let lo = self.row_ptr[i];
let hi = self.row_ptr[i + 1];
let mut t = 0.0;
for e in lo..hi {
t += self.vals[e] * v[self.col_idx[e] as usize];
}
t *= self.w[i];
for e in lo..hi {
out[self.col_idx[e] as usize] += self.vals[e] * t;
}
}
}
fn precond_diag(&self, lambda: f64) -> Vec<f64> {
self.gram_diag
.iter()
.zip(self.pen_diag.iter())
.map(|(&g, &d)| g + lambda * d)
.collect()
}
fn pcg(
&self,
lambda: f64,
b: &[f64],
warm: Option<&[f64]>,
) -> Result<(Vec<f64>, f64, usize), String> {
let m = self.m;
let prec = self.precond_diag(lambda);
for (j, &p) in prec.iter().enumerate() {
if !(p.is_finite() && p > EIG_FLOOR) {
return Err(format!(
"residual cascade: non-positive preconditioner diagonal {p} at column {j}"
));
}
}
let b_norm = b.iter().map(|v| v * v).sum::<f64>().sqrt();
if b_norm == 0.0 {
return Ok((vec![0.0; m], 0.0, 0));
}
let mut x = match warm {
Some(x0) => x0.to_vec(),
None => vec![0.0; m],
};
let mut r = vec![0.0; m];
self.matvec(lambda, &x, &mut r);
for (ri, &bi) in r.iter_mut().zip(b.iter()) {
*ri = bi - *ri;
}
let mut zv: Vec<f64> = r.iter().zip(prec.iter()).map(|(&ri, &p)| ri / p).collect();
let mut p_dir = zv.clone();
let mut rz: f64 = r.iter().zip(zv.iter()).map(|(&a, &c)| a * c).sum();
let mut ap = vec![0.0; m];
for iter in 0..CG_MAX_ITERS {
let r_norm = r.iter().map(|v| v * v).sum::<f64>().sqrt();
if r_norm <= CG_RTOL * b_norm {
return Ok((x, r_norm / b_norm, iter));
}
self.matvec(lambda, &p_dir, &mut ap);
let pap: f64 = p_dir.iter().zip(ap.iter()).map(|(&a, &c)| a * c).sum();
if !(pap.is_finite() && pap > 0.0) {
return Err(format!(
"residual cascade: CG curvature breakdown (p'Ap = {pap}) at iteration {iter}"
));
}
let alpha = rz / pap;
for j in 0..m {
x[j] += alpha * p_dir[j];
r[j] -= alpha * ap[j];
}
for j in 0..m {
zv[j] = r[j] / prec[j];
}
let rz_new: f64 = r.iter().zip(zv.iter()).map(|(&a, &c)| a * c).sum();
let beta = rz_new / rz;
rz = rz_new;
for j in 0..m {
p_dir[j] = zv[j] + beta * p_dir[j];
}
}
Err(format!(
"residual cascade: CG failed to reach relative residual {CG_RTOL} within \
{CG_MAX_ITERS} iterations (the norm-equivalence preconditioner should make this \
n-independent; this indicates a degenerate design)"
))
}
fn dense_system(&self, lambda: f64) -> Option<Vec<f64>> {
let gram = self.dense_gram.as_ref()?;
let m = self.m;
let mut a = vec![0.0; m * m];
for i in 0..m {
for j in i..m {
let mut v = gram[i * m + j];
if i == j {
v += lambda * self.pen_diag[i];
}
a[i * m + j] = v;
a[j * m + i] = v;
}
}
Some(a)
}
fn logdet_dense(&self, lambda: f64) -> Result<f64, String> {
let mut a = self.dense_system(lambda).ok_or_else(|| {
format!(
"residual cascade: dense logdet requested past the sizing cap \
(m = {} > {DENSE_GRAM_MAX})",
self.m
)
})?;
cholesky_logdet(&mut a, self.m)
}
fn logdet_slq(&self, lambda: f64) -> Result<f64, String> {
let m = self.m;
let prec = self.precond_diag(lambda);
let mut logdet = 0.0;
for (j, &p) in prec.iter().enumerate() {
if !(p.is_finite() && p > EIG_FLOOR) {
return Err(format!(
"residual cascade: non-positive diagonal {p} at column {j} in SLQ"
));
}
logdet += p.ln();
}
let sqrt_p: Vec<f64> = prec.iter().map(|&p| p.sqrt()).collect();
let mut scratch_in = vec![0.0; m];
let mut scratch_out = vec![0.0; m];
let mut trace_est = 0.0;
let steps = SLQ_LANCZOS_STEPS.min(m);
let mut basis: Vec<Vec<f64>> = Vec::with_capacity(steps);
for probe in 0..SLQ_PROBES {
let mut rng =
SplitMix64::new(RNG_SEED ^ (probe as u64).wrapping_mul(0xD134_2543_DE82_EF95));
let mut q = vec![0.0; m];
for qj in q.iter_mut() {
*qj = rng.next_sign();
}
let z_norm2 = m as f64;
let inv_norm = 1.0 / (m as f64).sqrt();
for qj in q.iter_mut() {
*qj *= inv_norm;
}
basis.clear();
let mut alpha = Vec::with_capacity(steps);
let mut beta: Vec<f64> = Vec::with_capacity(steps);
let mut q_prev: Option<Vec<f64>> = None;
for _step in 0..steps {
for j in 0..m {
scratch_in[j] = q[j] / sqrt_p[j];
}
self.matvec(lambda, &scratch_in, &mut scratch_out);
let mut v: Vec<f64> = (0..m).map(|j| scratch_out[j] / sqrt_p[j]).collect();
let a: f64 = v.iter().zip(q.iter()).map(|(&x, &y)| x * y).sum();
alpha.push(a);
for j in 0..m {
v[j] -= a * q[j];
}
if let Some(prev) = &q_prev {
let b_prev = beta.last().copied().unwrap_or(0.0);
for j in 0..m {
v[j] -= b_prev * prev[j];
}
}
basis.push(q.clone());
for qb in &basis {
let proj: f64 = v.iter().zip(qb.iter()).map(|(&x, &y)| x * y).sum();
for j in 0..m {
v[j] -= proj * qb[j];
}
}
let b: f64 = v.iter().map(|x| x * x).sum::<f64>().sqrt();
if !(b.is_finite()) {
return Err("residual cascade: Lanczos breakdown (non-finite norm)".into());
}
if b < 1e-13 {
break;
}
beta.push(b);
q_prev = Some(std::mem::replace(&mut q, v));
for qj in q.iter_mut() {
*qj /= b;
}
}
beta.truncate(alpha.len().saturating_sub(1));
let (theta, tau) = symmetric_tridiagonal_eigen(&alpha, &beta)?;
let mut quad = 0.0;
for (&t, &w0) in theta.iter().zip(tau.iter()) {
if !(t.is_finite() && t > EIG_FLOOR) {
return Err(format!(
"residual cascade: non-positive Ritz value {t} in SLQ (system not PD)"
));
}
quad += w0 * w0 * t.ln();
}
trace_est += z_norm2 * quad;
}
Ok(logdet + trace_est / SLQ_PROBES as f64)
}
fn logdet(&self, lambda: f64) -> Result<(f64, LogdetMethod), String> {
if self.dense_gram.is_some() {
Ok((self.logdet_dense(lambda)?, LogdetMethod::DenseExact))
} else {
Ok((self.logdet_slq(lambda)?, LogdetMethod::Slq))
}
}
fn solve_coeff(
&self,
lambda: f64,
b: &[f64],
warm: Option<&[f64]>,
) -> Result<(Vec<f64>, f64, usize), String> {
if let Some(mut a) = self.dense_system(lambda) {
cholesky_logdet(&mut a, self.m)?;
return Ok((chol_solve(&a, self.m, b), 0.0, 0));
}
self.pcg(lambda, b, warm)
}
fn rss_pen(&self, coeff: &[f64]) -> f64 {
let mut quad = 0.0;
for (c, r) in coeff.iter().zip(self.rhs.iter()) {
quad += c * r;
}
self.ytwy - quad
}
fn nullity(&self) -> usize {
self.dim + 1
}
fn residuals(&self, coeff: &[f64]) -> Vec<f64> {
let n = self.y.len();
let mut r = Vec::with_capacity(n);
for i in 0..n {
let mut fit = 0.0;
for e in self.row_ptr[i]..self.row_ptr[i + 1] {
fit += self.vals[e] * coeff[self.col_idx[e] as usize];
}
r.push(self.y[i] - fit);
}
r
}
}
fn symmetric_tridiagonal_eigen(d: &[f64], e: &[f64]) -> Result<(Vec<f64>, Vec<f64>), String> {
let n = d.len();
if n == 0 {
return Ok((Vec::new(), Vec::new()));
}
let mut diag = d.to_vec();
let mut off = vec![0.0; n];
off[..n - 1].copy_from_slice(&e[..n - 1]);
let mut first = vec![0.0; n];
first[0] = 1.0;
for l in 0..n {
let mut iter = 0;
loop {
let mut msplit = n - 1;
for mm in l..n - 1 {
let dd = diag[mm].abs() + diag[mm + 1].abs();
if off[mm].abs() <= f64::EPSILON * dd {
msplit = mm;
break;
}
}
if msplit == l {
break;
}
iter += 1;
if iter > 60 {
return Err("residual cascade: tridiagonal QL failed to converge".into());
}
let mut g = (diag[l + 1] - diag[l]) / (2.0 * off[l]);
let mut r = g.hypot(1.0);
g = diag[msplit] - diag[l] + off[l] / (g + r.copysign(g));
let (mut s, mut c) = (1.0, 1.0);
let mut p = 0.0;
let mut broke_early = false;
for i in (l..msplit).rev() {
let mut f = s * off[i];
let b = c * off[i];
r = f.hypot(g);
off[i + 1] = r;
if r == 0.0 {
diag[i + 1] -= p;
off[msplit] = 0.0;
broke_early = true;
break;
}
s = f / r;
c = g / r;
g = diag[i + 1] - p;
r = (diag[i] - g) * s + 2.0 * c * b;
p = s * r;
diag[i + 1] = g + p;
g = c * r - b;
f = first[i + 1];
first[i + 1] = s * first[i] + c * f;
first[i] = c * first[i] - s * f;
}
if broke_early {
continue;
}
diag[l] -= p;
off[l] = g;
off[msplit] = 0.0;
}
}
Ok((diag, first))
}
fn extend_net(net: &mut Vec<[f64; 3]>, points: &[[f64; 3]], dim: usize, h: f64) -> Vec<[f64; 3]> {
let mut grid = HashGrid::new(h, dim);
for (idx, c) in net.iter().enumerate() {
grid.insert(idx as u32, c);
}
let h2 = h * h;
let mut new_centers = Vec::new();
for p in points {
let mut covered = false;
grid.for_neighbors(p, |j| {
if !covered && dist2(p, &net[j as usize], dim) <= h2 {
covered = true;
}
});
if !covered {
let idx = net.len() as u32;
net.push(*p);
grid.insert(idx, p);
new_centers.push(*p);
}
}
new_centers
}
impl ResidualCascadeDesign {
pub fn build(
xs: &[&[f64]],
y: &[f64],
w: &[f64],
metric: &[f64],
sobolev_s: f64,
levels: usize,
) -> Result<Self, String> {
let dim = xs.len();
if !(dim == 2 || dim == 3) {
return Err(format!(
"residual cascade: built for scattered 2-3D smooths, got {dim} axes"
));
}
let n = y.len();
if w.len() != n || xs.iter().any(|x| x.len() != n) {
return Err(format!(
"residual cascade: length mismatch (y={n}, w={}, axes={:?})",
w.len(),
xs.iter().map(|x| x.len()).collect::<Vec<_>>()
));
}
if n <= dim + 1 {
return Err(format!(
"residual cascade: needs more than {} rows for the profiled REML degrees of \
freedom, got {n}",
dim + 1
));
}
if metric.len() != dim || metric.iter().any(|&s| !(s.is_finite() && s > 0.0)) {
return Err(format!(
"residual cascade: metric must be {dim} finite positive scales, got {metric:?}"
));
}
if !(sobolev_s > dim as f64 / 2.0 && sobolev_s <= (dim as f64 + 3.0) / 2.0) {
return Err(format!(
"residual cascade: sobolev_s must lie in (d/2, (d+3)/2] = ({}, {}] for the \
Wendland-(3,1) bump, got {sobolev_s}",
dim as f64 / 2.0,
(dim as f64 + 3.0) / 2.0
));
}
if levels == 0 || levels > MAX_LEVELS {
return Err(format!(
"residual cascade: levels must be in 1..={MAX_LEVELS}, got {levels}"
));
}
for i in 0..n {
if !(y[i].is_finite() && w[i].is_finite() && w[i] > 0.0)
|| xs.iter().any(|x| !x[i].is_finite())
{
return Err(format!(
"residual cascade: non-finite or non-positive input at row {i}"
));
}
}
let mut z_lo = [f64::INFINITY; 3];
let mut z_hi = [f64::NEG_INFINITY; 3];
for a in 0..dim {
for &v in xs[a] {
let s = metric[a] * v;
z_lo[a] = z_lo[a].min(s);
z_hi[a] = z_hi[a].max(s);
}
}
let mut z_range = [1.0_f64; 3];
let mut max_range = 0.0_f64;
for a in 0..dim {
if !(z_hi[a] > z_lo[a]) {
return Err(format!(
"residual cascade: degenerate axis {a} bounding box [{}, {}]",
z_lo[a], z_hi[a]
));
}
z_range[a] = z_hi[a] - z_lo[a];
max_range = max_range.max(z_range[a]);
}
for a in dim..3 {
z_lo[a] = 0.0;
}
let z: Vec<[f64; 3]> = (0..n)
.map(|i| {
let mut p = [0.0_f64; 3];
for a in 0..dim {
p[a] = metric[a] * xs[a][i] - z_lo[a];
}
p
})
.collect();
let mut metric3 = [1.0_f64; 3];
metric3[..dim].copy_from_slice(metric);
let h0 = H0_FRACTION * max_range;
let mut net: Vec<[f64; 3]> = Vec::new();
let mut level_specs = Vec::with_capacity(levels);
let mut col = dim + 1;
let mut pen_logdet_const = 0.0;
for l in 0..levels {
let h = h0 * 0.5_f64.powi(l as i32);
let new_centers = extend_net(&mut net, &z, dim, h);
if net.len() > MAX_CENTERS {
return Err(format!(
"residual cascade: center cap {MAX_CENTERS} exceeded at level {l}"
));
}
let weight = level_weight(l, sobolev_s, dim);
pen_logdet_const += new_centers.len() as f64 * weight.ln();
let delta = OVERLAP * h;
let mut grid = HashGrid::new(delta, dim);
for (j, c) in new_centers.iter().enumerate() {
grid.insert(j as u32, c);
}
let col_offset = col;
col += new_centers.len();
level_specs.push(Level {
h,
delta,
weight,
centers: new_centers,
col_offset,
grid,
});
}
let m = col;
let mut row_ptr = Vec::with_capacity(n + 1);
row_ptr.push(0_usize);
let mut col_idx: Vec<u32> = Vec::new();
let mut vals: Vec<f64> = Vec::new();
let mut rhs = vec![0.0_f64; m];
let mut gram_diag = vec![0.0_f64; m];
let mut ytwy = 0.0_f64;
let probe_core = CoreScaffold {
dim,
z_range,
levels: &level_specs,
};
for i in 0..n {
let row = probe_core.basis_row(&z[i]);
for &(c, v) in &row {
col_idx.push(c as u32);
vals.push(v);
rhs[c] += w[i] * y[i] * v;
gram_diag[c] += w[i] * v * v;
}
ytwy += w[i] * y[i] * y[i];
row_ptr.push(col_idx.len());
}
let mut pen_diag = vec![0.0_f64; m];
for level in &level_specs {
for j in 0..level.centers.len() {
pen_diag[level.col_offset + j] = level.weight;
}
}
let dense_gram = if m <= DENSE_GRAM_MAX {
let mut gram = vec![0.0_f64; m * m];
for i in 0..n {
let lo = row_ptr[i];
let hi = row_ptr[i + 1];
for ea in lo..hi {
let ca = col_idx[ea] as usize;
let va = w[i] * vals[ea];
for eb in ea..hi {
gram[ca * m + col_idx[eb] as usize] += va * vals[eb];
}
}
}
Some(gram)
} else {
None
};
Ok(ResidualCascadeDesign {
core: Arc::new(Core {
dim,
metric: metric3,
z_lo,
z_range,
sobolev_s,
levels: level_specs,
net,
m,
row_ptr,
col_idx,
vals,
w: w.to_vec(),
y: y.to_vec(),
z,
rhs,
ytwy,
gram_diag,
pen_diag,
pen_logdet_const,
dense_gram,
}),
})
}
pub fn num_levels(&self) -> usize {
self.core.levels.len()
}
pub fn num_coeffs(&self) -> usize {
self.core.m
}
pub fn num_centers(&self) -> usize {
self.core.m - self.core.nullity()
}
pub fn centers(&self, level: usize) -> Vec<Vec<f64>> {
let lv = &self.core.levels[level];
lv.centers
.iter()
.map(|c| {
(0..self.core.dim)
.map(|a| (c[a] + self.core.z_lo[a]) / self.core.metric[a])
.collect()
})
.collect()
}
pub fn support_radius(&self, level: usize) -> f64 {
self.core.levels[level].delta
}
pub fn level_weight(&self, level: usize) -> f64 {
self.core.levels[level].weight
}
pub fn basis_row(&self, x: &[f64]) -> Result<Vec<(usize, f64)>, String> {
self.check_point(x)?;
Ok(self.core.basis_row_scaled(&self.core.scale_point(x)))
}
fn check_point(&self, x: &[f64]) -> Result<(), String> {
if x.len() != self.core.dim || x.iter().any(|v| !v.is_finite()) {
return Err(format!(
"residual cascade: point must be {} finite coordinates, got {x:?}",
self.core.dim
));
}
Ok(())
}
pub fn penalty_value(&self, coeff: &[f64]) -> Result<f64, String> {
if coeff.len() != self.core.m {
return Err(format!(
"residual cascade: coefficient length {} != {}",
coeff.len(),
self.core.m
));
}
Ok(coeff
.iter()
.zip(self.core.pen_diag.iter())
.map(|(&c, &d)| d * c * c)
.sum())
}
pub fn logdet_exact(&self, log_lambda: f64) -> Result<f64, String> {
self.core.logdet_dense(log_lambda.exp())
}
pub fn logdet_slq(&self, log_lambda: f64) -> Result<f64, String> {
self.core.logdet_slq(log_lambda.exp())
}
pub fn criterion(&self, log_lambda: f64) -> Result<f64, String> {
if !log_lambda.is_finite() {
return Err(format!(
"residual cascade: non-finite log lambda {log_lambda}"
));
}
let core = &self.core;
let lambda = log_lambda.exp();
let (coeff, _, _) = core.solve_coeff(lambda, &core.rhs, None)?;
let rss_pen = core.rss_pen(&coeff);
if !(rss_pen > 0.0) {
return Err(format!(
"residual cascade: degenerate penalized residual {rss_pen}"
));
}
let (logdet, _) = core.logdet(lambda)?;
let dof = (core.y.len() - core.nullity()) as f64;
let r = (core.m - core.nullity()) as f64;
let sigma2 = rss_pen / dof;
Ok(-0.5 * (logdet - r * log_lambda - core.pen_logdet_const + dof * sigma2.ln()))
}
pub fn fit_at(
&self,
log_lambda: f64,
sigma2: Option<f64>,
) -> Result<ResidualCascadeFit, String> {
if !log_lambda.is_finite() {
return Err(format!(
"residual cascade: non-finite log lambda {log_lambda}"
));
}
let core = &self.core;
let lambda = log_lambda.exp();
let (coeff, rel_res, iters) = core.solve_coeff(lambda, &core.rhs, None)?;
let rss_pen = core.rss_pen(&coeff);
let dof = (core.y.len() - core.nullity()) as f64;
let sigma2 = match sigma2 {
Some(s) => {
if !(s.is_finite() && s > 0.0) {
return Err(format!("residual cascade: invalid sigma2 {s}"));
}
s
}
None => {
if !(rss_pen > 0.0) {
return Err(format!(
"residual cascade: degenerate penalized residual {rss_pen}"
));
}
rss_pen / dof
}
};
let (logdet, logdet_method) = core.logdet(lambda)?;
let r = (core.m - core.nullity()) as f64;
let restricted_loglik = -0.5
* (logdet - r * log_lambda - core.pen_logdet_const
+ dof * sigma2.ln()
+ rss_pen / sigma2);
Ok(ResidualCascadeFit {
core: Arc::clone(&self.core),
coeff,
log_lambda,
sigma2,
restricted_loglik,
rss_pen,
certificate: CascadeCertificate {
solve_rel_residual: rel_res,
solve_iters: iters,
logdet_method,
},
refinement: None,
})
}
pub fn fit_reml(&self) -> Result<ResidualCascadeFit, String> {
let mut best_i = 0usize;
let mut best_v = f64::NEG_INFINITY;
let step = (LOG_LAMBDA_HI - LOG_LAMBDA_LO) / (LOG_LAMBDA_GRID - 1) as f64;
for i in 0..LOG_LAMBDA_GRID {
let ll = LOG_LAMBDA_LO + step * i as f64;
let v = self.criterion(ll)?;
if v > best_v {
best_v = v;
best_i = i;
}
}
let mut lo = LOG_LAMBDA_LO + step * best_i.saturating_sub(1) as f64;
let mut hi = (LOG_LAMBDA_LO + step * (best_i + 1) as f64).min(LOG_LAMBDA_HI);
let inv_phi = 0.618_033_988_749_894_9_f64;
let mut x1 = hi - inv_phi * (hi - lo);
let mut x2 = lo + inv_phi * (hi - lo);
let mut f1 = self.criterion(x1)?;
let mut f2 = self.criterion(x2)?;
while hi - lo > LOG_LAMBDA_TOL {
if f1 < f2 {
lo = x1;
x1 = x2;
f1 = f2;
x2 = lo + inv_phi * (hi - lo);
f2 = self.criterion(x2)?;
} else {
hi = x2;
x2 = x1;
f2 = f1;
x1 = hi - inv_phi * (hi - lo);
f1 = self.criterion(x1)?;
}
}
self.fit_at(0.5 * (lo + hi), None)
}
pub fn next_level_gain_bound(&self, fit: &ResidualCascadeFit) -> Result<Option<f64>, String> {
let core = &self.core;
if !Arc::ptr_eq(core, &fit.core) {
return Err("residual cascade: fit does not belong to this design".into());
}
let next_l = core.levels.len();
if next_l >= MAX_LEVELS {
return Ok(None);
}
let h = core.levels[next_l - 1].h * 0.5;
let mut net = core.net.clone();
let candidates = extend_net(&mut net, &core.z, core.dim, h);
if candidates.is_empty() || net.len() > MAX_CENTERS {
return Ok(None);
}
let delta = OVERLAP * h;
let mut grid = HashGrid::new(delta, core.dim);
for (j, c) in candidates.iter().enumerate() {
grid.insert(j as u32, c);
}
let r = core.residuals(&fit.coeff);
let mut g = vec![0.0_f64; candidates.len()];
for (i, zi) in core.z.iter().enumerate() {
let wr = core.w[i] * r[i];
grid.for_neighbors(zi, |j| {
let rad = dist2(zi, &candidates[j as usize], core.dim).sqrt() / delta;
g[j as usize] += wr * wendland(rad);
});
}
let g2: f64 = g.iter().map(|v| v * v).sum();
let d_next = level_weight(next_l, core.sobolev_s, core.dim);
Ok(Some(g2 / (fit.log_lambda.exp() * d_next)))
}
}
fn level_weight(l: usize, sobolev_s: f64, dim: usize) -> f64 {
(4.0_f64).powf(l as f64 * (sobolev_s - dim as f64 / 2.0))
}
struct CoreScaffold<'a> {
dim: usize,
z_range: [f64; 3],
levels: &'a [Level],
}
impl CoreScaffold<'_> {
fn basis_row(&self, z: &[f64; 3]) -> Vec<(usize, f64)> {
let mut row = Vec::with_capacity(self.dim + 1 + self.levels.len() * 8);
row.push((0, 1.0));
for a in 0..self.dim {
row.push((a + 1, 2.0 * z[a] / self.z_range[a] - 1.0));
}
for level in self.levels {
let start = row.len();
level.grid.for_neighbors(z, |j| {
let c = &level.centers[j as usize];
let r = dist2(z, c, self.dim).sqrt() / level.delta;
let v = wendland(r);
if v > 0.0 {
row.push((level.col_offset + j as usize, v));
}
});
row[start..].sort_unstable_by_key(|&(col, _)| col);
}
row
}
}
impl ResidualCascadeFit {
pub fn predict(&self, x: &[f64]) -> Result<(f64, f64), String> {
let core = &self.core;
if x.len() != core.dim || x.iter().any(|v| !v.is_finite()) {
return Err(format!(
"residual cascade: prediction point must be {} finite coordinates, got {x:?}",
core.dim
));
}
let row = core.basis_row_scaled(&core.scale_point(x));
let mut mean = 0.0;
let mut dense_row = vec![0.0_f64; core.m];
for &(c, v) in &row {
mean += v * self.coeff[c];
dense_row[c] += v;
}
let lambda = self.log_lambda.exp();
let (zsol, _, _) = core.solve_coeff(lambda, &dense_row, None)?;
let mut quad = 0.0;
for (a, b) in dense_row.iter().zip(zsol.iter()) {
quad += a * b;
}
Ok((mean, self.sigma2 * quad))
}
pub fn sample_coefficients(&self, n_samples: usize) -> Result<Vec<Vec<f64>>, String> {
let core = &self.core;
let lambda = self.log_lambda.exp();
let sigma = self.sigma2.sqrt();
let sqrt_lambda = lambda.sqrt();
let n = core.y.len();
let mut rng = SplitMix64::new(RNG_SEED ^ 0xA11C_E5A_u64);
let mut samples = Vec::with_capacity(n_samples);
for _ in 0..n_samples {
let mut b = core.rhs.clone();
for i in 0..n {
let f = sigma * core.w[i].sqrt() * rng.next_normal();
for e in core.row_ptr[i]..core.row_ptr[i + 1] {
b[core.col_idx[e] as usize] += f * core.vals[e];
}
}
for (bj, &dj) in b.iter_mut().zip(core.pen_diag.iter()) {
if dj > 0.0 {
*bj += sigma * sqrt_lambda * dj.sqrt() * rng.next_normal();
}
}
let (c, _, _) = core.solve_coeff(lambda, &b, Some(&self.coeff))?;
samples.push(c);
}
Ok(samples)
}
pub fn num_levels(&self) -> usize {
self.core.levels.len()
}
pub fn num_coeffs(&self) -> usize {
self.core.m
}
}
pub fn fit_residual_cascade(
xs: &[&[f64]],
y: &[f64],
w: &[f64],
metric: &[f64],
sobolev_s: f64,
) -> Result<ResidualCascadeFit, String> {
let mut levels = INITIAL_LEVELS;
loop {
let design = ResidualCascadeDesign::build(xs, y, w, metric, sobolev_s, levels)?;
let mut fit = design.fit_reml()?;
let gain = design.next_level_gain_bound(&fit)?;
let tolerance = REFINE_TOL * fit.rss_pen;
match gain {
None => {
fit.refinement = Some(RefinementCertificate {
next_level_gain_bound: 0.0,
tolerance,
exhausted: true,
});
return Ok(fit);
}
Some(bound) if bound <= tolerance || levels >= MAX_LEVELS => {
fit.refinement = Some(RefinementCertificate {
next_level_gain_bound: bound,
tolerance,
exhausted: bound > tolerance,
});
return Ok(fit);
}
Some(_) => {
levels += 1;
}
}
}
}