#[derive(Clone, Copy, Debug)]
struct PooledNode {
x: f64,
y: f64,
w: f64,
}
const LOG_LAMBDA_GRID: usize = 25;
const LOG_LAMBDA_LO: f64 = -18.0;
const LOG_LAMBDA_HI: f64 = 18.0;
const LOG_LAMBDA_TOL: f64 = 1e-7;
const INNOVATION_VAR_FLOOR: f64 = 1e-300;
const MAX_ORDER: usize = 3;
type Mat2 = [[f64; MAX_ORDER]; MAX_ORDER];
type Vec2 = [f64; MAX_ORDER];
#[inline]
fn mat_mul(a: &Mat2, b: &Mat2, m: usize) -> Mat2 {
let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
for i in 0..m {
for j in 0..m {
let mut acc = 0.0;
for k in 0..m {
acc += a[i][k] * b[k][j];
}
c[i][j] = acc;
}
}
c
}
#[inline]
fn mat_t(a: &Mat2, m: usize) -> Mat2 {
let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
for i in 0..m {
for j in 0..m {
c[i][j] = a[j][i];
}
}
c
}
#[inline]
fn mat_vec(a: &Mat2, v: &Vec2, m: usize) -> Vec2 {
let mut out = [0.0; MAX_ORDER];
for i in 0..m {
let mut acc = 0.0;
for j in 0..m {
acc += a[i][j] * v[j];
}
out[i] = acc;
}
out
}
#[inline]
fn mat_add(a: &Mat2, b: &Mat2, m: usize) -> Mat2 {
let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
for i in 0..m {
for j in 0..m {
c[i][j] = a[i][j] + b[i][j];
}
}
c
}
#[inline]
fn mat_sub(a: &Mat2, b: &Mat2, m: usize) -> Mat2 {
let mut c = [[0.0; MAX_ORDER]; MAX_ORDER];
for i in 0..m {
for j in 0..m {
c[i][j] = a[i][j] - b[i][j];
}
}
c
}
fn mat_inv(a: &Mat2, m: usize, what: &str) -> Result<Mat2, String> {
let mut out = [[0.0; MAX_ORDER]; MAX_ORDER];
match m {
1 => {
let d = a[0][0];
if !(d.is_finite() && d.abs() > 0.0) {
return Err(format!("spline scan: singular 1x1 in {what} (a00={d})"));
}
out[0][0] = 1.0 / d;
}
2 => {
let det = a[0][0] * a[1][1] - a[0][1] * a[1][0];
if !(det.is_finite() && det.abs() > 0.0) {
return Err(format!("spline scan: singular 2x2 in {what} (det={det})"));
}
out[0][0] = a[1][1] / det;
out[0][1] = -a[0][1] / det;
out[1][0] = -a[1][0] / det;
out[1][1] = a[0][0] / det;
}
3 => {
let c00 = a[1][1] * a[2][2] - a[1][2] * a[2][1];
let c01 = a[1][2] * a[2][0] - a[1][0] * a[2][2];
let c02 = a[1][0] * a[2][1] - a[1][1] * a[2][0];
let det = a[0][0] * c00 + a[0][1] * c01 + a[0][2] * c02;
if !(det.is_finite() && det.abs() > 0.0) {
return Err(format!("spline scan: singular 3x3 in {what} (det={det})"));
}
let inv_det = 1.0 / det;
out[0][0] = c00 * inv_det;
out[0][1] = (a[0][2] * a[2][1] - a[0][1] * a[2][2]) * inv_det;
out[0][2] = (a[0][1] * a[1][2] - a[0][2] * a[1][1]) * inv_det;
out[1][0] = c01 * inv_det;
out[1][1] = (a[0][0] * a[2][2] - a[0][2] * a[2][0]) * inv_det;
out[1][2] = (a[0][2] * a[1][0] - a[0][0] * a[1][2]) * inv_det;
out[2][0] = c02 * inv_det;
out[2][1] = (a[0][1] * a[2][0] - a[0][0] * a[2][1]) * inv_det;
out[2][2] = (a[0][0] * a[1][1] - a[0][1] * a[1][0]) * inv_det;
}
_ => return Err(format!("spline scan: unsupported order {m} in {what}")),
}
Ok(out)
}
fn dense_spd_inverse(a: &[Vec<f64>], what: &str) -> Result<Vec<Vec<f64>>, String> {
let d = a.len();
let s: Vec<f64> = (0..d)
.map(|i| {
let dii = a[i][i];
if dii.is_finite() && dii > 0.0 {
1.0 / dii.sqrt()
} else {
1.0
}
})
.collect();
let a_s: Vec<Vec<f64>> = (0..d)
.map(|i| (0..d).map(|j| s[i] * a[i][j] * s[j]).collect())
.collect();
let mut inv_s = gauss_jordan_inverse(&a_s, what)?;
let mut resid = vec![vec![0.0_f64; d]; d]; for i in 0..d {
for j in 0..d {
let mut ax = 0.0;
for k in 0..d {
ax += a_s[i][k] * inv_s[k][j];
}
resid[i][j] = f64::from(u8::from(i == j)) - ax;
}
}
let mut delta = vec![vec![0.0_f64; d]; d]; for i in 0..d {
for j in 0..d {
let mut acc = 0.0;
for k in 0..d {
acc += inv_s[i][k] * resid[k][j];
}
delta[i][j] = acc;
}
}
for i in 0..d {
for j in 0..d {
inv_s[i][j] += delta[i][j];
}
}
Ok((0..d)
.map(|i| (0..d).map(|j| s[i] * inv_s[i][j] * s[j]).collect())
.collect())
}
fn gauss_jordan_inverse(a: &[Vec<f64>], what: &str) -> Result<Vec<Vec<f64>>, String> {
let d = a.len();
let mut aug = a.to_vec();
let mut inv = vec![vec![0.0_f64; d]; d];
for i in 0..d {
inv[i][i] = 1.0;
}
for col in 0..d {
let piv = (col..d)
.max_by(|&i, &j| aug[i][col].abs().total_cmp(&aug[j][col].abs()))
.unwrap();
let p = aug[piv][col];
if !(p.is_finite() && p.abs() > 0.0) {
return Err(format!(
"spline scan: singular {d}x{d} in {what} (pivot={p})"
));
}
aug.swap(col, piv);
inv.swap(col, piv);
let d_piv = aug[col][col];
for k in 0..d {
aug[col][k] /= d_piv;
inv[col][k] /= d_piv;
}
for r in 0..d {
if r == col {
continue;
}
let f = aug[r][col];
if f == 0.0 {
continue;
}
for k in 0..d {
aug[r][k] -= f * aug[col][k];
inv[r][k] -= f * inv[col][k];
}
}
}
Ok(inv)
}
#[inline]
fn factorial(k: usize) -> f64 {
(1..=k).map(|v| v as f64).product::<f64>().max(1.0)
}
#[inline]
fn transition(delta: f64, m: usize) -> Mat2 {
let mut f = [[0.0; MAX_ORDER]; MAX_ORDER];
for i in 0..m {
for j in i..m {
f[i][j] = delta.powi((j - i) as i32) / factorial(j - i);
}
}
f
}
#[inline]
fn process_noise(delta: f64, q: f64, m: usize) -> Mat2 {
let mut out = [[0.0; MAX_ORDER]; MAX_ORDER];
for i in 0..m {
for j in 0..m {
let p = 2 * m - 1 - i - j;
out[i][j] = q * delta.powi(p as i32)
/ (factorial(m - 1 - i) * factorial(m - 1 - j) * (p as f64));
}
}
out
}
#[inline]
fn symmetrize(a: &mut Mat2, m: usize) {
for i in 0..m {
for j in (i + 1)..m {
let off = 0.5 * (a[i][j] + a[j][i]);
a[i][j] = off;
a[j][i] = off;
}
}
}
struct FilterStep {
a_filt: Vec2,
p_filt: Mat2,
a_pred: Vec2,
p_pred: Mat2,
}
struct FilterPass {
steps: Vec<FilterStep>,
sum_log_f: f64,
sum_v2_over_f: f64,
n_proper: usize,
}
fn run_filter(nodes: &[PooledNode], q: f64, order: usize) -> Result<FilterPass, String> {
let n = nodes.len();
let mut steps = Vec::with_capacity(n);
let mut a: Vec2 = [0.0; MAX_ORDER];
let mut p_star: Mat2 = [[0.0; MAX_ORDER]; MAX_ORDER];
let mut p_inf: Mat2 = [[0.0; MAX_ORDER]; MAX_ORDER];
for i in 0..order {
p_inf[i][i] = 1.0;
}
let mut diffuse_rank = order;
let mut sum_log_f = 0.0;
let mut sum_v2_over_f = 0.0;
let mut n_proper = 0usize;
for t in 0..n {
let a_pred = a;
let p_pred = p_star;
let r = 1.0 / nodes[t].w;
let v = nodes[t].y - a[0];
let mut m_star: Vec2 = [0.0; MAX_ORDER];
for i in 0..order {
m_star[i] = p_star[i][0];
}
let f_star = m_star[0] + r;
if diffuse_rank > 0 {
let mut m_inf: Vec2 = [0.0; MAX_ORDER];
for i in 0..order {
m_inf[i] = p_inf[i][0];
}
let f_inf = m_inf[0];
if f_inf > INNOVATION_VAR_FLOOR {
for i in 0..order {
a[i] += (m_inf[i] / f_inf) * v;
}
let mut p_new = p_star;
for i in 0..order {
for j in 0..order {
p_new[i][j] += -m_inf[i] * m_star[j] / f_inf - m_star[i] * m_inf[j] / f_inf
+ m_inf[i] * m_inf[j] * f_star / (f_inf * f_inf);
}
}
p_star = p_new;
symmetrize(&mut p_star, order);
for i in 0..order {
for j in 0..order {
p_inf[i][j] -= m_inf[i] * m_inf[j] / f_inf;
}
}
symmetrize(&mut p_inf, order);
diffuse_rank -= 1;
if diffuse_rank == 0 {
p_inf = [[0.0; MAX_ORDER]; MAX_ORDER];
}
} else {
if f_star <= INNOVATION_VAR_FLOOR {
return Err("spline scan: non-positive innovation variance".to_string());
}
for i in 0..order {
a[i] += (m_star[i] / f_star) * v;
}
for i in 0..order {
for j in 0..order {
p_star[i][j] -= m_star[i] * m_star[j] / f_star;
}
}
symmetrize(&mut p_star, order);
sum_log_f += f_star.ln();
sum_v2_over_f += v * v / f_star;
n_proper += 1;
}
} else {
if f_star <= INNOVATION_VAR_FLOOR {
return Err("spline scan: non-positive innovation variance".to_string());
}
for i in 0..order {
a[i] += (m_star[i] / f_star) * v;
}
for i in 0..order {
for j in 0..order {
p_star[i][j] -= m_star[i] * m_star[j] / f_star;
}
}
symmetrize(&mut p_star, order);
sum_log_f += f_star.ln();
sum_v2_over_f += v * v / f_star;
n_proper += 1;
}
steps.push(FilterStep {
a_filt: a,
p_filt: p_star,
a_pred,
p_pred,
});
if t + 1 < n {
let delta = nodes[t + 1].x - nodes[t].x;
let f_t = transition(delta, order);
a = mat_vec(&f_t, &a, order);
let mut p_next = mat_add(
&mat_mul(&mat_mul(&f_t, &p_star, order), &mat_t(&f_t, order), order),
&process_noise(delta, q, order),
order,
);
symmetrize(&mut p_next, order);
p_star = p_next;
if diffuse_rank > 0 {
let mut pi_next =
mat_mul(&mat_mul(&f_t, &p_inf, order), &mat_t(&f_t, order), order);
symmetrize(&mut pi_next, order);
p_inf = pi_next;
}
}
}
Ok(FilterPass {
steps,
sum_log_f,
sum_v2_over_f,
n_proper,
})
}
#[derive(Clone, Debug)]
pub struct SplineScanFit {
pub order: usize,
pub knots: Vec<f64>,
pub mean: Vec<f64>,
pub deriv: Vec<f64>,
pub var: Vec<f64>,
pub log_lambda: f64,
pub sigma2: f64,
pub restricted_loglik: f64,
pub n_obs: usize,
smoothed_state: Vec<Vec2>,
smoothed_cov: Vec<Mat2>,
rts_gain: Vec<Mat2>,
q: f64,
node_weight: Vec<f64>,
}
fn pool_nodes(
x: &[f64],
y: &[f64],
w: &[f64],
order: usize,
) -> Result<(Vec<PooledNode>, f64, usize), String> {
let n = x.len();
if y.len() != n || w.len() != n {
return Err(format!(
"spline scan: length mismatch x={n}, y={}, w={}",
y.len(),
w.len()
));
}
for i in 0..n {
if !(x[i].is_finite() && y[i].is_finite() && w[i].is_finite() && w[i] > 0.0) {
return Err(format!(
"spline scan: non-finite or non-positive input at row {i} (x={}, y={}, w={})",
x[i], y[i], w[i]
));
}
}
let mut perm: Vec<usize> = (0..n).collect();
perm.sort_by(|&i, &j| x[i].total_cmp(&x[j]));
let mut nodes: Vec<PooledNode> = Vec::new();
for &i in &perm {
match nodes.last_mut() {
Some(last) if last.x == x[i] => {
let w_new = last.w + w[i];
last.y = (last.y * last.w + y[i] * w[i]) / w_new;
last.w = w_new;
}
_ => nodes.push(PooledNode {
x: x[i],
y: y[i],
w: w[i],
}),
}
}
if nodes.len() < order + 1 {
return Err(format!(
"spline scan: order {order} needs at least {} distinct abscissae, got {}",
order + 1,
nodes.len()
));
}
let mut ssr_within = 0.0;
let mut k = 0usize;
for &i in &perm {
while nodes[k].x != x[i] {
k += 1;
}
let d = y[i] - nodes[k].y;
ssr_within += w[i] * d * d;
}
Ok((nodes, ssr_within, n))
}
fn concentrated_criterion(
nodes: &[PooledNode],
ssr_within: f64,
n_obs: usize,
log_lambda: f64,
order: usize,
) -> Result<f64, String> {
let pass = run_filter(nodes, (-log_lambda).exp(), order)?;
let dof = (n_obs - order) as f64;
let rss = pass.sum_v2_over_f + ssr_within;
if rss <= 0.0 {
return Err("spline scan: degenerate zero residual sum".to_string());
}
let sigma2 = rss / dof;
if pass.n_proper != nodes.len() - order {
return Err(format!(
"spline scan: expected {} proper innovations, got {} (diffuse rank not consumed)",
nodes.len() - order,
pass.n_proper
));
}
Ok(-0.5 * (pass.sum_log_f + dof * sigma2.ln()))
}
fn leading_block_smooth(
sm_state: &mut [Vec2],
sm_cov: &mut [Mat2],
gains: &mut [Mat2],
nodes: &[PooledNode],
q: f64,
order: usize,
) -> Result<(), String> {
let nb = order - 1; let pin = order - 1; let d = nb * order; let mut lambda = vec![vec![0.0_f64; d]; d];
let mut b_const = vec![0.0_f64; d];
let mut bmat = vec![vec![0.0_f64; order]; d];
for t in 0..order - 1 {
let delta = nodes[t + 1].x - nodes[t].x;
let f = transition(delta, order);
let qn = process_noise(delta, q, order);
let a = mat_inv(&qn, order, "leading-block increment noise")?; let ft = mat_t(&f, order);
let fta = mat_mul(&ft, &a, order); let ftaf = mat_mul(&fta, &f, order); let af = mat_mul(&a, &f, order); for i in 0..order {
for j in 0..order {
lambda[t * order + i][t * order + j] += ftaf[i][j];
}
}
if t + 1 <= nb - 1 {
for i in 0..order {
for j in 0..order {
lambda[(t + 1) * order + i][(t + 1) * order + j] += a[i][j];
lambda[t * order + i][(t + 1) * order + j] -= fta[i][j];
lambda[(t + 1) * order + i][t * order + j] -= af[i][j];
}
}
} else {
for i in 0..order {
for j in 0..order {
bmat[t * order + i][j] += fta[i][j];
}
}
}
}
for t in 0..nb {
let w = nodes[t].w;
lambda[t * order][t * order] += w;
b_const[t * order] += w * nodes[t].y;
}
let sigma = dense_spd_inverse(&lambda, "leading-block precision")?;
let dvec: Vec<f64> = (0..d)
.map(|i| (0..d).map(|k| sigma[i][k] * b_const[k]).sum())
.collect();
let cmat: Vec<Vec<f64>> = (0..d)
.map(|i| {
(0..order)
.map(|j| (0..d).map(|k| sigma[i][k] * bmat[k][j]).sum())
.collect()
})
.collect();
let ahat_p = sm_state[pin];
let vp = sm_cov[pin];
let cvp: Vec<Vec<f64>> = (0..d)
.map(|i| {
(0..order)
.map(|j| (0..order).map(|k| cmat[i][k] * vp[k][j]).sum())
.collect()
})
.collect();
let mean_u: Vec<f64> = (0..d)
.map(|i| (0..order).map(|j| cmat[i][j] * ahat_p[j]).sum::<f64>() + dvec[i])
.collect();
let cov_u: Vec<Vec<f64>> = (0..d)
.map(|i| {
(0..d)
.map(|k| (0..order).map(|j| cvp[i][j] * cmat[k][j]).sum::<f64>() + sigma[i][k])
.collect()
})
.collect();
for j in 0..nb {
for i in 0..order {
sm_state[j][i] = mean_u[j * order + i];
}
let mut cov = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
for i in 0..order {
for k in 0..order {
cov[i][k] = cov_u[j * order + i][j * order + k];
}
}
symmetrize(&mut cov, order);
sm_cov[j] = cov;
}
for j in 0..nb {
let mut cross = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
if j + 1 <= nb - 1 {
for i in 0..order {
for k in 0..order {
cross[i][k] = cov_u[j * order + i][(j + 1) * order + k];
}
}
} else {
for i in 0..order {
for k in 0..order {
cross[i][k] = cvp[j * order + i][k];
}
}
}
let denom_inv = mat_inv(&sm_cov[j + 1], order, "leading-block gain denominator")?;
gains[j] = mat_mul(&cross, &denom_inv, order);
}
Ok(())
}
pub fn fit_spline_scan_at(
x: &[f64],
y: &[f64],
w: &[f64],
log_lambda: f64,
sigma2: Option<f64>,
order: usize,
) -> Result<SplineScanFit, String> {
if order == 0 || order > MAX_ORDER {
return Err(format!(
"spline scan: order must be in 1..={MAX_ORDER}, got {order}"
));
}
let (nodes, ssr_within, n_obs) = pool_nodes(x, y, w, order)?;
let q = (-log_lambda).exp();
let pass = run_filter(&nodes, q, order)?;
let n = nodes.len();
let dof = (n_obs - order) as f64;
let sigma2 = match sigma2 {
Some(s) => {
if !(s.is_finite() && s > 0.0) {
return Err(format!("spline scan: invalid sigma2 {s}"));
}
s
}
None => (pass.sum_v2_over_f + ssr_within) / dof,
};
let rss = pass.sum_v2_over_f + ssr_within;
let restricted_loglik = -0.5 * (pass.sum_log_f + dof * sigma2.ln() + rss / sigma2);
let mut sm_state = vec![[0.0_f64; MAX_ORDER]; n];
let mut sm_cov = vec![[[0.0_f64; MAX_ORDER]; MAX_ORDER]; n];
let mut gains = vec![[[0.0_f64; MAX_ORDER]; MAX_ORDER]; n];
sm_state[n - 1] = pass.steps[n - 1].a_filt;
sm_cov[n - 1] = pass.steps[n - 1].p_filt;
for t in (order - 1..n - 1).rev() {
let p_next_pred = &pass.steps[t + 1].p_pred;
let delta = nodes[t + 1].x - nodes[t].x;
let f_t = transition(delta, order);
let p_inv = mat_inv(p_next_pred, order, "RTS predicted covariance")?;
let g = mat_mul(
&mat_mul(&pass.steps[t].p_filt, &mat_t(&f_t, order), order),
&p_inv,
order,
);
let mut dm: Vec2 = [0.0; MAX_ORDER];
for i in 0..order {
dm[i] = sm_state[t + 1][i] - pass.steps[t + 1].a_pred[i];
}
let corr = mat_vec(&g, &dm, order);
for i in 0..order {
sm_state[t][i] = pass.steps[t].a_filt[i] + corr[i];
}
let dp = mat_sub(&sm_cov[t + 1], p_next_pred, order);
let mut cov = mat_add(
&pass.steps[t].p_filt,
&mat_mul(&mat_mul(&g, &dp, order), &mat_t(&g, order), order),
order,
);
symmetrize(&mut cov, order);
sm_cov[t] = cov;
gains[t] = g;
}
if order >= 2 {
leading_block_smooth(&mut sm_state, &mut sm_cov, &mut gains, &nodes, q, order)?;
}
let knots: Vec<f64> = nodes.iter().map(|n| n.x).collect();
let mean: Vec<f64> = sm_state.iter().map(|s| s[0]).collect();
let deriv: Vec<f64> = sm_state
.iter()
.map(|s| if order >= 2 { s[1] } else { 0.0 })
.collect();
let var: Vec<f64> = sm_cov.iter().map(|p| p[0][0] * sigma2).collect();
Ok(SplineScanFit {
order,
knots,
mean,
deriv,
var,
log_lambda,
sigma2,
restricted_loglik,
n_obs,
smoothed_state: sm_state,
smoothed_cov: sm_cov,
rts_gain: gains,
q,
node_weight: nodes.iter().map(|n| n.w).collect(),
})
}
pub fn fit_spline_scan(
x: &[f64],
y: &[f64],
w: &[f64],
order: usize,
) -> Result<SplineScanFit, String> {
if order == 0 || order > MAX_ORDER {
return Err(format!(
"spline scan: order must be in 1..={MAX_ORDER}, got {order}"
));
}
let (nodes, ssr_within, n_obs) = pool_nodes(x, y, w, order)?;
let crit = |ll: f64| concentrated_criterion(&nodes, ssr_within, n_obs, ll, order);
let mut best_i = 0usize;
let mut best_v = f64::NEG_INFINITY;
let step = (LOG_LAMBDA_HI - LOG_LAMBDA_LO) / (LOG_LAMBDA_GRID - 1) as f64;
for i in 0..LOG_LAMBDA_GRID {
let ll = LOG_LAMBDA_LO + step * i as f64;
let v = crit(ll)?;
if v > best_v {
best_v = v;
best_i = i;
}
}
let mut lo = LOG_LAMBDA_LO + step * best_i.saturating_sub(1) as f64;
let mut hi = (LOG_LAMBDA_LO + step * (best_i + 1) as f64).min(LOG_LAMBDA_HI);
let inv_phi = 0.618_033_988_749_894_9_f64;
let mut x1 = hi - inv_phi * (hi - lo);
let mut x2 = lo + inv_phi * (hi - lo);
let mut f1 = crit(x1)?;
let mut f2 = crit(x2)?;
while hi - lo > LOG_LAMBDA_TOL {
if f1 < f2 {
lo = x1;
x1 = x2;
f1 = f2;
x2 = lo + inv_phi * (hi - lo);
f2 = crit(x2)?;
} else {
hi = x2;
x2 = x1;
f2 = f1;
x1 = hi - inv_phi * (hi - lo);
f1 = crit(x1)?;
}
}
fit_spline_scan_at(x, y, w, 0.5 * (lo + hi), None, order)
}
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct SplineScanState {
#[serde(default = "default_spline_scan_order")]
pub order: usize,
pub knots: Vec<f64>,
pub state: Vec<f64>,
pub cov: Vec<f64>,
pub gain: Vec<f64>,
pub node_weight: Vec<f64>,
pub log_lambda: f64,
pub sigma2: f64,
pub restricted_loglik: f64,
#[serde(default)]
pub n_obs: u64,
}
fn default_spline_scan_order() -> usize {
2
}
impl SplineScanFit {
pub fn to_state(&self) -> SplineScanState {
let order = self.order;
let tri = order * (order + 1) / 2;
let nk = self.knots.len();
let mut state = Vec::with_capacity(order * nk);
for s in &self.smoothed_state {
state.extend_from_slice(&s[..order]);
}
let mut cov = Vec::with_capacity(tri * nk);
for c in &self.smoothed_cov {
for i in 0..order {
for j in i..order {
cov.push(c[i][j]);
}
}
}
let mut gain = Vec::with_capacity(order * order * nk);
for g in &self.rts_gain {
for i in 0..order {
for j in 0..order {
gain.push(g[i][j]);
}
}
}
SplineScanState {
order: self.order,
knots: self.knots.clone(),
state,
cov,
gain,
node_weight: self.node_weight.clone(),
log_lambda: self.log_lambda,
sigma2: self.sigma2,
restricted_loglik: self.restricted_loglik,
n_obs: self.n_obs as u64,
}
}
pub fn from_state(state: &SplineScanState) -> Result<Self, String> {
let order = state.order;
if order == 0 || order > MAX_ORDER {
return Err(format!(
"spline scan state: order must be in 1..={MAX_ORDER}, got {order}"
));
}
let m = state.knots.len();
if m < order + 1 {
return Err(format!(
"spline scan state: order {order} needs at least {} knots, got {m}",
order + 1
));
}
let tri = order * (order + 1) / 2;
if state.state.len() != order * m
|| state.cov.len() != tri * m
|| state.gain.len() != order * order * m
|| state.node_weight.len() != m
{
return Err(format!(
"spline scan state: inconsistent lengths (order={order}, m={m}, state={}, cov={}, gain={}, weights={})",
state.state.len(),
state.cov.len(),
state.gain.len(),
state.node_weight.len()
));
}
let all = state
.state
.iter()
.chain(&state.cov)
.chain(&state.gain)
.chain(&state.knots)
.chain(&state.node_weight);
for (i, v) in all.enumerate() {
if !v.is_finite() {
return Err(format!("spline scan state: non-finite entry at {i}"));
}
}
if !(state.log_lambda.is_finite()
&& state.restricted_loglik.is_finite()
&& state.sigma2.is_finite()
&& state.sigma2 > 0.0)
{
return Err(format!(
"spline scan state: invalid scalars (log_lambda={}, sigma2={}, restricted_loglik={})",
state.log_lambda, state.sigma2, state.restricted_loglik
));
}
if state.knots.windows(2).any(|kk| !(kk[0] < kk[1])) {
return Err("spline scan state: knots must be strictly increasing".to_string());
}
if state.node_weight.iter().any(|&w| w <= 0.0) {
return Err("spline scan state: node weights must be positive".to_string());
}
let smoothed_state: Vec<Vec2> = state
.state
.chunks_exact(order)
.map(|s| {
let mut v = [0.0_f64; MAX_ORDER];
v[..order].copy_from_slice(s);
v
})
.collect();
let smoothed_cov: Vec<Mat2> = state
.cov
.chunks_exact(tri)
.map(|c| {
let mut mm = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
let mut idx = 0;
for i in 0..order {
for j in i..order {
mm[i][j] = c[idx];
mm[j][i] = c[idx];
idx += 1;
}
}
mm
})
.collect();
let rts_gain: Vec<Mat2> = state
.gain
.chunks_exact(order * order)
.map(|g| {
let mut mm = [[0.0_f64; MAX_ORDER]; MAX_ORDER];
for i in 0..order {
for j in 0..order {
mm[i][j] = g[i * order + j];
}
}
mm
})
.collect();
let sigma2 = state.sigma2;
let n_obs = if state.n_obs > 0 {
state.n_obs as usize
} else {
(state.node_weight.iter().sum::<f64>().round() as usize).max(m)
};
Ok(Self {
order,
knots: state.knots.clone(),
mean: smoothed_state.iter().map(|s| s[0]).collect(),
deriv: smoothed_state.iter().map(|s| s[1]).collect(),
var: smoothed_cov.iter().map(|c| c[0][0] * sigma2).collect(),
log_lambda: state.log_lambda,
sigma2,
restricted_loglik: state.restricted_loglik,
n_obs,
smoothed_state,
smoothed_cov,
rts_gain,
q: (-state.log_lambda).exp(),
node_weight: state.node_weight.clone(),
})
}
pub fn predict(&self, x_new: f64) -> Result<(f64, f64), String> {
if !x_new.is_finite() {
return Err("spline scan: non-finite prediction abscissa".to_string());
}
let n = self.knots.len();
let order = self.order;
let first = self.knots[0];
let last = self.knots[n - 1];
if x_new <= first {
let delta = first - x_new;
let f_t = transition(delta, order);
let f_inv = mat_inv(&f_t, order, "backward extrapolation transition")?;
let mean_s = mat_vec(&f_inv, &self.smoothed_state[0], order);
let qm = process_noise(delta, self.q, order);
let cov = mat_add(
&mat_mul(
&mat_mul(&f_inv, &self.smoothed_cov[0], order),
&mat_t(&f_inv, order),
order,
),
&mat_mul(&mat_mul(&f_inv, &qm, order), &mat_t(&f_inv, order), order),
order,
);
return Ok((mean_s[0], cov[0][0] * self.sigma2));
}
if x_new >= last {
let delta = x_new - last;
let f_t = transition(delta, order);
let mean_s = mat_vec(&f_t, &self.smoothed_state[n - 1], order);
let cov = mat_add(
&mat_mul(
&mat_mul(&f_t, &self.smoothed_cov[n - 1], order),
&mat_t(&f_t, order),
order,
),
&process_noise(delta, self.q, order),
order,
);
return Ok((mean_s[0], cov[0][0] * self.sigma2));
}
let t = match self.knots.binary_search_by(|k| k.total_cmp(&x_new)) {
Ok(idx) => return Ok((self.mean[idx], self.var[idx])),
Err(idx) => idx - 1,
};
let (xa, xb) = (self.knots[t], self.knots[t + 1]);
let (d1, d2) = (x_new - xa, xb - x_new);
let (f1m, f2m) = (transition(d1, order), transition(d2, order));
let (q1, q2) = (
process_noise(d1, self.q, order),
process_noise(d2, self.q, order),
);
let q1_inv = mat_inv(&q1, order, "bridge left noise")?;
let q2_inv = mat_inv(&q2, order, "bridge right noise")?;
let lambda = mat_add(
&q1_inv,
&mat_mul(&mat_mul(&mat_t(&f2m, order), &q2_inv, order), &f2m, order),
order,
);
let lam_inv = mat_inv(&lambda, order, "bridge precision")?;
let ca = mat_mul(&lam_inv, &mat_mul(&q1_inv, &f1m, order), order);
let cb = mat_mul(
&lam_inv,
&mat_mul(&mat_t(&f2m, order), &q2_inv, order),
order,
);
let ma = mat_vec(&ca, &self.smoothed_state[t], order);
let mb = mat_vec(&cb, &self.smoothed_state[t + 1], order);
let mut mean_s = [0.0_f64; MAX_ORDER];
for i in 0..order {
mean_s[i] = ma[i] + mb[i];
}
let cross = mat_mul(&self.rts_gain[t], &self.smoothed_cov[t + 1], order);
let mut cov = mat_add(
&mat_add(
&mat_mul(
&mat_mul(&ca, &self.smoothed_cov[t], order),
&mat_t(&ca, order),
order,
),
&mat_mul(
&mat_mul(&cb, &self.smoothed_cov[t + 1], order),
&mat_t(&cb, order),
order,
),
order,
),
&lam_inv,
order,
);
let cab = mat_mul(&mat_mul(&ca, &cross, order), &mat_t(&cb, order), order);
cov = mat_add(&cov, &mat_add(&cab, &mat_t(&cab, order), order), order);
symmetrize(&mut cov, order);
Ok((mean_s[0], cov[0][0] * self.sigma2))
}
pub fn edf(&self) -> f64 {
self.node_weight
.iter()
.zip(self.smoothed_cov.iter())
.map(|(w, c)| w * c[0][0])
.sum()
}
pub fn deriv_at_knot(&self, t: usize) -> (f64, f64) {
(
self.smoothed_state[t][1],
self.smoothed_cov[t][1][1] * self.sigma2,
)
}
pub fn lambda(&self) -> f64 {
self.log_lambda.exp()
}
pub fn n_obs(&self) -> usize {
self.n_obs
}
pub fn deviance(&self) -> f64 {
self.sigma2 * (self.n_obs as f64 - self.order as f64).max(0.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn round_trip_predict_bit_for_bit(order: usize) {
let n = 60usize;
let x: Vec<f64> = (0..n).map(|i| (i as f64) / (n as f64 - 1.0)).collect();
let mut x = x;
x[7] = x[6];
let y: Vec<f64> = x
.iter()
.enumerate()
.map(|(i, &xi)| {
(6.0 * xi).sin() + 0.3 * (17.0 * xi).cos() + 0.05 * ((i * 37 % 11) as f64 - 5.0)
})
.collect();
let w: Vec<f64> = (0..n).map(|i| 1.0 + 0.5 * ((i % 3) as f64)).collect();
let fit = fit_spline_scan(&x, &y, &w, order).expect("scan fit");
assert_eq!(fit.order, order);
assert_eq!(fit.n_obs, n);
let json = serde_json::to_string(&fit.to_state()).expect("serialize state");
let state: SplineScanState = serde_json::from_str(&json).expect("deserialize state");
let restored = SplineScanFit::from_state(&state).expect("restore fit");
assert_eq!(fit.n_obs, restored.n_obs);
assert_eq!(fit.deviance().to_bits(), restored.deviance().to_bits());
assert_eq!(fit.knots, restored.knots);
assert_eq!(fit.mean, restored.mean);
assert_eq!(fit.var, restored.var);
assert_eq!(fit.deriv, restored.deriv);
assert_eq!(fit.log_lambda.to_bits(), restored.log_lambda.to_bits());
assert_eq!(fit.sigma2.to_bits(), restored.sigma2.to_bits());
assert_eq!(fit.edf().to_bits(), restored.edf().to_bits());
for t in 0..fit.knots.len() {
let (d0, v0) = fit.deriv_at_knot(t);
let (d1, v1) = restored.deriv_at_knot(t);
assert_eq!(d0.to_bits(), d1.to_bits());
assert_eq!(v0.to_bits(), v1.to_bits());
}
for &xq in &[-0.2, 0.0, 0.013, 0.5, x[6], 0.987, 1.0, 1.3] {
let (m0, v0) = fit.predict(xq).expect("predict original");
let (m1, v1) = restored.predict(xq).expect("predict restored");
assert_eq!(
m0.to_bits(),
m1.to_bits(),
"mean drift at x={xq} (m={order})"
);
assert_eq!(
v0.to_bits(),
v1.to_bits(),
"variance drift at x={xq} (m={order})"
);
}
let mut bad = fit.to_state();
bad.cov.truncate(bad.cov.len() - 1);
SplineScanFit::from_state(&bad).expect_err("length mismatch must error");
let mut bad = fit.to_state();
bad.sigma2 = -1.0;
SplineScanFit::from_state(&bad).expect_err("non-positive sigma2 must error");
let mut bad = fit.to_state();
bad.knots[2] = bad.knots[1];
SplineScanFit::from_state(&bad).expect_err("non-increasing knots must error");
}
#[test]
fn state_snapshot_round_trips_predict_bit_for_bit() {
round_trip_predict_bit_for_bit(2);
}
#[test]
fn state_snapshot_round_trips_predict_bit_for_bit_order1() {
round_trip_predict_bit_for_bit(1);
}
#[test]
fn state_snapshot_round_trips_predict_bit_for_bit_order3() {
round_trip_predict_bit_for_bit(3);
}
#[test]
fn legacy_snapshot_recovers_n_obs_from_node_weights() {
let n = 40usize;
let x: Vec<f64> = (0..n).map(|i| (i as f64) / (n as f64 - 1.0)).collect();
let y: Vec<f64> = x.iter().map(|&xi| (5.0 * xi).sin()).collect();
let w = vec![1.0; n];
let fit = fit_spline_scan(&x, &y, &w, 2).expect("scan fit");
let mut legacy = fit.to_state();
legacy.n_obs = 0; let restored = SplineScanFit::from_state(&legacy).expect("restore legacy");
assert_eq!(restored.n_obs, n);
assert!(restored.deviance() > 0.0 && restored.deviance().is_finite());
}
fn dense_rw_truth(x: &[f64], y: &[f64], w: &[f64], log_lambda: f64) -> (Vec<f64>, Vec<f64>) {
let n = x.len();
let q = (-log_lambda).exp();
let mut prec = vec![vec![0.0_f64; n]; n];
let mut rhs = vec![0.0_f64; n];
for t in 0..n {
prec[t][t] += w[t];
rhs[t] += w[t] * y[t];
}
for t in 0..n - 1 {
let p = 1.0 / (q * (x[t + 1] - x[t]));
prec[t][t] += p;
prec[t + 1][t + 1] += p;
prec[t][t + 1] -= p;
prec[t + 1][t] -= p;
}
let mut aug = prec.clone();
let mut inv = vec![vec![0.0_f64; n]; n];
for i in 0..n {
inv[i][i] = 1.0;
}
for col in 0..n {
let piv = (col..n)
.max_by(|&a, &b| aug[a][col].abs().total_cmp(&aug[b][col].abs()))
.unwrap();
aug.swap(col, piv);
inv.swap(col, piv);
let d = aug[col][col];
for k in 0..n {
aug[col][k] /= d;
inv[col][k] /= d;
}
for r in 0..n {
if r == col {
continue;
}
let f = aug[r][col];
if f == 0.0 {
continue;
}
for k in 0..n {
aug[r][k] -= f * aug[col][k];
inv[r][k] -= f * inv[col][k];
}
}
}
let mean: Vec<f64> = (0..n)
.map(|i| (0..n).map(|j| inv[i][j] * rhs[j]).sum())
.collect();
let var: Vec<f64> = (0..n).map(|i| inv[i][i]).collect();
(mean, var)
}
#[test]
fn order_one_scan_matches_dense_random_walk_posterior() {
let n = 30usize;
let x: Vec<f64> = (0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect();
let y: Vec<f64> = x
.iter()
.enumerate()
.map(|(i, &xi)| 2.0 * xi + 0.4 * (5.0 * xi).sin() + 0.05 * ((i * 13 % 7) as f64 - 3.0))
.collect();
let w = vec![1.0_f64; n];
let fit = fit_spline_scan(&x, &y, &w, 1).expect("order-1 scan fit");
assert_eq!(fit.order, 1);
let (mean, var) = dense_rw_truth(&x, &y, &w, fit.log_lambda);
for t in 0..n {
assert!(
(fit.mean[t] - mean[t]).abs() <= 1e-7 * mean[t].abs().max(1e-3),
"order-1 mean mismatch at {t}: scan={} dense={}",
fit.mean[t],
mean[t]
);
let se_scan = fit.var[t].sqrt();
let se_dense = (var[t] * fit.sigma2).sqrt();
assert!(
(se_scan - se_dense).abs() <= 1e-7 * se_dense.max(1e-12),
"order-1 SE mismatch at {t}: scan={se_scan} dense={se_dense}"
);
}
let dense_edf: f64 = w.iter().zip(var.iter()).map(|(wt, vt)| wt * vt).sum();
assert!(
(fit.edf() - dense_edf).abs() <= 1e-7 * dense_edf.max(1e-12),
"order-1 EDF mismatch: scan={} dense={dense_edf}",
fit.edf()
);
assert!(fit.deriv.iter().all(|&d| d == 0.0));
}
}