1#[derive(Clone, Copy, Debug)]
56struct PooledNode {
57 x: f64,
58 y: f64,
60 w: f64,
62}
63
64const LOG_LAMBDA_GRID: usize = 25;
66const LOG_LAMBDA_LO: f64 = -18.0;
68const LOG_LAMBDA_HI: f64 = 18.0;
69const LOG_LAMBDA_TOL: f64 = 1e-7;
71const INNOVATION_VAR_FLOOR: f64 = 1e-300;
73
74const MAX_ORDER: usize = 3;
82
83type Mat2 = [[f64; MAX_ORDER]; MAX_ORDER];
89type Vec2 = [f64; MAX_ORDER];
90
91#[inline]
92fn mat_mul(a: &Mat2, b: &Mat2, m: usize) -> Mat2 {
93 let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
94 for i in 0..m {
95 for j in 0..m {
96 let mut acc = 0.0;
97 for k in 0..m {
98 acc += a[i][k] * b[k][j];
99 }
100 c[i][j] = acc;
101 }
102 }
103 c
104}
105
106#[inline]
107fn mat_t(a: &Mat2, m: usize) -> Mat2 {
108 let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
109 for i in 0..m {
110 for j in 0..m {
111 c[i][j] = a[j][i];
112 }
113 }
114 c
115}
116
117#[inline]
118fn mat_vec(a: &Mat2, v: &Vec2, m: usize) -> Vec2 {
119 let mut out = [0.0; MAX_ORDER];
120 for i in 0..m {
121 let mut acc = 0.0;
122 for j in 0..m {
123 acc += a[i][j] * v[j];
124 }
125 out[i] = acc;
126 }
127 out
128}
129
130#[inline]
131fn mat_add(a: &Mat2, b: &Mat2, m: usize) -> Mat2 {
132 let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
133 for i in 0..m {
134 for j in 0..m {
135 c[i][j] = a[i][j] + b[i][j];
136 }
137 }
138 c
139}
140
141#[inline]
142fn mat_sub(a: &Mat2, b: &Mat2, m: usize) -> Mat2 {
143 let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
144 for i in 0..m {
145 for j in 0..m {
146 c[i][j] = a[i][j] - b[i][j];
147 }
148 }
149 c
150}
151
152fn mat_inv(a: &Mat2, m: usize, what: &str) -> Result<Mat2, String> {
156 let mut out = [[0.0; MAX_ORDER]; MAX_ORDER];
157 match m {
158 1 => {
159 let d = a[0][0];
160 if !(d.is_finite() && d.abs() > 0.0) {
161 return Err(format!("spline scan: singular 1x1 in {what} (a00={d})"));
162 }
163 out[0][0] = 1.0 / d;
164 }
165 2 => {
166 let det = a[0][0] * a[1][1] - a[0][1] * a[1][0];
167 if !(det.is_finite() && det.abs() > 0.0) {
168 return Err(format!("spline scan: singular 2x2 in {what} (det={det})"));
169 }
170 out[0][0] = a[1][1] / det;
171 out[0][1] = -a[0][1] / det;
172 out[1][0] = -a[1][0] / det;
173 out[1][1] = a[0][0] / det;
174 }
175 3 => {
176 let c00 = a[1][1] * a[2][2] - a[1][2] * a[2][1];
178 let c01 = a[1][2] * a[2][0] - a[1][0] * a[2][2];
179 let c02 = a[1][0] * a[2][1] - a[1][1] * a[2][0];
180 let det = a[0][0] * c00 + a[0][1] * c01 + a[0][2] * c02;
181 if !(det.is_finite() && det.abs() > 0.0) {
182 return Err(format!("spline scan: singular 3x3 in {what} (det={det})"));
183 }
184 let inv_det = 1.0 / det;
185 out[0][0] = c00 * inv_det;
187 out[0][1] = (a[0][2] * a[2][1] - a[0][1] * a[2][2]) * inv_det;
188 out[0][2] = (a[0][1] * a[1][2] - a[0][2] * a[1][1]) * inv_det;
189 out[1][0] = c01 * inv_det;
190 out[1][1] = (a[0][0] * a[2][2] - a[0][2] * a[2][0]) * inv_det;
191 out[1][2] = (a[0][2] * a[1][0] - a[0][0] * a[1][2]) * inv_det;
192 out[2][0] = c02 * inv_det;
193 out[2][1] = (a[0][1] * a[2][0] - a[0][0] * a[2][1]) * inv_det;
194 out[2][2] = (a[0][0] * a[1][1] - a[0][1] * a[1][0]) * inv_det;
195 }
196 _ => return Err(format!("spline scan: unsupported order {m} in {what}")),
197 }
198 Ok(out)
199}
200
201fn dense_spd_inverse(a: &[Vec<f64>], what: &str) -> Result<Vec<Vec<f64>>, String> {
217 let d = a.len();
218 let s: Vec<f64> = (0..d)
220 .map(|i| {
221 let dii = a[i][i];
222 if dii.is_finite() && dii > 0.0 {
223 1.0 / dii.sqrt()
224 } else {
225 1.0
226 }
227 })
228 .collect();
229 let a_s: Vec<Vec<f64>> = (0..d)
230 .map(|i| (0..d).map(|j| s[i] * a[i][j] * s[j]).collect())
231 .collect();
232 let mut inv_s = gauss_jordan_inverse(&a_s, what)?;
234 let mut resid = vec![vec![0.0_f64; d]; d]; for i in 0..d {
238 for j in 0..d {
239 let mut ax = 0.0;
240 for k in 0..d {
241 ax += a_s[i][k] * inv_s[k][j];
242 }
243 resid[i][j] = f64::from(u8::from(i == j)) - ax;
244 }
245 }
246 let mut delta = vec![vec![0.0_f64; d]; d]; for i in 0..d {
248 for j in 0..d {
249 let mut acc = 0.0;
250 for k in 0..d {
251 acc += inv_s[i][k] * resid[k][j];
252 }
253 delta[i][j] = acc;
254 }
255 }
256 for i in 0..d {
257 for j in 0..d {
258 inv_s[i][j] += delta[i][j];
259 }
260 }
261 Ok((0..d)
263 .map(|i| (0..d).map(|j| s[i] * inv_s[i][j] * s[j]).collect())
264 .collect())
265}
266
267fn gauss_jordan_inverse(a: &[Vec<f64>], what: &str) -> Result<Vec<Vec<f64>>, String> {
269 let d = a.len();
270 let mut aug = a.to_vec();
271 let mut inv = vec![vec![0.0_f64; d]; d];
272 for i in 0..d {
273 inv[i][i] = 1.0;
274 }
275 for col in 0..d {
276 let piv = (col..d)
277 .max_by(|&i, &j| aug[i][col].abs().total_cmp(&aug[j][col].abs()))
278 .unwrap();
279 let p = aug[piv][col];
280 if !(p.is_finite() && p.abs() > 0.0) {
281 return Err(format!(
282 "spline scan: singular {d}x{d} in {what} (pivot={p})"
283 ));
284 }
285 aug.swap(col, piv);
286 inv.swap(col, piv);
287 let d_piv = aug[col][col];
288 for k in 0..d {
289 aug[col][k] /= d_piv;
290 inv[col][k] /= d_piv;
291 }
292 for r in 0..d {
293 if r == col {
294 continue;
295 }
296 let f = aug[r][col];
297 if f == 0.0 {
298 continue;
299 }
300 for k in 0..d {
301 aug[r][k] -= f * aug[col][k];
302 inv[r][k] -= f * inv[col][k];
303 }
304 }
305 }
306 Ok(inv)
307}
308
309#[inline]
312fn factorial(k: usize) -> f64 {
313 (1..=k).map(|v| v as f64).product::<f64>().max(1.0)
314}
315
316#[inline]
320fn transition(delta: f64, m: usize) -> Mat2 {
321 let mut f = [[0.0; MAX_ORDER]; MAX_ORDER];
322 for i in 0..m {
323 for j in i..m {
324 f[i][j] = delta.powi((j - i) as i32) / factorial(j - i);
325 }
326 }
327 f
328}
329
330#[inline]
335fn process_noise(delta: f64, q: f64, m: usize) -> Mat2 {
336 let mut out = [[0.0; MAX_ORDER]; MAX_ORDER];
337 for i in 0..m {
338 for j in 0..m {
339 let p = 2 * m - 1 - i - j;
340 out[i][j] = q * delta.powi(p as i32)
341 / (factorial(m - 1 - i) * factorial(m - 1 - j) * (p as f64));
342 }
343 }
344 out
345}
346
347#[inline]
349fn symmetrize(a: &mut Mat2, m: usize) {
350 for i in 0..m {
351 for j in (i + 1)..m {
352 let off = 0.5 * (a[i][j] + a[j][i]);
353 a[i][j] = off;
354 a[j][i] = off;
355 }
356 }
357}
358
359struct FilterStep {
361 a_filt: Vec2,
363 p_filt: Mat2,
364 a_pred: Vec2,
366 p_pred: Mat2,
367}
368
369struct FilterPass {
371 steps: Vec<FilterStep>,
372 sum_log_f: f64,
374 sum_v2_over_f: f64,
376 n_proper: usize,
378}
379
380fn run_filter(nodes: &[PooledNode], q: f64, order: usize) -> Result<FilterPass, String> {
381 let n = nodes.len();
382 let mut steps = Vec::with_capacity(n);
383 let mut a: Vec2 = [0.0; MAX_ORDER];
388 let mut p_star: Mat2 = [[0.0; MAX_ORDER]; MAX_ORDER];
389 let mut p_inf: Mat2 = [[0.0; MAX_ORDER]; MAX_ORDER];
390 for i in 0..order {
391 p_inf[i][i] = 1.0;
392 }
393 let mut diffuse_rank = order;
394 let mut sum_log_f = 0.0;
395 let mut sum_v2_over_f = 0.0;
396 let mut n_proper = 0usize;
397 for t in 0..n {
398 let a_pred = a;
399 let p_pred = p_star;
400 let r = 1.0 / nodes[t].w;
401 let v = nodes[t].y - a[0];
402 let mut m_star: Vec2 = [0.0; MAX_ORDER];
404 for i in 0..order {
405 m_star[i] = p_star[i][0];
406 }
407 let f_star = m_star[0] + r;
408 if diffuse_rank > 0 {
409 let mut m_inf: Vec2 = [0.0; MAX_ORDER];
410 for i in 0..order {
411 m_inf[i] = p_inf[i][0];
412 }
413 let f_inf = m_inf[0];
414 if f_inf > INNOVATION_VAR_FLOOR {
415 for i in 0..order {
419 a[i] += (m_inf[i] / f_inf) * v;
420 }
421 let mut p_new = p_star;
422 for i in 0..order {
423 for j in 0..order {
424 p_new[i][j] += -m_inf[i] * m_star[j] / f_inf - m_star[i] * m_inf[j] / f_inf
425 + m_inf[i] * m_inf[j] * f_star / (f_inf * f_inf);
426 }
427 }
428 p_star = p_new;
429 symmetrize(&mut p_star, order);
430 for i in 0..order {
431 for j in 0..order {
432 p_inf[i][j] -= m_inf[i] * m_inf[j] / f_inf;
433 }
434 }
435 symmetrize(&mut p_inf, order);
436 diffuse_rank -= 1;
437 if diffuse_rank == 0 {
438 p_inf = [[0.0; MAX_ORDER]; MAX_ORDER];
439 }
440 } else {
441 if f_star <= INNOVATION_VAR_FLOOR {
444 return Err("spline scan: non-positive innovation variance".to_string());
445 }
446 for i in 0..order {
447 a[i] += (m_star[i] / f_star) * v;
448 }
449 for i in 0..order {
450 for j in 0..order {
451 p_star[i][j] -= m_star[i] * m_star[j] / f_star;
452 }
453 }
454 symmetrize(&mut p_star, order);
455 sum_log_f += f_star.ln();
456 sum_v2_over_f += v * v / f_star;
457 n_proper += 1;
458 }
459 } else {
460 if f_star <= INNOVATION_VAR_FLOOR {
461 return Err("spline scan: non-positive innovation variance".to_string());
462 }
463 for i in 0..order {
464 a[i] += (m_star[i] / f_star) * v;
465 }
466 for i in 0..order {
467 for j in 0..order {
468 p_star[i][j] -= m_star[i] * m_star[j] / f_star;
469 }
470 }
471 symmetrize(&mut p_star, order);
472 sum_log_f += f_star.ln();
473 sum_v2_over_f += v * v / f_star;
474 n_proper += 1;
475 }
476 steps.push(FilterStep {
477 a_filt: a,
478 p_filt: p_star,
479 a_pred,
480 p_pred,
481 });
482 if t + 1 < n {
484 let delta = nodes[t + 1].x - nodes[t].x;
485 let f_t = transition(delta, order);
486 a = mat_vec(&f_t, &a, order);
487 let mut p_next = mat_add(
488 &mat_mul(&mat_mul(&f_t, &p_star, order), &mat_t(&f_t, order), order),
489 &process_noise(delta, q, order),
490 order,
491 );
492 symmetrize(&mut p_next, order);
493 p_star = p_next;
494 if diffuse_rank > 0 {
495 let mut pi_next =
496 mat_mul(&mat_mul(&f_t, &p_inf, order), &mat_t(&f_t, order), order);
497 symmetrize(&mut pi_next, order);
498 p_inf = pi_next;
499 }
500 }
501 }
502 Ok(FilterPass {
503 steps,
504 sum_log_f,
505 sum_v2_over_f,
506 n_proper,
507 })
508}
509
510#[derive(Clone, Debug)]
512pub struct SplineScanFit {
513 pub order: usize,
517 pub knots: Vec<f64>,
519 pub mean: Vec<f64>,
521 pub deriv: Vec<f64>,
524 pub var: Vec<f64>,
526 pub log_lambda: f64,
528 pub sigma2: f64,
530 pub restricted_loglik: f64,
534 pub n_obs: usize,
539 smoothed_state: Vec<Vec2>,
541 smoothed_cov: Vec<Mat2>,
543 rts_gain: Vec<Mat2>,
545 q: f64,
547 node_weight: Vec<f64>,
549}
550
551fn pool_nodes(
554 x: &[f64],
555 y: &[f64],
556 w: &[f64],
557 order: usize,
558) -> Result<(Vec<PooledNode>, f64, usize), String> {
559 let n = x.len();
560 if y.len() != n || w.len() != n {
561 return Err(format!(
562 "spline scan: length mismatch x={n}, y={}, w={}",
563 y.len(),
564 w.len()
565 ));
566 }
567 for i in 0..n {
568 if !(x[i].is_finite() && y[i].is_finite() && w[i].is_finite() && w[i] > 0.0) {
569 return Err(format!(
570 "spline scan: non-finite or non-positive input at row {i} (x={}, y={}, w={})",
571 x[i], y[i], w[i]
572 ));
573 }
574 }
575 let mut perm: Vec<usize> = (0..n).collect();
576 perm.sort_by(|&i, &j| x[i].total_cmp(&x[j]));
577 let mut nodes: Vec<PooledNode> = Vec::new();
578 for &i in &perm {
579 match nodes.last_mut() {
580 Some(last) if last.x == x[i] => {
581 let w_new = last.w + w[i];
582 last.y = (last.y * last.w + y[i] * w[i]) / w_new;
583 last.w = w_new;
584 }
585 _ => nodes.push(PooledNode {
586 x: x[i],
587 y: y[i],
588 w: w[i],
589 }),
590 }
591 }
592 if nodes.len() < order + 1 {
594 return Err(format!(
595 "spline scan: order {order} needs at least {} distinct abscissae, got {}",
596 order + 1,
597 nodes.len()
598 ));
599 }
600 let mut ssr_within = 0.0;
602 let mut k = 0usize;
603 for &i in &perm {
604 while nodes[k].x != x[i] {
605 k += 1;
606 }
607 let d = y[i] - nodes[k].y;
608 ssr_within += w[i] * d * d;
609 }
610 Ok((nodes, ssr_within, n))
611}
612
613fn concentrated_criterion(
615 nodes: &[PooledNode],
616 ssr_within: f64,
617 n_obs: usize,
618 log_lambda: f64,
619 order: usize,
620) -> Result<f64, String> {
621 let pass = run_filter(nodes, (-log_lambda).exp(), order)?;
622 let dof = (n_obs - order) as f64;
625 let rss = pass.sum_v2_over_f + ssr_within;
626 if rss <= 0.0 {
627 return Err("spline scan: degenerate zero residual sum".to_string());
628 }
629 let sigma2 = rss / dof;
630 if pass.n_proper != nodes.len() - order {
631 return Err(format!(
632 "spline scan: expected {} proper innovations, got {} (diffuse rank not consumed)",
633 nodes.len() - order,
634 pass.n_proper
635 ));
636 }
637 Ok(-0.5 * (pass.sum_log_f + dof * sigma2.ln()))
638}
639
640fn leading_block_smooth(
673 sm_state: &mut [Vec2],
674 sm_cov: &mut [Mat2],
675 gains: &mut [Mat2],
676 nodes: &[PooledNode],
677 q: f64,
678 order: usize,
679) -> Result<(), String> {
680 let nb = order - 1; let pin = order - 1; let d = nb * order; let mut lambda = vec![vec![0.0_f64; d]; d];
684 let mut b_const = vec![0.0_f64; d];
685 let mut bmat = vec![vec![0.0_f64; order]; d]; for t in 0..order - 1 {
689 let delta = nodes[t + 1].x - nodes[t].x;
690 let f = transition(delta, order);
691 let qn = process_noise(delta, q, order);
692 let a = mat_inv(&qn, order, "leading-block increment noise")?; let ft = mat_t(&f, order);
694 let fta = mat_mul(&ft, &a, order); let ftaf = mat_mul(&fta, &f, order); let af = mat_mul(&a, &f, order); for i in 0..order {
699 for j in 0..order {
700 lambda[t * order + i][t * order + j] += ftaf[i][j];
701 }
702 }
703 if t + 1 <= nb - 1 {
704 for i in 0..order {
707 for j in 0..order {
708 lambda[(t + 1) * order + i][(t + 1) * order + j] += a[i][j];
709 lambda[t * order + i][(t + 1) * order + j] -= fta[i][j];
710 lambda[(t + 1) * order + i][t * order + j] -= af[i][j];
711 }
712 }
713 } else {
714 for i in 0..order {
717 for j in 0..order {
718 bmat[t * order + i][j] += fta[i][j];
719 }
720 }
721 }
722 }
723 for t in 0..nb {
725 let w = nodes[t].w;
726 lambda[t * order][t * order] += w;
727 b_const[t * order] += w * nodes[t].y;
728 }
729
730 let sigma = dense_spd_inverse(&lambda, "leading-block precision")?;
732 let dvec: Vec<f64> = (0..d)
733 .map(|i| (0..d).map(|k| sigma[i][k] * b_const[k]).sum())
734 .collect();
735 let cmat: Vec<Vec<f64>> = (0..d)
736 .map(|i| {
737 (0..order)
738 .map(|j| (0..d).map(|k| sigma[i][k] * bmat[k][j]).sum())
739 .collect()
740 })
741 .collect();
742
743 let ahat_p = sm_state[pin];
745 let vp = sm_cov[pin];
746 let cvp: Vec<Vec<f64>> = (0..d)
748 .map(|i| {
749 (0..order)
750 .map(|j| (0..order).map(|k| cmat[i][k] * vp[k][j]).sum())
751 .collect()
752 })
753 .collect();
754 let mean_u: Vec<f64> = (0..d)
756 .map(|i| (0..order).map(|j| cmat[i][j] * ahat_p[j]).sum::<f64>() + dvec[i])
757 .collect();
758 let cov_u: Vec<Vec<f64>> = (0..d)
760 .map(|i| {
761 (0..d)
762 .map(|k| (0..order).map(|j| cvp[i][j] * cmat[k][j]).sum::<f64>() + sigma[i][k])
763 .collect()
764 })
765 .collect();
766
767 for j in 0..nb {
769 for i in 0..order {
770 sm_state[j][i] = mean_u[j * order + i];
771 }
772 let mut cov = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
773 for i in 0..order {
774 for k in 0..order {
775 cov[i][k] = cov_u[j * order + i][j * order + k];
776 }
777 }
778 symmetrize(&mut cov, order);
779 sm_cov[j] = cov;
780 }
781 for j in 0..nb {
785 let mut cross = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
786 if j + 1 <= nb - 1 {
787 for i in 0..order {
789 for k in 0..order {
790 cross[i][k] = cov_u[j * order + i][(j + 1) * order + k];
791 }
792 }
793 } else {
794 for i in 0..order {
796 for k in 0..order {
797 cross[i][k] = cvp[j * order + i][k];
798 }
799 }
800 }
801 let denom_inv = mat_inv(&sm_cov[j + 1], order, "leading-block gain denominator")?;
802 gains[j] = mat_mul(&cross, &denom_inv, order);
803 }
804 Ok(())
805}
806
807pub fn fit_spline_scan_at(
810 x: &[f64],
811 y: &[f64],
812 w: &[f64],
813 log_lambda: f64,
814 sigma2: Option<f64>,
815 order: usize,
816) -> Result<SplineScanFit, String> {
817 if order == 0 || order > MAX_ORDER {
818 return Err(format!(
819 "spline scan: order must be in 1..={MAX_ORDER}, got {order}"
820 ));
821 }
822 let (nodes, ssr_within, n_obs) = pool_nodes(x, y, w, order)?;
823 let q = (-log_lambda).exp();
824 let pass = run_filter(&nodes, q, order)?;
825 let n = nodes.len();
826 let dof = (n_obs - order) as f64;
827 let sigma2 = match sigma2 {
828 Some(s) => {
829 if !(s.is_finite() && s > 0.0) {
830 return Err(format!("spline scan: invalid sigma2 {s}"));
831 }
832 s
833 }
834 None => (pass.sum_v2_over_f + ssr_within) / dof,
835 };
836 let rss = pass.sum_v2_over_f + ssr_within;
841 let restricted_loglik = -0.5 * (pass.sum_log_f + dof * sigma2.ln() + rss / sigma2);
842
843 let mut sm_state = vec![[0.0_f64; MAX_ORDER]; n];
854 let mut sm_cov = vec![[[0.0_f64; MAX_ORDER]; MAX_ORDER]; n];
855 let mut gains = vec![[[0.0_f64; MAX_ORDER]; MAX_ORDER]; n];
856 sm_state[n - 1] = pass.steps[n - 1].a_filt;
857 sm_cov[n - 1] = pass.steps[n - 1].p_filt;
858 for t in (order - 1..n - 1).rev() {
859 let p_next_pred = &pass.steps[t + 1].p_pred;
860 let delta = nodes[t + 1].x - nodes[t].x;
861 let f_t = transition(delta, order);
862 let p_inv = mat_inv(p_next_pred, order, "RTS predicted covariance")?;
863 let g = mat_mul(
864 &mat_mul(&pass.steps[t].p_filt, &mat_t(&f_t, order), order),
865 &p_inv,
866 order,
867 );
868 let mut dm: Vec2 = [0.0; MAX_ORDER];
869 for i in 0..order {
870 dm[i] = sm_state[t + 1][i] - pass.steps[t + 1].a_pred[i];
871 }
872 let corr = mat_vec(&g, &dm, order);
873 for i in 0..order {
874 sm_state[t][i] = pass.steps[t].a_filt[i] + corr[i];
875 }
876 let dp = mat_sub(&sm_cov[t + 1], p_next_pred, order);
877 let mut cov = mat_add(
878 &pass.steps[t].p_filt,
879 &mat_mul(&mat_mul(&g, &dp, order), &mat_t(&g, order), order),
880 order,
881 );
882 symmetrize(&mut cov, order);
883 sm_cov[t] = cov;
884 gains[t] = g;
885 }
886 if order >= 2 {
889 leading_block_smooth(&mut sm_state, &mut sm_cov, &mut gains, &nodes, q, order)?;
890 }
891
892 let knots: Vec<f64> = nodes.iter().map(|n| n.x).collect();
893 let mean: Vec<f64> = sm_state.iter().map(|s| s[0]).collect();
894 let deriv: Vec<f64> = sm_state
896 .iter()
897 .map(|s| if order >= 2 { s[1] } else { 0.0 })
898 .collect();
899 let var: Vec<f64> = sm_cov.iter().map(|p| p[0][0] * sigma2).collect();
900 Ok(SplineScanFit {
901 order,
902 knots,
903 mean,
904 deriv,
905 var,
906 log_lambda,
907 sigma2,
908 restricted_loglik,
909 n_obs,
910 smoothed_state: sm_state,
911 smoothed_cov: sm_cov,
912 rts_gain: gains,
913 q,
914 node_weight: nodes.iter().map(|n| n.w).collect(),
915 })
916}
917
918pub fn fit_spline_scan(
922 x: &[f64],
923 y: &[f64],
924 w: &[f64],
925 order: usize,
926) -> Result<SplineScanFit, String> {
927 if order == 0 || order > MAX_ORDER {
928 return Err(format!(
929 "spline scan: order must be in 1..={MAX_ORDER}, got {order}"
930 ));
931 }
932 let (nodes, ssr_within, n_obs) = pool_nodes(x, y, w, order)?;
933 let span = nodes.last().map(|n| n.x).unwrap_or(0.0) - nodes.first().map(|n| n.x).unwrap_or(0.0);
948 let scale_shift = if span.is_finite() && span > 0.0 {
949 (2 * order - 1) as f64 * span.ln()
950 } else {
951 0.0
952 };
953 let lo_anchor = LOG_LAMBDA_LO + scale_shift;
954 let hi_anchor = LOG_LAMBDA_HI + scale_shift;
955 let crit = |ll: f64| concentrated_criterion(&nodes, ssr_within, n_obs, ll, order);
956 let mut best_i = 0usize;
957 let mut best_v = f64::NEG_INFINITY;
958 let step = (hi_anchor - lo_anchor) / (LOG_LAMBDA_GRID - 1) as f64;
959 for i in 0..LOG_LAMBDA_GRID {
960 let ll = lo_anchor + step * i as f64;
961 let v = crit(ll)?;
962 if v > best_v {
963 best_v = v;
964 best_i = i;
965 }
966 }
967 let mut lo = lo_anchor + step * best_i.saturating_sub(1) as f64;
968 let mut hi = (lo_anchor + step * (best_i + 1) as f64).min(hi_anchor);
969 let inv_phi = 0.618_033_988_749_894_9_f64;
971 let mut x1 = hi - inv_phi * (hi - lo);
972 let mut x2 = lo + inv_phi * (hi - lo);
973 let mut f1 = crit(x1)?;
974 let mut f2 = crit(x2)?;
975 while hi - lo > LOG_LAMBDA_TOL {
976 if f1 < f2 {
977 lo = x1;
978 x1 = x2;
979 f1 = f2;
980 x2 = lo + inv_phi * (hi - lo);
981 f2 = crit(x2)?;
982 } else {
983 hi = x2;
984 x2 = x1;
985 f2 = f1;
986 x1 = hi - inv_phi * (hi - lo);
987 f1 = crit(x1)?;
988 }
989 }
990 fit_spline_scan_at(x, y, w, 0.5 * (lo + hi), None, order)
991}
992
993#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
1007pub struct SplineScanState {
1008 #[serde(default = "default_spline_scan_order")]
1012 pub order: usize,
1013 pub knots: Vec<f64>,
1014 pub state: Vec<f64>,
1016 pub cov: Vec<f64>,
1019 pub gain: Vec<f64>,
1022 pub node_weight: Vec<f64>,
1024 pub log_lambda: f64,
1025 pub sigma2: f64,
1026 pub restricted_loglik: f64,
1027 pub n_obs: u64,
1029}
1030
1031fn default_spline_scan_order() -> usize {
1034 2
1035}
1036
1037impl SplineScanFit {
1038 pub fn to_state(&self) -> SplineScanState {
1040 let order = self.order;
1041 let tri = order * (order + 1) / 2;
1042 let nk = self.knots.len();
1043 let mut state = Vec::with_capacity(order * nk);
1044 for s in &self.smoothed_state {
1045 state.extend_from_slice(&s[..order]);
1046 }
1047 let mut cov = Vec::with_capacity(tri * nk);
1048 for c in &self.smoothed_cov {
1049 for i in 0..order {
1050 for j in i..order {
1051 cov.push(c[i][j]);
1052 }
1053 }
1054 }
1055 let mut gain = Vec::with_capacity(order * order * nk);
1056 for g in &self.rts_gain {
1057 for i in 0..order {
1058 for j in 0..order {
1059 gain.push(g[i][j]);
1060 }
1061 }
1062 }
1063 SplineScanState {
1064 order: self.order,
1065 knots: self.knots.clone(),
1066 state,
1067 cov,
1068 gain,
1069 node_weight: self.node_weight.clone(),
1070 log_lambda: self.log_lambda,
1071 sigma2: self.sigma2,
1072 restricted_loglik: self.restricted_loglik,
1073 n_obs: self.n_obs as u64,
1074 }
1075 }
1076
1077 pub fn from_state(state: &SplineScanState) -> Result<Self, String> {
1085 let order = state.order;
1086 if order == 0 || order > MAX_ORDER {
1087 return Err(format!(
1088 "spline scan state: order must be in 1..={MAX_ORDER}, got {order}"
1089 ));
1090 }
1091 let m = state.knots.len();
1092 if m < order + 1 {
1093 return Err(format!(
1094 "spline scan state: order {order} needs at least {} knots, got {m}",
1095 order + 1
1096 ));
1097 }
1098 let tri = order * (order + 1) / 2;
1099 if state.state.len() != order * m
1100 || state.cov.len() != tri * m
1101 || state.gain.len() != order * order * m
1102 || state.node_weight.len() != m
1103 {
1104 return Err(format!(
1105 "spline scan state: inconsistent lengths (order={order}, m={m}, state={}, cov={}, gain={}, weights={})",
1106 state.state.len(),
1107 state.cov.len(),
1108 state.gain.len(),
1109 state.node_weight.len()
1110 ));
1111 }
1112 let all = state
1113 .state
1114 .iter()
1115 .chain(&state.cov)
1116 .chain(&state.gain)
1117 .chain(&state.knots)
1118 .chain(&state.node_weight);
1119 for (i, v) in all.enumerate() {
1120 if !v.is_finite() {
1121 return Err(format!("spline scan state: non-finite entry at {i}"));
1122 }
1123 }
1124 if !(state.log_lambda.is_finite()
1125 && state.restricted_loglik.is_finite()
1126 && state.sigma2.is_finite()
1127 && state.sigma2 > 0.0)
1128 {
1129 return Err(format!(
1130 "spline scan state: invalid scalars (log_lambda={}, sigma2={}, restricted_loglik={})",
1131 state.log_lambda, state.sigma2, state.restricted_loglik
1132 ));
1133 }
1134 if state.knots.windows(2).any(|kk| !(kk[0] < kk[1])) {
1135 return Err("spline scan state: knots must be strictly increasing".to_string());
1136 }
1137 if state.node_weight.iter().any(|&w| w <= 0.0) {
1138 return Err("spline scan state: node weights must be positive".to_string());
1139 }
1140 let smoothed_state: Vec<Vec2> = state
1141 .state
1142 .chunks_exact(order)
1143 .map(|s| {
1144 let mut v = [0.0_f64; MAX_ORDER];
1145 v[..order].copy_from_slice(s);
1146 v
1147 })
1148 .collect();
1149 let smoothed_cov: Vec<Mat2> = state
1150 .cov
1151 .chunks_exact(tri)
1152 .map(|c| {
1153 let mut mm = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
1154 let mut idx = 0;
1155 for i in 0..order {
1156 for j in i..order {
1157 mm[i][j] = c[idx];
1158 mm[j][i] = c[idx];
1159 idx += 1;
1160 }
1161 }
1162 mm
1163 })
1164 .collect();
1165 let rts_gain: Vec<Mat2> = state
1166 .gain
1167 .chunks_exact(order * order)
1168 .map(|g| {
1169 let mut mm = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
1170 for i in 0..order {
1171 for j in 0..order {
1172 mm[i][j] = g[i * order + j];
1173 }
1174 }
1175 mm
1176 })
1177 .collect();
1178 let sigma2 = state.sigma2;
1179 if state.n_obs == 0 {
1180 return Err("spline scan state: n_obs must be positive".to_string());
1181 }
1182 let n_obs = state.n_obs as usize;
1183 Ok(Self {
1184 order,
1185 knots: state.knots.clone(),
1186 mean: smoothed_state.iter().map(|s| s[0]).collect(),
1187 deriv: smoothed_state.iter().map(|s| s[1]).collect(),
1188 var: smoothed_cov.iter().map(|c| c[0][0] * sigma2).collect(),
1189 log_lambda: state.log_lambda,
1190 sigma2,
1191 restricted_loglik: state.restricted_loglik,
1192 n_obs,
1193 smoothed_state,
1194 smoothed_cov,
1195 rts_gain,
1196 q: (-state.log_lambda).exp(),
1197 node_weight: state.node_weight.clone(),
1198 })
1199 }
1200
1201 pub fn predict(&self, x_new: f64) -> Result<(f64, f64), String> {
1208 if !x_new.is_finite() {
1209 return Err("spline scan: non-finite prediction abscissa".to_string());
1210 }
1211 let n = self.knots.len();
1212 let order = self.order;
1213 let first = self.knots[0];
1214 let last = self.knots[n - 1];
1215 if x_new <= first {
1216 let delta = first - x_new;
1217 let f_t = transition(delta, order);
1219 let f_inv = mat_inv(&f_t, order, "backward extrapolation transition")?;
1220 let mean_s = mat_vec(&f_inv, &self.smoothed_state[0], order);
1221 let qm = process_noise(delta, self.q, order);
1222 let cov = mat_add(
1223 &mat_mul(
1224 &mat_mul(&f_inv, &self.smoothed_cov[0], order),
1225 &mat_t(&f_inv, order),
1226 order,
1227 ),
1228 &mat_mul(&mat_mul(&f_inv, &qm, order), &mat_t(&f_inv, order), order),
1229 order,
1230 );
1231 return Ok((mean_s[0], cov[0][0] * self.sigma2));
1232 }
1233 if x_new >= last {
1234 let delta = x_new - last;
1235 let f_t = transition(delta, order);
1236 let mean_s = mat_vec(&f_t, &self.smoothed_state[n - 1], order);
1237 let cov = mat_add(
1238 &mat_mul(
1239 &mat_mul(&f_t, &self.smoothed_cov[n - 1], order),
1240 &mat_t(&f_t, order),
1241 order,
1242 ),
1243 &process_noise(delta, self.q, order),
1244 order,
1245 );
1246 return Ok((mean_s[0], cov[0][0] * self.sigma2));
1247 }
1248 let t = match self.knots.binary_search_by(|k| k.total_cmp(&x_new)) {
1250 Ok(idx) => return Ok((self.mean[idx], self.var[idx])),
1251 Err(idx) => idx - 1,
1252 };
1253 let (xa, xb) = (self.knots[t], self.knots[t + 1]);
1254 let (d1, d2) = (x_new - xa, xb - x_new);
1255 let (f1m, f2m) = (transition(d1, order), transition(d2, order));
1256 let (q1, q2) = (
1257 process_noise(d1, self.q, order),
1258 process_noise(d2, self.q, order),
1259 );
1260 let q1_inv = mat_inv(&q1, order, "bridge left noise")?;
1261 let q2_inv = mat_inv(&q2, order, "bridge right noise")?;
1262 let lambda = mat_add(
1265 &q1_inv,
1266 &mat_mul(&mat_mul(&mat_t(&f2m, order), &q2_inv, order), &f2m, order),
1267 order,
1268 );
1269 let lam_inv = mat_inv(&lambda, order, "bridge precision")?;
1270 let ca = mat_mul(&lam_inv, &mat_mul(&q1_inv, &f1m, order), order);
1271 let cb = mat_mul(
1272 &lam_inv,
1273 &mat_mul(&mat_t(&f2m, order), &q2_inv, order),
1274 order,
1275 );
1276 let ma = mat_vec(&ca, &self.smoothed_state[t], order);
1277 let mb = mat_vec(&cb, &self.smoothed_state[t + 1], order);
1278 let mut mean_s = [0.0_f64; MAX_ORDER];
1279 for i in 0..order {
1280 mean_s[i] = ma[i] + mb[i];
1281 }
1282 let cross = mat_mul(&self.rts_gain[t], &self.smoothed_cov[t + 1], order);
1285 let mut cov = mat_add(
1286 &mat_add(
1287 &mat_mul(
1288 &mat_mul(&ca, &self.smoothed_cov[t], order),
1289 &mat_t(&ca, order),
1290 order,
1291 ),
1292 &mat_mul(
1293 &mat_mul(&cb, &self.smoothed_cov[t + 1], order),
1294 &mat_t(&cb, order),
1295 order,
1296 ),
1297 order,
1298 ),
1299 &lam_inv,
1300 order,
1301 );
1302 let cab = mat_mul(&mat_mul(&ca, &cross, order), &mat_t(&cb, order), order);
1303 cov = mat_add(&cov, &mat_add(&cab, &mat_t(&cab, order), order), order);
1304 symmetrize(&mut cov, order);
1305 Ok((mean_s[0], cov[0][0] * self.sigma2))
1306 }
1307
1308 pub fn edf(&self) -> f64 {
1322 self.node_weight
1323 .iter()
1324 .zip(self.smoothed_cov.iter())
1325 .map(|(w, c)| w * c[0][0])
1326 .sum()
1327 }
1328
1329 pub fn deriv_at_knot(&self, t: usize) -> (f64, f64) {
1331 (
1332 self.smoothed_state[t][1],
1333 self.smoothed_cov[t][1][1] * self.sigma2,
1334 )
1335 }
1336
1337 pub fn lambda(&self) -> f64 {
1339 self.log_lambda.exp()
1340 }
1341
1342 pub fn n_obs(&self) -> usize {
1344 self.n_obs
1345 }
1346
1347 pub fn deviance(&self) -> f64 {
1352 self.sigma2 * (self.n_obs as f64 - self.order as f64).max(0.0)
1353 }
1354}
1355
1356#[cfg(test)]
1357mod tests {
1358 use super::*;
1359
1360 fn round_trip_predict_bit_for_bit(order: usize) {
1369 let n = 60usize;
1370 let x: Vec<f64> = (0..n).map(|i| (i as f64) / (n as f64 - 1.0)).collect();
1371 let mut x = x;
1373 x[7] = x[6];
1374 let y: Vec<f64> = x
1375 .iter()
1376 .enumerate()
1377 .map(|(i, &xi)| {
1378 (6.0 * xi).sin() + 0.3 * (17.0 * xi).cos() + 0.05 * ((i * 37 % 11) as f64 - 5.0)
1379 })
1380 .collect();
1381 let w: Vec<f64> = (0..n).map(|i| 1.0 + 0.5 * ((i % 3) as f64)).collect();
1382 let fit = fit_spline_scan(&x, &y, &w, order).expect("scan fit");
1383 assert_eq!(fit.order, order);
1384 assert_eq!(fit.n_obs, n);
1387
1388 let json = serde_json::to_string(&fit.to_state()).expect("serialize state");
1389 let state: SplineScanState = serde_json::from_str(&json).expect("deserialize state");
1390 let restored = SplineScanFit::from_state(&state).expect("restore fit");
1391
1392 assert_eq!(fit.n_obs, restored.n_obs);
1393 assert_eq!(fit.deviance().to_bits(), restored.deviance().to_bits());
1394 assert_eq!(fit.knots, restored.knots);
1395 assert_eq!(fit.mean, restored.mean);
1396 assert_eq!(fit.var, restored.var);
1397 assert_eq!(fit.deriv, restored.deriv);
1398 assert_eq!(fit.log_lambda.to_bits(), restored.log_lambda.to_bits());
1399 assert_eq!(fit.sigma2.to_bits(), restored.sigma2.to_bits());
1400 assert_eq!(fit.edf().to_bits(), restored.edf().to_bits());
1401 for t in 0..fit.knots.len() {
1402 let (d0, v0) = fit.deriv_at_knot(t);
1403 let (d1, v1) = restored.deriv_at_knot(t);
1404 assert_eq!(d0.to_bits(), d1.to_bits());
1405 assert_eq!(v0.to_bits(), v1.to_bits());
1406 }
1407 for &xq in &[-0.2, 0.0, 0.013, 0.5, x[6], 0.987, 1.0, 1.3] {
1409 let (m0, v0) = fit.predict(xq).expect("predict original");
1410 let (m1, v1) = restored.predict(xq).expect("predict restored");
1411 assert_eq!(
1412 m0.to_bits(),
1413 m1.to_bits(),
1414 "mean drift at x={xq} (m={order})"
1415 );
1416 assert_eq!(
1417 v0.to_bits(),
1418 v1.to_bits(),
1419 "variance drift at x={xq} (m={order})"
1420 );
1421 }
1422
1423 let mut bad = fit.to_state();
1425 bad.cov.truncate(bad.cov.len() - 1);
1426 SplineScanFit::from_state(&bad).expect_err("length mismatch must error");
1427 let mut bad = fit.to_state();
1428 bad.sigma2 = -1.0;
1429 SplineScanFit::from_state(&bad).expect_err("non-positive sigma2 must error");
1430 let mut bad = fit.to_state();
1431 bad.knots[2] = bad.knots[1];
1432 SplineScanFit::from_state(&bad).expect_err("non-increasing knots must error");
1433 }
1434
1435 #[test]
1436 fn state_snapshot_round_trips_predict_bit_for_bit() {
1437 round_trip_predict_bit_for_bit(2);
1438 }
1439
1440 #[test]
1442 fn state_snapshot_round_trips_predict_bit_for_bit_order1() {
1443 round_trip_predict_bit_for_bit(1);
1444 }
1445
1446 #[test]
1447 fn state_snapshot_round_trips_predict_bit_for_bit_order3() {
1448 round_trip_predict_bit_for_bit(3);
1449 }
1450
1451 fn dense_rw_truth(x: &[f64], y: &[f64], w: &[f64], log_lambda: f64) -> (Vec<f64>, Vec<f64>) {
1457 let n = x.len();
1458 let q = (-log_lambda).exp();
1459 let mut prec = vec![vec![0.0_f64; n]; n];
1460 let mut rhs = vec![0.0_f64; n];
1461 for t in 0..n {
1462 prec[t][t] += w[t];
1463 rhs[t] += w[t] * y[t];
1464 }
1465 for t in 0..n - 1 {
1466 let p = 1.0 / (q * (x[t + 1] - x[t]));
1467 prec[t][t] += p;
1468 prec[t + 1][t + 1] += p;
1469 prec[t][t + 1] -= p;
1470 prec[t + 1][t] -= p;
1471 }
1472 let mut aug = prec.clone();
1474 let mut inv = vec![vec![0.0_f64; n]; n];
1475 for i in 0..n {
1476 inv[i][i] = 1.0;
1477 }
1478 for col in 0..n {
1479 let piv = (col..n)
1480 .max_by(|&a, &b| aug[a][col].abs().total_cmp(&aug[b][col].abs()))
1481 .unwrap();
1482 aug.swap(col, piv);
1483 inv.swap(col, piv);
1484 let d = aug[col][col];
1485 for k in 0..n {
1486 aug[col][k] /= d;
1487 inv[col][k] /= d;
1488 }
1489 for r in 0..n {
1490 if r == col {
1491 continue;
1492 }
1493 let f = aug[r][col];
1494 if f == 0.0 {
1495 continue;
1496 }
1497 for k in 0..n {
1498 aug[r][k] -= f * aug[col][k];
1499 inv[r][k] -= f * inv[col][k];
1500 }
1501 }
1502 }
1503 let mean: Vec<f64> = (0..n)
1504 .map(|i| (0..n).map(|j| inv[i][j] * rhs[j]).sum())
1505 .collect();
1506 let var: Vec<f64> = (0..n).map(|i| inv[i][i]).collect();
1507 (mean, var)
1508 }
1509
1510 #[test]
1514 fn order_one_scan_matches_dense_random_walk_posterior() {
1515 let n = 30usize;
1516 let x: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
1517 let y: Vec<f64> = x
1518 .iter()
1519 .enumerate()
1520 .map(|(i, &xi)| 2.0 * xi + 0.4 * (5.0 * xi).sin() + 0.05 * ((i * 13 % 7) as f64 - 3.0))
1521 .collect();
1522 let w = vec![1.0_f64; n];
1523 let fit = fit_spline_scan(&x, &y, &w, 1).expect("order-1 scan fit");
1524 assert_eq!(fit.order, 1);
1525
1526 let (mean, var) = dense_rw_truth(&x, &y, &w, fit.log_lambda);
1527 for t in 0..n {
1528 assert!(
1529 (fit.mean[t] - mean[t]).abs() <= 1e-7 * mean[t].abs().max(1e-3),
1530 "order-1 mean mismatch at {t}: scan={} dense={}",
1531 fit.mean[t],
1532 mean[t]
1533 );
1534 let se_scan = fit.var[t].sqrt();
1535 let se_dense = (var[t] * fit.sigma2).sqrt();
1536 assert!(
1537 (se_scan - se_dense).abs() <= 1e-7 * se_dense.max(1e-12),
1538 "order-1 SE mismatch at {t}: scan={se_scan} dense={se_dense}"
1539 );
1540 }
1541 let dense_edf: f64 = w.iter().zip(var.iter()).map(|(wt, vt)| wt * vt).sum();
1543 assert!(
1544 (fit.edf() - dense_edf).abs() <= 1e-7 * dense_edf.max(1e-12),
1545 "order-1 EDF mismatch: scan={} dense={dense_edf}",
1546 fit.edf()
1547 );
1548 assert!(fit.deriv.iter().all(|&d| d == 0.0));
1550 }
1551}