const PENALTY_NULLITY: usize = 3;
const LOG_LAMBDA_GRID: usize = 25;
const LOG_LAMBDA_LO: f64 = -18.0;
const LOG_LAMBDA_HI: f64 = 18.0;
const LOG_LAMBDA_TOL: f64 = 1e-7;
const PIVOT_FLOOR: f64 = 1e-300;
const MAX_CELLS_PER_AXIS: usize = 32;
const GL4_NODES: [f64; 4] = [
-0.861_136_311_594_052_6,
-0.339_981_043_584_856_26,
0.339_981_043_584_856_26,
0.861_136_311_594_052_6,
];
const GL4_WEIGHTS: [f64; 4] = [
0.347_854_845_137_453_85,
0.652_145_154_862_546_2,
0.652_145_154_862_546_2,
0.347_854_845_137_453_85,
];
#[inline]
fn bspline_value(u: f64) -> [f64; 4] {
let v = 1.0 - u;
[
v * v * v / 6.0,
(3.0 * u * u * u - 6.0 * u * u + 4.0) / 6.0,
(-3.0 * u * u * u + 3.0 * u * u + 3.0 * u + 1.0) / 6.0,
u * u * u / 6.0,
]
}
#[inline]
fn bspline_d1(u: f64) -> [f64; 4] {
let v = 1.0 - u;
[
-0.5 * v * v,
0.5 * (3.0 * u * u - 4.0 * u),
0.5 * (-3.0 * u * u + 2.0 * u + 1.0),
0.5 * u * u,
]
}
#[inline]
fn bspline_d2(u: f64) -> [f64; 4] {
[1.0 - u, 3.0 * u - 2.0, 1.0 - 3.0 * u, u]
}
#[derive(Clone, Copy)]
struct Axis {
lo: f64,
h: f64,
cells: usize,
}
impl Axis {
#[inline]
fn locate(&self, x: f64) -> (usize, f64) {
let t = (x - self.lo) / self.h;
let cell = (t.floor().max(0.0) as usize).min(self.cells - 1);
(cell, t - cell as f64)
}
}
pub fn axis_basis_at(lo: f64, h: f64, cells: usize, x: f64) -> (usize, [f64; 4]) {
let (cell, u) = Axis { lo, h, cells }.locate(x);
(cell, bspline_value(u))
}
#[inline]
fn basis_row(axes: &[Axis; 2], m_axis: usize, x1: f64, x2: f64) -> ([usize; 16], [f64; 16]) {
let (c1, u1) = axes[0].locate(x1);
let (c2, u2) = axes[1].locate(x2);
let b1 = bspline_value(u1);
let b2 = bspline_value(u2);
let mut idx = [0usize; 16];
let mut val = [0f64; 16];
for i in 0..4 {
for j in 0..4 {
idx[4 * i + j] = (c1 + i) * m_axis + (c2 + j);
val[4 * i + j] = b1[i] * b2[j];
}
}
(idx, val)
}
pub(crate) fn cholesky_logdet(a: &mut [f64], p: usize) -> Result<f64, String> {
let mut logdet = 0.0;
for j in 0..p {
let mut s = a[j * p + j];
for t in 0..j {
s -= a[j * p + t] * a[j * p + t];
}
if !(s.is_finite() && s > PIVOT_FLOOR) {
return Err(format!(
"grid spline 2d: penalized system not positive definite at pivot {j} (value {s})"
));
}
let l = s.sqrt();
a[j * p + j] = l;
logdet += 2.0 * l.ln();
for i in j + 1..p {
let mut s2 = a[i * p + j];
for t in 0..j {
s2 -= a[i * p + t] * a[j * p + t];
}
a[i * p + j] = s2 / l;
}
}
for i in 0..p {
for j in i + 1..p {
a[i * p + j] = 0.0;
}
}
Ok(logdet)
}
pub(crate) fn chol_solve(l: &[f64], p: usize, b: &[f64]) -> Vec<f64> {
let mut z = b.to_vec();
for i in 0..p {
let mut s = z[i];
for t in 0..i {
s -= l[i * p + t] * z[t];
}
z[i] = s / l[i * p + i];
}
for i in (0..p).rev() {
let mut s = z[i];
for t in i + 1..p {
s -= l[t * p + i] * z[t];
}
z[i] = s / l[i * p + i];
}
z
}
pub struct GridSpline2dDesign {
axes: [Axis; 2],
m_axis: usize,
p: usize,
band_half: usize,
gram_band: Vec<f64>,
pen_band: Vec<f64>,
rhs: Vec<Vec<f64>>,
cross_moments: Vec<f64>,
n_obs: usize,
}
struct Solved {
chol: Vec<f64>,
logdet: f64,
coeffs: Vec<Vec<f64>>,
rss_pen: Vec<f64>,
}
impl GridSpline2dDesign {
pub fn build(
x1: &[f64],
x2: &[f64],
y: &[f64],
w: &[f64],
k: usize,
metric: [f64; 2],
) -> Result<Self, String> {
Self::build_multi(x1, x2, &[y], w, k, metric)
}
pub fn build_multi(
x1: &[f64],
x2: &[f64],
responses: &[&[f64]],
w: &[f64],
k: usize,
metric: [f64; 2],
) -> Result<Self, String> {
let n = x1.len();
if responses.is_empty() {
return Err("grid spline 2d: no response dimensions supplied".to_string());
}
if x2.len() != n || w.len() != n {
return Err(format!(
"grid spline 2d: length mismatch x1={n}, x2={}, w={}",
x2.len(),
w.len()
));
}
for (d, y) in responses.iter().enumerate() {
if y.len() != n {
return Err(format!(
"grid spline 2d: response dimension {d} has length {} != {n}",
y.len()
));
}
}
if n <= PENALTY_NULLITY {
return Err(format!(
"grid spline 2d: needs more than {PENALTY_NULLITY} rows for the profiled REML \
degrees of freedom, got {n}"
));
}
if k == 0 || k > MAX_CELLS_PER_AXIS {
return Err(format!(
"grid spline 2d: k must be in 1..={MAX_CELLS_PER_AXIS} (dense Cholesky on \
(k+3)² coefficients — see module sizing contract), got {k}"
));
}
if !(metric[0].is_finite() && metric[0] > 0.0 && metric[1].is_finite() && metric[1] > 0.0) {
return Err(format!(
"grid spline 2d: metric diagonal must be finite and positive, got [{}, {}]",
metric[0], metric[1]
));
}
for i in 0..n {
if !(x1[i].is_finite() && x2[i].is_finite()) || !(w[i] > 0.0) || !w[i].is_finite() {
return Err(format!(
"grid spline 2d: non-finite or non-positive input at row {i} \
(x1={}, x2={}, w={})",
x1[i], x2[i], w[i]
));
}
for (d, y) in responses.iter().enumerate() {
if !y[i].is_finite() {
return Err(format!(
"grid spline 2d: non-finite response at row {i}, dimension {d} ({})",
y[i]
));
}
}
}
let mut axes = [Axis {
lo: 0.0,
h: 1.0,
cells: k,
}; 2];
for (axis, xs) in axes.iter_mut().zip([x1, x2]) {
let mut lo = f64::INFINITY;
let mut hi = f64::NEG_INFINITY;
for &v in xs {
lo = lo.min(v);
hi = hi.max(v);
}
if !(hi > lo) {
return Err(format!(
"grid spline 2d: degenerate axis bounding box [{lo}, {hi}]"
));
}
axis.lo = lo;
axis.h = (hi - lo) / k as f64;
}
let m_axis = k + 3;
let p = m_axis * m_axis;
let band_half = 3 * m_axis + 3;
let stride = band_half + 1;
let n_dims = responses.len();
let mut gram_band = vec![0.0_f64; p * stride];
let mut rhs = vec![vec![0.0_f64; p]; n_dims];
let mut cross_moments = vec![0.0_f64; n_dims * n_dims];
for i in 0..n {
let (idx, val) = basis_row(&axes, m_axis, x1[i], x2[i]);
let wi = w[i];
for (d, y) in responses.iter().enumerate() {
let wy = wi * y[i];
for e in 0..16 {
rhs[d][idx[e]] += wy * val[e];
}
for (e, ye) in responses.iter().enumerate().skip(d) {
cross_moments[d * n_dims + e] += wy * ye[i];
}
}
for a in 0..16 {
let base = idx[a] * stride - idx[a];
let wa = wi * val[a];
for b in a..16 {
gram_band[base + idx[b]] += wa * val[b];
}
}
}
for d in 0..n_dims {
for e in 0..d {
cross_moments[d * n_dims + e] = cross_moments[e * n_dims + d];
}
}
let mut tab = [[[[0.0_f64; 4]; 4]; 3]; 2]; for ax in 0..2 {
let h = axes[ax].h;
for q in 0..4 {
let u = 0.5 * (1.0 + GL4_NODES[q]);
let v0 = bspline_value(u);
let v1 = bspline_d1(u);
let v2 = bspline_d2(u);
for e in 0..4 {
tab[ax][0][q][e] = v0[e];
tab[ax][1][q][e] = v1[e] / h;
tab[ax][2][q][e] = v2[e] / (h * h);
}
}
}
let s11 = metric[0] * metric[0];
let s12 = 2.0 * metric[0] * metric[1];
let s22 = metric[1] * metric[1];
let cell_area_jac = 0.25 * axes[0].h * axes[1].h; let mut pen_band = vec![0.0_f64; p * stride];
let mut r11 = [0.0_f64; 16];
let mut r12 = [0.0_f64; 16];
let mut r22 = [0.0_f64; 16];
let mut idx = [0usize; 16];
for c1 in 0..k {
for c2 in 0..k {
for i in 0..4 {
for j in 0..4 {
idx[4 * i + j] = (c1 + i) * m_axis + (c2 + j);
}
}
for q1 in 0..4 {
for q2 in 0..4 {
let wq = cell_area_jac * GL4_WEIGHTS[q1] * GL4_WEIGHTS[q2];
for i in 0..4 {
for j in 0..4 {
let e = 4 * i + j;
r11[e] = tab[0][2][q1][i] * tab[1][0][q2][j];
r12[e] = tab[0][1][q1][i] * tab[1][1][q2][j];
r22[e] = tab[0][0][q1][i] * tab[1][2][q2][j];
}
}
for a in 0..16 {
let base = idx[a] * stride - idx[a];
let (pa11, pa12, pa22) =
(wq * s11 * r11[a], wq * s12 * r12[a], wq * s22 * r22[a]);
for b in a..16 {
pen_band[base + idx[b]] +=
pa11 * r11[b] + pa12 * r12[b] + pa22 * r22[b];
}
}
}
}
}
}
Ok(GridSpline2dDesign {
axes,
m_axis,
p,
band_half,
gram_band,
pen_band,
rhs,
cross_moments,
n_obs: n,
})
}
pub fn num_cells(&self) -> usize {
self.axes[0].cells
}
pub fn basis_per_axis(&self) -> usize {
self.m_axis
}
pub fn num_coeffs(&self) -> usize {
self.p
}
pub fn lower_corner(&self) -> [f64; 2] {
[self.axes[0].lo, self.axes[1].lo]
}
pub fn cell_widths(&self) -> [f64; 2] {
[self.axes[0].h, self.axes[1].h]
}
pub fn num_rows(&self) -> usize {
self.n_obs
}
pub fn num_responses(&self) -> usize {
self.rhs.len()
}
pub fn axis_basis(&self, axis: usize, x: f64) -> Result<(usize, [f64; 4]), String> {
if axis > 1 {
return Err(format!("grid spline 2d: axis {axis} out of range"));
}
if !x.is_finite() {
return Err(format!("grid spline 2d: non-finite axis-{axis} point {x}"));
}
let ax = self.axes[axis];
Ok(axis_basis_at(ax.lo, ax.h, ax.cells, x))
}
pub fn penalty_value(&self, coeff: &[f64]) -> Result<f64, String> {
if coeff.len() != self.p {
return Err(format!(
"grid spline 2d: coefficient length {} != {}",
coeff.len(),
self.p
));
}
let stride = self.band_half + 1;
let mut j = 0.0;
for g in 0..self.p {
let dmax = self.band_half.min(self.p - 1 - g);
j += self.pen_band[g * stride] * coeff[g] * coeff[g];
for d in 1..=dmax {
j += 2.0 * self.pen_band[g * stride + d] * coeff[g] * coeff[g + d];
}
}
Ok(j)
}
fn dense_system(&self, lambda: f64) -> Vec<f64> {
let p = self.p;
let stride = self.band_half + 1;
let mut a = vec![0.0_f64; p * p];
for g in 0..p {
let dmax = self.band_half.min(p - 1 - g);
for d in 0..=dmax {
let v = self.gram_band[g * stride + d] + lambda * self.pen_band[g * stride + d];
a[g * p + g + d] = v;
a[(g + d) * p + g] = v;
}
}
a
}
fn solve_at(&self, log_lambda: f64) -> Result<Solved, String> {
if !log_lambda.is_finite() {
return Err(format!(
"grid spline 2d: non-finite log lambda {log_lambda}"
));
}
let mut a = self.dense_system(log_lambda.exp());
let logdet = cholesky_logdet(&mut a, self.p)?;
let n_dims = self.rhs.len();
let mut coeffs = Vec::with_capacity(n_dims);
let mut rss_pen = Vec::with_capacity(n_dims);
for (d, rhs) in self.rhs.iter().enumerate() {
let coeff = chol_solve(&a, self.p, rhs);
let mut quad = 0.0;
for g in 0..self.p {
quad += rhs[g] * coeff[g];
}
rss_pen.push(self.cross_moments[d * n_dims + d] - quad);
coeffs.push(coeff);
}
Ok(Solved {
chol: a,
logdet,
coeffs,
rss_pen,
})
}
fn criterion(&self, log_lambda: f64) -> Result<f64, String> {
let solved = self.solve_at(log_lambda)?;
let dof = (self.n_obs - PENALTY_NULLITY) as f64;
let r = (self.p - PENALTY_NULLITY) as f64;
let shared = solved.logdet - r * log_lambda;
let mut v = 0.0;
for &rss in &solved.rss_pen {
if !(rss > 0.0) {
return Err(format!(
"grid spline 2d: degenerate penalized residual {rss}"
));
}
v += shared + dof * (rss / dof).ln();
}
Ok(-0.5 * v)
}
pub fn fit_at(&self, log_lambda: f64, sigma2: Option<f64>) -> Result<GridSpline2dFit, String> {
let solved = self.solve_at(log_lambda)?;
let dof = (self.n_obs - PENALTY_NULLITY) as f64;
let mut sigma2_dims = Vec::with_capacity(solved.rss_pen.len());
for &rss in &solved.rss_pen {
match sigma2 {
Some(s) => {
if !(s.is_finite() && s > 0.0) {
return Err(format!("grid spline 2d: invalid sigma2 {s}"));
}
sigma2_dims.push(s);
}
None => {
if !(rss > 0.0) {
return Err(format!(
"grid spline 2d: degenerate penalized residual {rss}"
));
}
sigma2_dims.push(rss / dof);
}
}
}
let r = (self.p - PENALTY_NULLITY) as f64;
let mut restricted_loglik = 0.0;
for (d, &rss) in solved.rss_pen.iter().enumerate() {
restricted_loglik -= 0.5
* (solved.logdet - r * log_lambda
+ dof * sigma2_dims[d].ln()
+ rss / sigma2_dims[d]);
}
Ok(GridSpline2dFit {
coeffs: solved.coeffs,
log_lambda,
sigma2: sigma2_dims,
restricted_loglik,
chol: solved.chol,
axes: self.axes,
m_axis: self.m_axis,
})
}
pub fn fit_reml(&self) -> Result<GridSpline2dFit, String> {
let mut best_i = 0usize;
let mut best_v = f64::NEG_INFINITY;
let step = (LOG_LAMBDA_HI - LOG_LAMBDA_LO) / (LOG_LAMBDA_GRID - 1) as f64;
for i in 0..LOG_LAMBDA_GRID {
let ll = LOG_LAMBDA_LO + step * i as f64;
let v = self.criterion(ll)?;
if v > best_v {
best_v = v;
best_i = i;
}
}
let mut lo = LOG_LAMBDA_LO + step * best_i.saturating_sub(1) as f64;
let mut hi = (LOG_LAMBDA_LO + step * (best_i + 1) as f64).min(LOG_LAMBDA_HI);
let inv_phi = 0.618_033_988_749_894_9_f64;
let mut x1 = hi - inv_phi * (hi - lo);
let mut x2 = lo + inv_phi * (hi - lo);
let mut f1 = self.criterion(x1)?;
let mut f2 = self.criterion(x2)?;
while hi - lo > LOG_LAMBDA_TOL {
if f1 < f2 {
lo = x1;
x1 = x2;
f1 = f2;
x2 = lo + inv_phi * (hi - lo);
f2 = self.criterion(x2)?;
} else {
hi = x2;
x2 = x1;
f2 = f1;
x1 = hi - inv_phi * (hi - lo);
f1 = self.criterion(x1)?;
}
}
self.fit_at(0.5 * (lo + hi), None)
}
fn gram_quadratic(&self, a: &[f64], b: &[f64]) -> f64 {
let stride = self.band_half + 1;
let mut q = 0.0;
for g in 0..self.p {
let dmax = self.band_half.min(self.p - 1 - g);
q += self.gram_band[g * stride] * a[g] * b[g];
for d in 1..=dmax {
q += self.gram_band[g * stride + d] * (a[g] * b[g + d] + a[g + d] * b[g]);
}
}
q
}
pub fn posterior(&self, fit: &GridSpline2dFit) -> Result<GridSpline2dPosterior, String> {
let p = self.p;
let n_dims = self.rhs.len();
if fit.coeffs.len() != n_dims || fit.coeffs.iter().any(|c| c.len() != p) {
return Err(format!(
"grid spline 2d: posterior asked for a fit with {} dimensions of length {}, \
design has {n_dims} of {p}",
fit.coeffs.len(),
fit.coeffs.first().map_or(0, Vec::len)
));
}
let mut unit_covariance = vec![0.0_f64; p * p];
let mut e_g = vec![0.0_f64; p];
for g in 0..p {
e_g[g] = 1.0;
let col = chol_solve(&fit.chol, p, &e_g);
e_g[g] = 0.0;
for (r, &v) in col.iter().enumerate() {
unit_covariance[r * p + g] = v;
}
}
let stride = self.band_half + 1;
let mut edf = 0.0;
for g in 0..p {
let dmax = self.band_half.min(p - 1 - g);
edf += self.gram_band[g * stride] * unit_covariance[g * p + g];
for d in 1..=dmax {
edf += 2.0 * self.gram_band[g * stride + d] * unit_covariance[g * p + g + d];
}
}
let residual_df = self.n_obs as f64 - edf;
if !(residual_df >= 1.0) {
return Err(format!(
"grid spline 2d: too few rows for a scale estimate \
(n = {}, edf = {edf:.2}; need n − edf ≥ 1)",
self.n_obs
));
}
let mut residual_cross_cov = vec![0.0_f64; n_dims * n_dims];
for d in 0..n_dims {
for e in d..n_dims {
let mut cd_rhse = 0.0;
let mut ce_rhsd = 0.0;
for g in 0..p {
cd_rhse += fit.coeffs[d][g] * self.rhs[e][g];
ce_rhsd += fit.coeffs[e][g] * self.rhs[d][g];
}
let quad = self.gram_quadratic(&fit.coeffs[d], &fit.coeffs[e]);
let v =
(self.cross_moments[d * n_dims + e] - cd_rhse - ce_rhsd + quad) / residual_df;
residual_cross_cov[d * n_dims + e] = v;
residual_cross_cov[e * n_dims + d] = v;
}
}
Ok(GridSpline2dPosterior {
unit_covariance,
edf,
residual_df,
residual_cross_cov,
})
}
}
pub struct GridSpline2dPosterior {
pub unit_covariance: Vec<f64>,
pub edf: f64,
pub residual_df: f64,
pub residual_cross_cov: Vec<f64>,
}
pub struct GridSpline2dFit {
pub coeffs: Vec<Vec<f64>>,
pub log_lambda: f64,
pub sigma2: Vec<f64>,
pub restricted_loglik: f64,
chol: Vec<f64>,
axes: [Axis; 2],
m_axis: usize,
}
impl GridSpline2dFit {
pub fn predict(&self, dim: usize, x1: f64, x2: f64) -> Result<(f64, f64), String> {
if dim >= self.coeffs.len() {
return Err(format!(
"grid spline 2d: response dimension {dim} out of range (D = {})",
self.coeffs.len()
));
}
if !(x1.is_finite() && x2.is_finite()) {
return Err(format!(
"grid spline 2d: non-finite prediction point ({x1}, {x2})"
));
}
let (idx, val) = basis_row(&self.axes, self.m_axis, x1, x2);
let p = self.coeffs[dim].len();
let mut mean = 0.0;
let mut row = vec![0.0_f64; p];
for e in 0..16 {
mean += val[e] * self.coeffs[dim][idx[e]];
row[idx[e]] += val[e];
}
let z = chol_solve(&self.chol, p, &row);
let mut quad = 0.0;
for g in 0..p {
quad += row[g] * z[g];
}
Ok((mean, self.sigma2[dim] * quad))
}
}
pub fn fit_grid_spline_2d(
x1: &[f64],
x2: &[f64],
y: &[f64],
w: &[f64],
k: usize,
metric: [f64; 2],
) -> Result<GridSpline2dFit, String> {
GridSpline2dDesign::build(x1, x2, y, w, k, metric)?.fit_reml()
}
pub fn fit_grid_spline_2d_at(
x1: &[f64],
x2: &[f64],
y: &[f64],
w: &[f64],
k: usize,
metric: [f64; 2],
log_lambda: f64,
sigma2: Option<f64>,
) -> Result<GridSpline2dFit, String> {
GridSpline2dDesign::build(x1, x2, y, w, k, metric)?.fit_at(log_lambda, sigma2)
}