use super::*;
#[derive(Clone, Debug)]
pub struct CubicRegressionBasis {
pub knots: Array1<f64>,
f_matrix: Array2<f64>,
}
impl CubicRegressionBasis {
pub fn new(knots: Array1<f64>) -> Result<Self, BasisError> {
let k = knots.len();
if k < 3 {
crate::bail_invalid_basis!(
"cubic regression spline requires at least 3 knots, got {k}"
);
}
for i in 1..k {
if !(knots[i] > knots[i - 1]) {
crate::bail_invalid_basis!(
"cubic regression spline knots must be strictly increasing; \
knot[{}]={} is not greater than knot[{}]={}",
i,
knots[i],
i - 1,
knots[i - 1]
);
}
}
let h: Vec<f64> = (0..k - 1).map(|i| knots[i + 1] - knots[i]).collect();
let f_matrix = build_f_matrix(&h, k)?;
Ok(Self { knots, f_matrix })
}
pub fn num_basis(&self) -> usize {
self.knots.len()
}
pub fn penalty(&self) -> Array2<f64> {
let k = self.knots.len();
let h: Vec<f64> = (0..k - 1)
.map(|i| self.knots[i + 1] - self.knots[i])
.collect();
let mut d = Array2::<f64>::zeros((k - 2, k));
for i in 0..k - 2 {
d[[i, i]] = 1.0 / h[i];
d[[i, i + 1]] = -1.0 / h[i] - 1.0 / h[i + 1];
d[[i, i + 2]] = 1.0 / h[i + 1];
}
let f_int = self.f_matrix.slice(s![1..k - 1, ..]).to_owned();
let s = d.t().dot(&f_int);
let mut s_sym = Array2::<f64>::zeros((k, k));
for a in 0..k {
for b in 0..k {
s_sym[[a, b]] = 0.5 * (s[[a, b]] + s[[b, a]]);
}
}
s_sym
}
pub fn eval_row_into(&self, x: f64, row: &mut [f64]) {
let k = self.knots.len();
assert_eq!(row.len(), k);
for r in row.iter_mut() {
*r = 0.0;
}
let x1 = self.knots[0];
let xk = self.knots[k - 1];
if x <= x1 {
let h0 = self.knots[1] - self.knots[0];
row[0] += 1.0;
let dx = x - x1;
row[0] += dx * (-1.0 / h0);
row[1] += dx * (1.0 / h0);
let coeff = dx * (-h0 / 6.0);
for c in 0..k {
row[c] += coeff * self.f_matrix[[1, c]];
}
return;
}
if x >= xk {
let hk = self.knots[k - 1] - self.knots[k - 2];
row[k - 1] += 1.0;
let dx = x - xk;
row[k - 2] += dx * (-1.0 / hk);
row[k - 1] += dx * (1.0 / hk);
let coeff = dx * (hk / 6.0);
for c in 0..k {
row[c] += coeff * self.f_matrix[[k - 2, c]];
}
return;
}
let mut j = match self
.knots
.as_slice()
.expect("contiguous knots")
.binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Less))
{
Ok(idx) => idx, Err(idx) => idx - 1, };
if j >= k - 1 {
j = k - 2;
}
let hj = self.knots[j + 1] - self.knots[j];
let a_minus = (self.knots[j + 1] - x) / hj;
let a_plus = (x - self.knots[j]) / hj;
let c_minus = (a_minus * a_minus * a_minus - a_minus) * hj * hj / 6.0;
let c_plus = (a_plus * a_plus * a_plus - a_plus) * hj * hj / 6.0;
row[j] += a_minus;
row[j + 1] += a_plus;
for c in 0..k {
row[c] += c_minus * self.f_matrix[[j, c]] + c_plus * self.f_matrix[[j + 1, c]];
}
}
pub fn design(&self, data: ArrayView1<'_, f64>) -> Array2<f64> {
let k = self.knots.len();
let n = data.len();
let mut x = Array2::<f64>::zeros((n, k));
let mut row = vec![0.0f64; k];
for (i, &xi) in data.iter().enumerate() {
self.eval_row_into(xi, &mut row);
for c in 0..k {
x[[i, c]] = row[c];
}
}
x
}
}
fn build_f_matrix(h: &[f64], k: usize) -> Result<Array2<f64>, BasisError> {
let m = k - 2; let mut b_diag = vec![0.0f64; m];
let mut b_off = vec![0.0f64; m.saturating_sub(1)]; for i in 0..m {
b_diag[i] = (h[i] + h[i + 1]) / 3.0;
}
for i in 0..m.saturating_sub(1) {
b_off[i] = h[i + 1] / 6.0;
}
let mut d = Array2::<f64>::zeros((m, k));
for i in 0..m {
d[[i, i]] = 1.0 / h[i];
d[[i, i + 1]] = -1.0 / h[i] - 1.0 / h[i + 1];
d[[i, i + 2]] = 1.0 / h[i + 1];
}
let f_int = thomas_solve_multi(&b_diag, &b_off, &d)?;
let mut f = Array2::<f64>::zeros((k, k));
for i in 0..m {
for c in 0..k {
f[[i + 1, c]] = f_int[[i, c]];
}
}
Ok(f)
}
fn thomas_solve_multi(
diag: &[f64],
off: &[f64],
rhs: &Array2<f64>,
) -> Result<Array2<f64>, BasisError> {
let m = diag.len();
let cols = rhs.ncols();
if m == 0 {
return Ok(Array2::<f64>::zeros((0, cols)));
}
if rhs.nrows() != m {
crate::bail_dim_basis!(
"tridiagonal solve RHS has {} rows but system is {}x{}",
rhs.nrows(),
m,
m
);
}
let mut c_prime = vec![0.0f64; m]; let mut d_prime = Array2::<f64>::zeros((m, cols));
let denom0 = diag[0];
if denom0.abs() < 1e-300 {
crate::bail_invalid_basis!("singular tridiagonal pivot at row 0 in cr penalty solve");
}
if m > 1 {
c_prime[0] = off[0] / denom0;
}
for col in 0..cols {
d_prime[[0, col]] = rhs[[0, col]] / denom0;
}
for i in 1..m {
let denom = diag[i] - off[i - 1] * c_prime[i - 1];
if denom.abs() < 1e-300 {
crate::bail_invalid_basis!("singular tridiagonal pivot at row {i} in cr penalty solve");
}
if i < m - 1 {
c_prime[i] = off[i] / denom;
}
for col in 0..cols {
d_prime[[i, col]] = (rhs[[i, col]] - off[i - 1] * d_prime[[i - 1, col]]) / denom;
}
}
let mut x = Array2::<f64>::zeros((m, cols));
for col in 0..cols {
x[[m - 1, col]] = d_prime[[m - 1, col]];
}
for i in (0..m - 1).rev() {
for col in 0..cols {
x[[i, col]] = d_prime[[i, col]] - c_prime[i] * x[[i + 1, col]];
}
}
Ok(x)
}
pub fn select_cr_knots(data: ArrayView1<'_, f64>, k: usize) -> Result<Array1<f64>, BasisError> {
if k < 3 {
crate::bail_invalid_basis!("cubic regression spline requires k >= 3, got {k}");
}
if data.is_empty() {
crate::bail_invalid_basis!("cannot place cr knots on empty data");
}
if data.iter().any(|x| !x.is_finite()) {
crate::bail_invalid_basis!("cr knot placement requires finite data");
}
let mut sorted: Vec<f64> = data.iter().copied().collect();
sorted.sort_by(f64::total_cmp);
let mut unique: Vec<f64> = Vec::with_capacity(sorted.len());
for &v in &sorted {
if unique.last().map(|&p| p != v).unwrap_or(true) {
unique.push(v);
}
}
let nu = unique.len();
if nu < k {
crate::bail_invalid_basis!(
"cubic regression spline with k={k} requires at least {k} distinct \
values, got {nu}"
);
}
let mut knots = Array1::<f64>::zeros(k);
for j in 0..k {
let pos = (j as f64) * ((nu - 1) as f64) / ((k - 1) as f64);
let lo = pos.floor() as usize;
let hi = pos.ceil() as usize;
let frac = pos - lo as f64;
knots[j] = if lo == hi {
unique[lo]
} else {
unique[lo] * (1.0 - frac) + unique[hi] * frac
};
}
for i in 1..k {
if !(knots[i] > knots[i - 1]) {
crate::bail_invalid_basis!(
"cr knot placement produced non-increasing knots (too many knots \
for the data spread); reduce k"
);
}
}
Ok(knots)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cr_penalty_nullspace_is_const_and_linear() {
let knots = Array1::from(vec![0.0, 0.3, 0.55, 0.8, 1.0]);
let cr = CubicRegressionBasis::new(knots.clone()).unwrap();
let s = cr.penalty();
let k = knots.len();
let ones = Array1::<f64>::ones(k);
let q_const = ones.dot(&s.dot(&ones));
assert!(q_const.abs() < 1e-9, "const not in null space: {q_const}");
let lin = knots.clone();
let q_lin = lin.dot(&s.dot(&lin));
assert!(q_lin.abs() < 1e-9, "linear not in null space: {q_lin}");
let quad: Array1<f64> = knots.mapv(|x| x * x);
let q_quad = quad.dot(&s.dot(&quad));
assert!(q_quad > 1e-6, "quadratic penalty not positive: {q_quad}");
}
#[test]
fn cr_design_reproduces_line_including_extrapolation() {
let knots = Array1::from(vec![0.0, 0.25, 0.5, 0.75, 1.0]);
let cr = CubicRegressionBasis::new(knots.clone()).unwrap();
let beta: Array1<f64> = knots.mapv(|x| 2.0 + 3.0 * x);
let xs = Array1::from(vec![-0.4, 0.0, 0.13, 0.5, 0.87, 1.0, 1.3]);
let design = cr.design(xs.view());
let fitted = design.dot(&beta);
for (i, &x) in xs.iter().enumerate() {
let truth = 2.0 + 3.0 * x;
assert!(
(fitted[i] - truth).abs() < 1e-9,
"line not reproduced at x={x}: got {}, want {truth}",
fitted[i]
);
}
}
#[test]
fn cr_knots_span_data_and_increase() {
let data = Array1::from((0..50).map(|i| i as f64 / 49.0).collect::<Vec<_>>());
let knots = select_cr_knots(data.view(), 5).unwrap();
assert_eq!(knots.len(), 5);
assert!((knots[0] - 0.0).abs() < 1e-12);
assert!((knots[4] - 1.0).abs() < 1e-12);
for i in 1..5 {
assert!(knots[i] > knots[i - 1]);
}
}
}