1const PENALTY_NULLITY: usize = 3;
59
60const LOG_LAMBDA_GRID: usize = 25;
62const LOG_LAMBDA_LO: f64 = -18.0;
64const LOG_LAMBDA_HI: f64 = 18.0;
65const LOG_LAMBDA_TOL: f64 = 1e-7;
67const PIVOT_FLOOR: f64 = 1e-300;
69const MAX_CELLS_PER_AXIS: usize = 32;
71
72const GL4_NODES: [f64; 4] = [
76 -0.861_136_311_594_052_6,
77 -0.339_981_043_584_856_26,
78 0.339_981_043_584_856_26,
79 0.861_136_311_594_052_6,
80];
81const GL4_WEIGHTS: [f64; 4] = [
82 0.347_854_845_137_453_85,
83 0.652_145_154_862_546_2,
84 0.652_145_154_862_546_2,
85 0.347_854_845_137_453_85,
86];
87
88#[inline]
93fn bspline_value(u: f64) -> [f64; 4] {
94 let v = 1.0 - u;
95 [
96 v * v * v / 6.0,
97 (3.0 * u * u * u - 6.0 * u * u + 4.0) / 6.0,
98 (-3.0 * u * u * u + 3.0 * u * u + 3.0 * u + 1.0) / 6.0,
99 u * u * u / 6.0,
100 ]
101}
102
103#[inline]
105fn bspline_d1(u: f64) -> [f64; 4] {
106 let v = 1.0 - u;
107 [
108 -0.5 * v * v,
109 0.5 * (3.0 * u * u - 4.0 * u),
110 0.5 * (-3.0 * u * u + 2.0 * u + 1.0),
111 0.5 * u * u,
112 ]
113}
114
115#[inline]
118fn bspline_d2(u: f64) -> [f64; 4] {
119 [1.0 - u, 3.0 * u - 2.0, 1.0 - 3.0 * u, u]
120}
121
122#[derive(Clone, Copy, Debug)]
124struct Axis {
125 lo: f64,
126 h: f64,
127 cells: usize,
128}
129
130impl Axis {
131 #[inline]
135 fn locate(&self, x: f64) -> (usize, f64) {
136 let t = (x - self.lo) / self.h;
137 let cell = (t.floor().max(0.0) as usize).min(self.cells - 1);
138 (cell, t - cell as f64)
139 }
140}
141
142pub fn axis_basis_at(lo: f64, h: f64, cells: usize, x: f64) -> (usize, [f64; 4]) {
148 let (cell, u) = Axis { lo, h, cells }.locate(x);
149 (cell, bspline_value(u))
150}
151
152#[inline]
155fn basis_row(axes: &[Axis; 2], m_axis: usize, x1: f64, x2: f64) -> ([usize; 16], [f64; 16]) {
156 let (c1, u1) = axes[0].locate(x1);
157 let (c2, u2) = axes[1].locate(x2);
158 let b1 = bspline_value(u1);
159 let b2 = bspline_value(u2);
160 let mut idx = [0usize; 16];
161 let mut val = [0f64; 16];
162 for i in 0..4 {
163 for j in 0..4 {
164 idx[4 * i + j] = (c1 + i) * m_axis + (c2 + j);
165 val[4 * i + j] = b1[i] * b2[j];
166 }
167 }
168 (idx, val)
169}
170
171pub fn cholesky_logdet(a: &mut [f64], p: usize) -> Result<f64, String> {
175 let mut logdet = 0.0;
176 for j in 0..p {
177 let mut s = a[j * p + j];
178 for t in 0..j {
179 s -= a[j * p + t] * a[j * p + t];
180 }
181 if !(s.is_finite() && s > PIVOT_FLOOR) {
182 return Err(format!(
183 "grid spline 2d: penalized system not positive definite at pivot {j} (value {s})"
184 ));
185 }
186 let l = s.sqrt();
187 a[j * p + j] = l;
188 logdet += 2.0 * l.ln();
189 for i in j + 1..p {
190 let mut s2 = a[i * p + j];
191 for t in 0..j {
192 s2 -= a[i * p + t] * a[j * p + t];
193 }
194 a[i * p + j] = s2 / l;
195 }
196 }
197 for i in 0..p {
198 for j in i + 1..p {
199 a[i * p + j] = 0.0;
200 }
201 }
202 Ok(logdet)
203}
204
205pub fn chol_solve(l: &[f64], p: usize, b: &[f64]) -> Vec<f64> {
207 let mut z = b.to_vec();
208 for i in 0..p {
209 let mut s = z[i];
210 for t in 0..i {
211 s -= l[i * p + t] * z[t];
212 }
213 z[i] = s / l[i * p + i];
214 }
215 for i in (0..p).rev() {
216 let mut s = z[i];
217 for t in i + 1..p {
218 s -= l[t * p + i] * z[t];
219 }
220 z[i] = s / l[i * p + i];
221 }
222 z
223}
224
225pub struct GridSpline2dDesign {
228 axes: [Axis; 2],
229 m_axis: usize,
231 p: usize,
233 band_half: usize,
235 gram_band: Vec<f64>,
237 pen_band: Vec<f64>,
239 rhs: Vec<Vec<f64>>,
243 cross_moments: Vec<f64>,
246 n_obs: usize,
247}
248
249struct Solved {
251 chol: Vec<f64>,
252 logdet: f64,
253 coeffs: Vec<Vec<f64>>,
254 rss_pen: Vec<f64>,
257}
258
259impl GridSpline2dDesign {
260 pub fn build(
262 x1: &[f64],
263 x2: &[f64],
264 y: &[f64],
265 w: &[f64],
266 k: usize,
267 metric: [f64; 2],
268 ) -> Result<Self, String> {
269 Self::build_multi(x1, x2, &[y], w, k, metric)
270 }
271
272 pub fn build_multi(
279 x1: &[f64],
280 x2: &[f64],
281 responses: &[&[f64]],
282 w: &[f64],
283 k: usize,
284 metric: [f64; 2],
285 ) -> Result<Self, String> {
286 let n = x1.len();
287 if responses.is_empty() {
288 return Err("grid spline 2d: no response dimensions supplied".to_string());
289 }
290 if x2.len() != n || w.len() != n {
291 return Err(format!(
292 "grid spline 2d: length mismatch x1={n}, x2={}, w={}",
293 x2.len(),
294 w.len()
295 ));
296 }
297 for (d, y) in responses.iter().enumerate() {
298 if y.len() != n {
299 return Err(format!(
300 "grid spline 2d: response dimension {d} has length {} != {n}",
301 y.len()
302 ));
303 }
304 }
305 if n <= PENALTY_NULLITY {
306 return Err(format!(
307 "grid spline 2d: needs more than {PENALTY_NULLITY} rows for the profiled REML \
308 degrees of freedom, got {n}"
309 ));
310 }
311 if k == 0 || k > MAX_CELLS_PER_AXIS {
312 return Err(format!(
313 "grid spline 2d: k must be in 1..={MAX_CELLS_PER_AXIS} (dense Cholesky on \
314 (k+3)² coefficients — see module sizing contract), got {k}"
315 ));
316 }
317 if !(metric[0].is_finite() && metric[0] > 0.0 && metric[1].is_finite() && metric[1] > 0.0) {
318 return Err(format!(
319 "grid spline 2d: metric diagonal must be finite and positive, got [{}, {}]",
320 metric[0], metric[1]
321 ));
322 }
323 for i in 0..n {
324 if !(x1[i].is_finite() && x2[i].is_finite()) || !(w[i] > 0.0) || !w[i].is_finite() {
325 return Err(format!(
326 "grid spline 2d: non-finite or non-positive input at row {i} \
327 (x1={}, x2={}, w={})",
328 x1[i], x2[i], w[i]
329 ));
330 }
331 for (d, y) in responses.iter().enumerate() {
332 if !y[i].is_finite() {
333 return Err(format!(
334 "grid spline 2d: non-finite response at row {i}, dimension {d} ({})",
335 y[i]
336 ));
337 }
338 }
339 }
340 let mut axes = [Axis {
341 lo: 0.0,
342 h: 1.0,
343 cells: k,
344 }; 2];
345 for (axis, xs) in axes.iter_mut().zip([x1, x2]) {
346 let mut lo = f64::INFINITY;
347 let mut hi = f64::NEG_INFINITY;
348 for &v in xs {
349 lo = lo.min(v);
350 hi = hi.max(v);
351 }
352 if !(hi > lo) {
353 return Err(format!(
354 "grid spline 2d: degenerate axis bounding box [{lo}, {hi}]"
355 ));
356 }
357 axis.lo = lo;
358 axis.h = (hi - lo) / k as f64;
359 }
360 let m_axis = k + 3;
361 let p = m_axis * m_axis;
362 let band_half = 3 * m_axis + 3;
363 let stride = band_half + 1;
364 let n_dims = responses.len();
365 let mut gram_band = vec![0.0_f64; p * stride];
366 let mut rhs = vec![vec![0.0_f64; p]; n_dims];
367 let mut cross_moments = vec![0.0_f64; n_dims * n_dims];
368
369 for i in 0..n {
374 let (idx, val) = basis_row(&axes, m_axis, x1[i], x2[i]);
375 let wi = w[i];
376 for (d, y) in responses.iter().enumerate() {
377 let wy = wi * y[i];
378 for e in 0..16 {
379 rhs[d][idx[e]] += wy * val[e];
380 }
381 for (e, ye) in responses.iter().enumerate().skip(d) {
382 cross_moments[d * n_dims + e] += wy * ye[i];
383 }
384 }
385 for a in 0..16 {
386 let base = idx[a] * stride - idx[a];
387 let wa = wi * val[a];
388 for b in a..16 {
389 gram_band[base + idx[b]] += wa * val[b];
390 }
391 }
392 }
393 for d in 0..n_dims {
394 for e in 0..d {
395 cross_moments[d * n_dims + e] = cross_moments[e * n_dims + d];
396 }
397 }
398
399 let mut tab = [[[[0.0_f64; 4]; 4]; 3]; 2]; for ax in 0..2 {
404 let h = axes[ax].h;
405 for q in 0..4 {
406 let u = 0.5 * (1.0 + GL4_NODES[q]);
407 let v0 = bspline_value(u);
408 let v1 = bspline_d1(u);
409 let v2 = bspline_d2(u);
410 for e in 0..4 {
411 tab[ax][0][q][e] = v0[e];
412 tab[ax][1][q][e] = v1[e] / h;
413 tab[ax][2][q][e] = v2[e] / (h * h);
414 }
415 }
416 }
417 let s11 = metric[0] * metric[0];
419 let s12 = 2.0 * metric[0] * metric[1];
420 let s22 = metric[1] * metric[1];
421 let cell_area_jac = 0.25 * axes[0].h * axes[1].h; let mut pen_band = vec![0.0_f64; p * stride];
423 let mut r11 = [0.0_f64; 16];
424 let mut r12 = [0.0_f64; 16];
425 let mut r22 = [0.0_f64; 16];
426 let mut idx = [0usize; 16];
427 for c1 in 0..k {
428 for c2 in 0..k {
429 for i in 0..4 {
430 for j in 0..4 {
431 idx[4 * i + j] = (c1 + i) * m_axis + (c2 + j);
432 }
433 }
434 for q1 in 0..4 {
435 for q2 in 0..4 {
436 let wq = cell_area_jac * GL4_WEIGHTS[q1] * GL4_WEIGHTS[q2];
437 for i in 0..4 {
438 for j in 0..4 {
439 let e = 4 * i + j;
440 r11[e] = tab[0][2][q1][i] * tab[1][0][q2][j];
441 r12[e] = tab[0][1][q1][i] * tab[1][1][q2][j];
442 r22[e] = tab[0][0][q1][i] * tab[1][2][q2][j];
443 }
444 }
445 for a in 0..16 {
446 let base = idx[a] * stride - idx[a];
447 let (pa11, pa12, pa22) =
448 (wq * s11 * r11[a], wq * s12 * r12[a], wq * s22 * r22[a]);
449 for b in a..16 {
450 pen_band[base + idx[b]] +=
451 pa11 * r11[b] + pa12 * r12[b] + pa22 * r22[b];
452 }
453 }
454 }
455 }
456 }
457 }
458
459 Ok(GridSpline2dDesign {
460 axes,
461 m_axis,
462 p,
463 band_half,
464 gram_band,
465 pen_band,
466 rhs,
467 cross_moments,
468 n_obs: n,
469 })
470 }
471
472 pub fn num_cells(&self) -> usize {
474 self.axes[0].cells
475 }
476
477 pub fn basis_per_axis(&self) -> usize {
479 self.m_axis
480 }
481
482 pub fn num_coeffs(&self) -> usize {
484 self.p
485 }
486
487 pub fn lower_corner(&self) -> [f64; 2] {
489 [self.axes[0].lo, self.axes[1].lo]
490 }
491
492 pub fn cell_widths(&self) -> [f64; 2] {
494 [self.axes[0].h, self.axes[1].h]
495 }
496
497 pub fn num_rows(&self) -> usize {
499 self.n_obs
500 }
501
502 pub fn num_responses(&self) -> usize {
504 self.rhs.len()
505 }
506
507 pub fn axis_basis(&self, axis: usize, x: f64) -> Result<(usize, [f64; 4]), String> {
513 if axis > 1 {
514 return Err(format!("grid spline 2d: axis {axis} out of range"));
515 }
516 if !x.is_finite() {
517 return Err(format!("grid spline 2d: non-finite axis-{axis} point {x}"));
518 }
519 let ax = self.axes[axis];
520 Ok(axis_basis_at(ax.lo, ax.h, ax.cells, x))
521 }
522
523 pub fn penalty_value(&self, coeff: &[f64]) -> Result<f64, String> {
526 if coeff.len() != self.p {
527 return Err(format!(
528 "grid spline 2d: coefficient length {} != {}",
529 coeff.len(),
530 self.p
531 ));
532 }
533 let stride = self.band_half + 1;
534 let mut j = 0.0;
535 for g in 0..self.p {
536 let dmax = self.band_half.min(self.p - 1 - g);
537 j += self.pen_band[g * stride] * coeff[g] * coeff[g];
538 for d in 1..=dmax {
539 j += 2.0 * self.pen_band[g * stride + d] * coeff[g] * coeff[g + d];
540 }
541 }
542 Ok(j)
543 }
544
545 fn dense_system(&self, lambda: f64) -> Vec<f64> {
547 let p = self.p;
548 let stride = self.band_half + 1;
549 let mut a = vec![0.0_f64; p * p];
550 for g in 0..p {
551 let dmax = self.band_half.min(p - 1 - g);
552 for d in 0..=dmax {
553 let v = self.gram_band[g * stride + d] + lambda * self.pen_band[g * stride + d];
554 a[g * p + g + d] = v;
555 a[(g + d) * p + g] = v;
556 }
557 }
558 a
559 }
560
561 fn solve_at(&self, log_lambda: f64) -> Result<Solved, String> {
562 if !log_lambda.is_finite() {
563 return Err(format!(
564 "grid spline 2d: non-finite log lambda {log_lambda}"
565 ));
566 }
567 let mut a = self.dense_system(log_lambda.exp());
568 let logdet = cholesky_logdet(&mut a, self.p)?;
569 let n_dims = self.rhs.len();
570 let mut coeffs = Vec::with_capacity(n_dims);
571 let mut rss_pen = Vec::with_capacity(n_dims);
572 for (d, rhs) in self.rhs.iter().enumerate() {
573 let coeff = chol_solve(&a, self.p, rhs);
574 let mut quad = 0.0;
575 for g in 0..self.p {
576 quad += rhs[g] * coeff[g];
577 }
578 rss_pen.push(self.cross_moments[d * n_dims + d] - quad);
579 coeffs.push(coeff);
580 }
581 Ok(Solved {
582 chol: a,
583 logdet,
584 coeffs,
585 rss_pen,
586 })
587 }
588
589 fn criterion(&self, log_lambda: f64) -> Result<f64, String> {
594 let solved = self.solve_at(log_lambda)?;
595 let dof = (self.n_obs - PENALTY_NULLITY) as f64;
596 let r = (self.p - PENALTY_NULLITY) as f64;
597 let shared = solved.logdet - r * log_lambda;
598 let mut v = 0.0;
599 for &rss in &solved.rss_pen {
600 if !(rss > 0.0) {
601 return Err(format!(
602 "grid spline 2d: degenerate penalized residual {rss}"
603 ));
604 }
605 v += shared + dof * (rss / dof).ln();
606 }
607 Ok(-0.5 * v)
608 }
609
610 pub fn fit_at(&self, log_lambda: f64, sigma2: Option<f64>) -> Result<GridSpline2dFit, String> {
613 let solved = self.solve_at(log_lambda)?;
614 let dof = (self.n_obs - PENALTY_NULLITY) as f64;
615 let mut sigma2_dims = Vec::with_capacity(solved.rss_pen.len());
616 for &rss in &solved.rss_pen {
617 match sigma2 {
618 Some(s) => {
619 if !(s.is_finite() && s > 0.0) {
620 return Err(format!("grid spline 2d: invalid sigma2 {s}"));
621 }
622 sigma2_dims.push(s);
623 }
624 None => {
625 if !(rss > 0.0) {
626 return Err(format!(
627 "grid spline 2d: degenerate penalized residual {rss}"
628 ));
629 }
630 sigma2_dims.push(rss / dof);
631 }
632 }
633 }
634 let r = (self.p - PENALTY_NULLITY) as f64;
639 let mut restricted_loglik = 0.0;
640 for (d, &rss) in solved.rss_pen.iter().enumerate() {
641 restricted_loglik -= 0.5
642 * (solved.logdet - r * log_lambda
643 + dof * sigma2_dims[d].ln()
644 + rss / sigma2_dims[d]);
645 }
646 Ok(GridSpline2dFit {
647 coeffs: solved.coeffs,
648 log_lambda,
649 sigma2: sigma2_dims,
650 restricted_loglik,
651 chol: solved.chol,
652 axes: self.axes,
653 m_axis: self.m_axis,
654 })
655 }
656
657 pub fn fit_reml(&self) -> Result<GridSpline2dFit, String> {
660 let mut best_i = 0usize;
661 let mut best_v = f64::NEG_INFINITY;
662 let step = (LOG_LAMBDA_HI - LOG_LAMBDA_LO) / (LOG_LAMBDA_GRID - 1) as f64;
663 for i in 0..LOG_LAMBDA_GRID {
664 let ll = LOG_LAMBDA_LO + step * i as f64;
665 let v = self.criterion(ll)?;
666 if v > best_v {
667 best_v = v;
668 best_i = i;
669 }
670 }
671 let mut lo = LOG_LAMBDA_LO + step * best_i.saturating_sub(1) as f64;
672 let mut hi = (LOG_LAMBDA_LO + step * (best_i + 1) as f64).min(LOG_LAMBDA_HI);
673 let inv_phi = 0.618_033_988_749_894_9_f64;
675 let mut x1 = hi - inv_phi * (hi - lo);
676 let mut x2 = lo + inv_phi * (hi - lo);
677 let mut f1 = self.criterion(x1)?;
678 let mut f2 = self.criterion(x2)?;
679 while hi - lo > LOG_LAMBDA_TOL {
680 if f1 < f2 {
681 lo = x1;
682 x1 = x2;
683 f1 = f2;
684 x2 = lo + inv_phi * (hi - lo);
685 f2 = self.criterion(x2)?;
686 } else {
687 hi = x2;
688 x2 = x1;
689 f2 = f1;
690 x1 = hi - inv_phi * (hi - lo);
691 f1 = self.criterion(x1)?;
692 }
693 }
694 self.fit_at(0.5 * (lo + hi), None)
695 }
696
697 fn gram_quadratic(&self, a: &[f64], b: &[f64]) -> f64 {
699 let stride = self.band_half + 1;
700 let mut q = 0.0;
701 for g in 0..self.p {
702 let dmax = self.band_half.min(self.p - 1 - g);
703 q += self.gram_band[g * stride] * a[g] * b[g];
704 for d in 1..=dmax {
705 q += self.gram_band[g * stride + d] * (a[g] * b[g + d] + a[g + d] * b[g]);
706 }
707 }
708 q
709 }
710
711 pub fn posterior(&self, fit: &GridSpline2dFit) -> Result<GridSpline2dPosterior, String> {
721 let p = self.p;
722 let n_dims = self.rhs.len();
723 if fit.coeffs.len() != n_dims || fit.coeffs.iter().any(|c| c.len() != p) {
724 return Err(format!(
725 "grid spline 2d: posterior asked for a fit with {} dimensions of length {}, \
726 design has {n_dims} of {p}",
727 fit.coeffs.len(),
728 fit.coeffs.first().map_or(0, Vec::len)
729 ));
730 }
731 let mut unit_covariance = vec![0.0_f64; p * p];
733 let mut e_g = vec![0.0_f64; p];
734 for g in 0..p {
735 e_g[g] = 1.0;
736 let col = chol_solve(&fit.chol, p, &e_g);
737 e_g[g] = 0.0;
738 for (r, &v) in col.iter().enumerate() {
739 unit_covariance[r * p + g] = v;
740 }
741 }
742 let stride = self.band_half + 1;
744 let mut edf = 0.0;
745 for g in 0..p {
746 let dmax = self.band_half.min(p - 1 - g);
747 edf += self.gram_band[g * stride] * unit_covariance[g * p + g];
748 for d in 1..=dmax {
749 edf += 2.0 * self.gram_band[g * stride + d] * unit_covariance[g * p + g + d];
750 }
751 }
752 let residual_df = self.n_obs as f64 - edf;
753 if !(residual_df >= 1.0) {
754 return Err(format!(
755 "grid spline 2d: too few rows for a scale estimate \
756 (n = {}, edf = {edf:.2}; need n − edf ≥ 1)",
757 self.n_obs
758 ));
759 }
760 let mut residual_cross_cov = vec![0.0_f64; n_dims * n_dims];
761 for d in 0..n_dims {
762 for e in d..n_dims {
763 let mut cd_rhse = 0.0;
764 let mut ce_rhsd = 0.0;
765 for g in 0..p {
766 cd_rhse += fit.coeffs[d][g] * self.rhs[e][g];
767 ce_rhsd += fit.coeffs[e][g] * self.rhs[d][g];
768 }
769 let quad = self.gram_quadratic(&fit.coeffs[d], &fit.coeffs[e]);
770 let v =
771 (self.cross_moments[d * n_dims + e] - cd_rhse - ce_rhsd + quad) / residual_df;
772 residual_cross_cov[d * n_dims + e] = v;
773 residual_cross_cov[e * n_dims + d] = v;
774 }
775 }
776 Ok(GridSpline2dPosterior {
777 unit_covariance,
778 edf,
779 residual_df,
780 residual_cross_cov,
781 })
782 }
783}
784
785pub struct GridSpline2dPosterior {
789 pub unit_covariance: Vec<f64>,
792 pub edf: f64,
794 pub residual_df: f64,
796 pub residual_cross_cov: Vec<f64>,
798}
799
800pub struct GridSpline2dFit {
802 pub coeffs: Vec<Vec<f64>>,
805 pub log_lambda: f64,
808 pub sigma2: Vec<f64>,
810 pub restricted_loglik: f64,
813 chol: Vec<f64>,
816 axes: [Axis; 2],
817 m_axis: usize,
818}
819
820#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
832pub struct GridSpline2dState {
833 pub coeffs: Vec<Vec<f64>>,
835 pub log_lambda: f64,
836 pub sigma2: Vec<f64>,
838 pub restricted_loglik: f64,
839 pub chol: Vec<f64>,
842 pub axis_lo: [f64; 2],
844 pub axis_h: [f64; 2],
846 pub axis_cells: [u64; 2],
848 pub m_axis: u64,
850}
851
852impl GridSpline2dFit {
853 pub fn to_state(&self) -> GridSpline2dState {
857 GridSpline2dState {
858 coeffs: self.coeffs.clone(),
859 log_lambda: self.log_lambda,
860 sigma2: self.sigma2.clone(),
861 restricted_loglik: self.restricted_loglik,
862 chol: self.chol.clone(),
863 axis_lo: [self.axes[0].lo, self.axes[1].lo],
864 axis_h: [self.axes[0].h, self.axes[1].h],
865 axis_cells: [self.axes[0].cells as u64, self.axes[1].cells as u64],
866 m_axis: self.m_axis as u64,
867 }
868 }
869
870 pub fn from_state(state: &GridSpline2dState) -> Result<Self, String> {
877 let m_axis = state.m_axis as usize;
878 let p = m_axis * m_axis;
879 for a in 0..2 {
880 let cells = state.axis_cells[a] as usize;
881 if cells == 0 {
882 return Err(format!(
883 "grid spline 2d state: axis {a} must have at least one cell"
884 ));
885 }
886 if m_axis != cells + 3 {
887 return Err(format!(
888 "grid spline 2d state: m_axis {m_axis} must equal K+3 = {} for axis {a}",
889 cells + 3
890 ));
891 }
892 if !(state.axis_lo[a].is_finite()
893 && state.axis_h[a].is_finite()
894 && state.axis_h[a] > 0.0)
895 {
896 return Err(format!(
897 "grid spline 2d state: axis {a} must have finite lo and positive h, got lo={}, h={}",
898 state.axis_lo[a], state.axis_h[a]
899 ));
900 }
901 }
902 if state.chol.len() != p * p {
903 return Err(format!(
904 "grid spline 2d state: chol must be p×p = {p}² = {}, got {}",
905 p * p,
906 state.chol.len()
907 ));
908 }
909 let d = state.coeffs.len();
910 if d == 0 || state.sigma2.len() != d {
911 return Err(format!(
912 "grid spline 2d state: need ≥1 response dimension with matching σ² (coeffs D={d}, sigma2 D={})",
913 state.sigma2.len()
914 ));
915 }
916 for (dim, c) in state.coeffs.iter().enumerate() {
917 if c.len() != p {
918 return Err(format!(
919 "grid spline 2d state: response dimension {dim} has {} coeffs, expected p = {p}",
920 c.len()
921 ));
922 }
923 }
924 for (dim, &s2) in state.sigma2.iter().enumerate() {
925 if !(s2.is_finite() && s2 > 0.0) {
926 return Err(format!(
927 "grid spline 2d state: response dimension {dim} has non-positive σ² = {s2}"
928 ));
929 }
930 }
931 for (i, v) in state
932 .chol
933 .iter()
934 .chain(state.coeffs.iter().flatten())
935 .enumerate()
936 {
937 if !v.is_finite() {
938 return Err(format!("grid spline 2d state: non-finite entry at {i}"));
939 }
940 }
941 for g in 0..p {
945 let piv = state.chol[g * p + g];
946 if !(piv.is_finite() && piv > 0.0) {
947 return Err(format!(
948 "grid spline 2d state: non-positive Cholesky pivot {piv} at index {g}"
949 ));
950 }
951 }
952 if !(state.log_lambda.is_finite() && state.restricted_loglik.is_finite()) {
953 return Err(format!(
954 "grid spline 2d state: invalid scalars (log_lambda={}, restricted_loglik={})",
955 state.log_lambda, state.restricted_loglik
956 ));
957 }
958 let axes = [
959 Axis {
960 lo: state.axis_lo[0],
961 h: state.axis_h[0],
962 cells: state.axis_cells[0] as usize,
963 },
964 Axis {
965 lo: state.axis_lo[1],
966 h: state.axis_h[1],
967 cells: state.axis_cells[1] as usize,
968 },
969 ];
970 Ok(GridSpline2dFit {
971 coeffs: state.coeffs.clone(),
972 log_lambda: state.log_lambda,
973 sigma2: state.sigma2.clone(),
974 restricted_loglik: state.restricted_loglik,
975 chol: state.chol.clone(),
976 axes,
977 m_axis,
978 })
979 }
980
981 pub fn predict(&self, dim: usize, x1: f64, x2: f64) -> Result<(f64, f64), String> {
986 if dim >= self.coeffs.len() {
987 return Err(format!(
988 "grid spline 2d: response dimension {dim} out of range (D = {})",
989 self.coeffs.len()
990 ));
991 }
992 if !(x1.is_finite() && x2.is_finite()) {
993 return Err(format!(
994 "grid spline 2d: non-finite prediction point ({x1}, {x2})"
995 ));
996 }
997 let (idx, val) = basis_row(&self.axes, self.m_axis, x1, x2);
998 let p = self.coeffs[dim].len();
999 let mut mean = 0.0;
1000 let mut row = vec![0.0_f64; p];
1001 for e in 0..16 {
1002 mean += val[e] * self.coeffs[dim][idx[e]];
1003 row[idx[e]] += val[e];
1004 }
1005 let z = chol_solve(&self.chol, p, &row);
1006 let mut quad = 0.0;
1007 for g in 0..p {
1008 quad += row[g] * z[g];
1009 }
1010 Ok((mean, self.sigma2[dim] * quad))
1011 }
1012}
1013
1014pub fn fit_grid_spline_2d(
1016 x1: &[f64],
1017 x2: &[f64],
1018 y: &[f64],
1019 w: &[f64],
1020 k: usize,
1021 metric: [f64; 2],
1022) -> Result<GridSpline2dFit, String> {
1023 GridSpline2dDesign::build(x1, x2, y, w, k, metric)?.fit_reml()
1024}
1025
1026pub fn fit_grid_spline_2d_at(
1028 x1: &[f64],
1029 x2: &[f64],
1030 y: &[f64],
1031 w: &[f64],
1032 k: usize,
1033 metric: [f64; 2],
1034 log_lambda: f64,
1035 sigma2: Option<f64>,
1036) -> Result<GridSpline2dFit, String> {
1037 GridSpline2dDesign::build(x1, x2, y, w, k, metric)?.fit_at(log_lambda, sigma2)
1038}
1039
1040#[cfg(test)]
1041mod tests {
1042 use super::*;
1043
1044 #[test]
1049 fn grid_spline_2d_state_roundtrip_reproduces_predict() {
1050 let k = 8usize;
1051 let mut x1 = Vec::new();
1053 let mut x2 = Vec::new();
1054 let mut y0 = Vec::new();
1055 let mut y1 = Vec::new();
1056 for i in 0..24 {
1057 for j in 0..24 {
1058 let a = i as f64 / 23.0;
1059 let b = j as f64 / 23.0;
1060 x1.push(a);
1061 x2.push(b);
1062 y0.push((2.5 * a).sin() * (1.7 * b).cos() + 0.3 * a * b);
1063 y1.push(a * a - 0.5 * b + 0.2 * (3.0 * a * b).cos());
1064 }
1065 }
1066 let n = x1.len();
1067 let w = vec![1.0_f64; n];
1068 let ys: Vec<&[f64]> = vec![&y0, &y1];
1069 let fit = GridSpline2dDesign::build_multi(&x1, &x2, &ys, &w, k, [1.0, 1.0])
1070 .expect("design")
1071 .fit_reml()
1072 .expect("fit");
1073
1074 let json = serde_json::to_string(&fit.to_state()).expect("serialize");
1075 let state: GridSpline2dState = serde_json::from_str(&json).expect("deserialize");
1076 let restored = GridSpline2dFit::from_state(&state).expect("restore");
1077
1078 let probes = [
1081 (0.13, 0.77),
1082 (0.41, 0.05),
1083 (0.66, 0.92),
1084 (0.99, 0.31),
1085 (1.20, -0.10),
1086 ];
1087 for dim in 0..2 {
1088 for &(p1, p2) in &probes {
1089 let (m0, v0) = fit.predict(dim, p1, p2).expect("orig predict");
1090 let (m1, v1) = restored.predict(dim, p1, p2).expect("restored predict");
1091 assert!(
1092 (m0 - m1).abs() <= 1e-12 * (1.0 + m0.abs()),
1093 "mean drift dim={dim} at ({p1},{p2}): {m0} vs {m1}"
1094 );
1095 assert!(
1096 (v0 - v1).abs() <= 1e-12 * (1.0 + v0.abs()),
1097 "variance drift dim={dim} at ({p1},{p2}): {v0} vs {v1}"
1098 );
1099 }
1100 }
1101 assert!((fit.log_lambda - restored.log_lambda).abs() <= 0.0);
1102 assert!((fit.restricted_loglik - restored.restricted_loglik).abs() <= 0.0);
1103 }
1104
1105 #[test]
1107 fn grid_spline_2d_state_rejects_corruption() {
1108 let k = 6usize;
1109 let side = 12usize;
1114 let mut x1 = Vec::new();
1115 let mut x2 = Vec::new();
1116 for i in 0..side {
1117 for j in 0..side {
1118 x1.push(i as f64 / (side - 1) as f64);
1119 x2.push(j as f64 / (side - 1) as f64);
1120 }
1121 }
1122 let n = x1.len();
1123 let y: Vec<f64> = x1
1133 .iter()
1134 .zip(&x2)
1135 .map(|(&a, &b)| a + b + (3.0 * a).sin() * (2.5 * b).cos())
1136 .collect();
1137 let w = vec![1.0_f64; n];
1138 let fit = fit_grid_spline_2d(&x1, &x2, &y, &w, k, [1.0, 1.0]).expect("fit");
1139
1140 let good = fit.to_state();
1141 let mut bad = good.clone();
1142 bad.chol.pop();
1143 assert!(
1144 GridSpline2dFit::from_state(&bad).is_err(),
1145 "chol length mismatch must error"
1146 );
1147
1148 let mut bad = good.clone();
1149 bad.sigma2[0] = -1.0;
1150 assert!(
1151 GridSpline2dFit::from_state(&bad).is_err(),
1152 "non-positive σ² must error"
1153 );
1154
1155 let mut bad = good.clone();
1156 bad.m_axis += 1;
1157 assert!(
1158 GridSpline2dFit::from_state(&bad).is_err(),
1159 "m_axis ≠ K+3 must error"
1160 );
1161
1162 let mut bad = good.clone();
1163 bad.axis_h[0] = 0.0;
1164 assert!(
1165 GridSpline2dFit::from_state(&bad).is_err(),
1166 "non-positive cell width must error"
1167 );
1168
1169 let mut bad = good;
1170 bad.chol[0] = 0.0;
1171 assert!(
1172 GridSpline2dFit::from_state(&bad).is_err(),
1173 "zero Cholesky pivot must error"
1174 );
1175 }
1176}