use ndarray::{s, Array2};
use crate::ebayes::quantile_type7;
use crate::linalg::qr_full_q;
fn bspline_row(t: &[f64], xv: f64, ord: usize, deriv: usize) -> Vec<f64> {
let m = t.len();
let mut tables: Vec<Vec<f64>> = Vec::with_capacity(ord);
let mut n1 = vec![0.0_f64; m - 1];
for i in 0..(m - 1) {
if t[i] <= xv && xv < t[i + 1] {
n1[i] = 1.0;
}
}
if xv >= t[m - 1] {
for i in (0..(m - 1)).rev() {
if t[i] < t[i + 1] {
n1[i] = 1.0;
break;
}
}
}
tables.push(n1);
for o in 2..=ord {
let prev = &tables[o - 2];
let len = m - o;
let mut no = vec![0.0_f64; len];
for j in 0..len {
let d1 = t[j + o - 1] - t[j];
let a1 = if d1 != 0.0 { (xv - t[j]) / d1 } else { 0.0 };
let d2 = t[j + o] - t[j + 1];
let a2 = if d2 != 0.0 { (t[j + o] - xv) / d2 } else { 0.0 };
no[j] = a1 * prev[j] + a2 * prev[j + 1];
}
tables.push(no);
}
match deriv {
0 => tables[ord - 1].clone(),
1 => {
let prev = &tables[ord - 2];
deriv_combine(t, prev, ord)
}
2 => {
let lower = &tables[ord - 3];
let d1 = deriv_combine(t, lower, ord - 1);
deriv_combine(t, &d1, ord)
}
_ => panic!("bspline_row: deriv must be 0, 1 or 2"),
}
}
fn deriv_combine(t: &[f64], prev: &[f64], ord: usize) -> Vec<f64> {
let m = t.len();
let len = m - ord;
let scale = (ord - 1) as f64;
let mut out = vec![0.0_f64; len];
for j in 0..len {
let d1 = t[j + ord - 1] - t[j];
let term1 = if d1 != 0.0 { prev[j] / d1 } else { 0.0 };
let d2 = t[j + ord] - t[j + 1];
let term2 = if d2 != 0.0 { prev[j + 1] / d2 } else { 0.0 };
out[j] = scale * (term1 - term2);
}
out
}
fn spline_design(knots: &[f64], x: &[f64], ord: usize, deriv: usize) -> Array2<f64> {
let nb = knots.len() - ord;
let mut out = Array2::<f64>::zeros((x.len(), nb));
for (r, &xv) in x.iter().enumerate() {
let row = bspline_row(knots, xv, ord, deriv);
for (c, val) in row.into_iter().enumerate() {
out[[r, c]] = val;
}
}
out
}
pub(crate) struct NsBasis {
aknots: Vec<f64>,
xmin: f64,
xmax: f64,
nbasis: Array2<f64>,
}
impl NsBasis {
pub(crate) fn fit(x: &[f64], df: usize) -> NsBasis {
assert!(df >= 2, "NsBasis requires df >= 2");
let xmin = x.iter().cloned().fold(f64::INFINITY, f64::min);
let xmax = x.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let n_iknots = df - 2;
let iknots: Vec<f64> = if n_iknots > 0 {
let probs: Vec<f64> = (1..=n_iknots)
.map(|i| i as f64 / (n_iknots as f64 + 1.0))
.collect();
let mut xs = x.to_vec();
xs.sort_by(|a, b| a.partial_cmp(b).unwrap());
quantile_type7(&xs, &probs)
} else {
Vec::new()
};
let ord = 4usize;
let mut aknots: Vec<f64> = Vec::with_capacity(2 * ord + n_iknots);
aknots.extend(std::iter::repeat_n(xmin, ord));
aknots.extend_from_slice(&iknots);
aknots.extend(std::iter::repeat_n(xmax, ord));
aknots.sort_by(|a, b| a.partial_cmp(b).unwrap());
let constraint = spline_design(&aknots, &[xmin, xmax], ord, 2); let ct = constraint.t().to_owned(); let q = qr_full_q(&ct); let nb = ct.nrows();
let nbasis = q.slice(s![.., 2..nb]).to_owned();
NsBasis {
aknots,
xmin,
xmax,
nbasis,
}
}
pub(crate) fn eval(&self, newx: &[f64]) -> Array2<f64> {
let ord = 4usize;
let npre = self.aknots.len() - ord; let mut pre = Array2::<f64>::zeros((newx.len(), npre));
for (r, &xv) in newx.iter().enumerate() {
let row = if xv < self.xmin {
self.extrapolate_row(self.xmin, xv)
} else if xv > self.xmax {
self.extrapolate_row(self.xmax, xv)
} else {
bspline_row(&self.aknots, xv, ord, 0)
};
for (c, val) in row.into_iter().enumerate() {
pre[[r, c]] = val;
}
}
pre.dot(&self.nbasis) }
fn extrapolate_row(&self, pivot: f64, xv: f64) -> Vec<f64> {
let ord = 4usize;
let val = bspline_row(&self.aknots, pivot, ord, 0);
let der = bspline_row(&self.aknots, pivot, ord, 1);
let dx = xv - pivot;
val.iter()
.zip(der.iter())
.map(|(&v, &d)| v + dx * d)
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::linalg::{qr_econ, solve_upper};
use ndarray::Array1;
fn ns_design(x: &[f64], df: usize) -> Array2<f64> {
NsBasis::fit(x, df).eval(x)
}
fn project(design: &Array2<f64>, y: &[f64]) -> (Vec<f64>, f64) {
let yv = Array1::from(y.to_vec());
let (q, r) = qr_econ(design);
let qty = q.t().dot(&yv);
let coef = solve_upper(&r, &qty);
let fitted = design.dot(&coef);
let rss: f64 = yv
.iter()
.zip(fitted.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
(fitted.to_vec(), rss)
}
#[test]
fn ns_basis_is_finite_and_spans_constants() {
let x: Vec<f64> = (0..50).map(|i| (i as f64) * 0.2 - 3.0).collect();
let d = ns_design(&x, 4);
assert_eq!(d.nrows(), 50);
assert_eq!(d.ncols(), 4);
assert!(d.iter().all(|v| v.is_finite()));
let ones = vec![1.0_f64; 50];
let (_fit, rss) = project(&d, &ones);
assert!(rss < 1e-18, "constant not in span, rss={rss}");
}
#[test]
fn ns_reproduces_linear_functions() {
let x: Vec<f64> = (0..40).map(|i| (i as f64).powf(1.3) * 0.1).collect();
let d = ns_design(&x, 5);
assert_eq!(d.ncols(), 5);
let y: Vec<f64> = x.iter().map(|&v| 2.5 * v - 0.7).collect();
let (fit, rss) = project(&d, &y);
assert!(rss < 1e-16, "line not reproduced, rss={rss}");
for (a, b) in fit.iter().zip(y.iter()) {
assert!((a - b).abs() < 1e-8);
}
}
#[test]
fn spline_design_matches_r_reference() {
let knots = [0.0, 0.0, 0.0, 0.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0];
let b = spline_design(&knots, &[0.0, 1.5, 3.0], 4, 0);
assert!((b[[0, 0]] - 1.0).abs() < 1e-12);
assert!(b.row(0).iter().skip(1).all(|&v| v.abs() < 1e-12));
assert!((b[[2, 5]] - 1.0).abs() < 1e-12);
assert!(b.row(2).iter().take(5).all(|&v| v.abs() < 1e-12));
for r in 0..3 {
let s: f64 = b.row(r).sum();
assert!((s - 1.0).abs() < 1e-12, "row {r} sums to {s}");
}
assert!((b[[1, 1]] - 0.031_25).abs() < 1e-12, "{}", b[[1, 1]]);
assert!((b[[1, 2]] - 0.468_75).abs() < 1e-12, "{}", b[[1, 2]]);
assert!((b[[1, 3]] - 0.468_75).abs() < 1e-12, "{}", b[[1, 3]]);
assert!((b[[1, 4]] - 0.031_25).abs() < 1e-12, "{}", b[[1, 4]]);
}
}