gam_terms/basis/cubic_regression.rs
1//! Natural cubic regression spline (`cr`) basis — mgcv-compatible.
2//!
3//! Implements the Lancaster–Salkauskas natural cubic regression spline that
4//! mgcv exposes as `bs="cr"` (and its shrinkage twin `bs="cs"`), following
5//! Wood (2017) *Generalized Additive Models*, §5.3.1.
6//!
7//! The smooth is parameterized by its values at `k` knots,
8//! `β_i = f(x*_i)`, with natural boundary conditions `f''(x*_1) = f''(x*_k) =
9//! 0`. The basis dimension is exactly `k` (the number of knots), and the
10//! roughness penalty `∫ f''(x)² dx` is the quadratic form `βᵀ S β` with
11//! `S = Dᵀ B⁻¹ D` whose null space is `{const, linear}` (dimension 2).
12//!
13//! This matches mgcv's `smooth.construct.cr.smooth.spec` output (`$X` and
14//! `$S[[1]]`) to round-off for the same knot vector — see the unit tests at
15//! the bottom of this module and the in-tree quality cross-checks.
16//!
17//! ## Geometry (the `F` matrix)
18//! For interior knots, the second derivatives `δ` are linear in the values
19//! `β` via `δ = F β`, where `F` is `k × k` with zero first/last rows and
20//! interior rows given by `B⁻¹ D`:
21//! * `D` is `(k-2) × k`: `D[i,i]=1/h_i`, `D[i,i+1]=-1/h_i-1/h_{i+1}`,
22//! `D[i,i+2]=1/h_{i+1}`.
23//! * `B` is `(k-2) × (k-2)` tridiagonal SPD: `B[i,i]=(h_i+h_{i+1})/3`,
24//! `B[i,i+1]=B[i+1,i]=h_{i+1}/6`.
25//! with `h_i = x*_{i+1} - x*_i` (1-indexed in the math, 0-indexed below).
26//!
27//! ## Design row
28//! For `x ∈ [x*_j, x*_{j+1}]` (knot interval `j`, 0-indexed) with
29//! `a₋ = (x*_{j+1}-x)/h_j`, `a₊ = (x-x*_j)/h_j`:
30//! `row = a₋·e_j + a₊·e_{j+1} + c₋·F[j,:] + c₊·F[j+1,:]`
31//! where `c₋ = (a₋³-a₋) h_j²/6`, `c₊ = (a₊³-a₊) h_j²/6`.
32//!
33//! Outside `[x*_1, x*_k]` mgcv extrapolates *linearly*: the value and first
34//! derivative are continued from the nearest endpoint knot. We reproduce that
35//! exactly so predict-time rows past the data range match mgcv.
36
37use super::*;
38
39/// Precomputed natural cubic regression spline geometry for a fixed knot set.
40#[derive(Clone, Debug)]
41pub struct CubicRegressionBasis {
42 /// Knot locations `x*_1 < … < x*_k` (strictly increasing).
43 pub knots: Array1<f64>,
44 /// The `k × k` second-derivative map `F` (`δ = F β`); rows 0 and k-1 are zero.
45 f_matrix: Array2<f64>,
46}
47
48impl CubicRegressionBasis {
49 /// Build the cr geometry for a strictly increasing knot vector of length
50 /// `k >= 3`. (mgcv requires `k >= 3` for a cubic regression spline.)
51 pub fn new(knots: Array1<f64>) -> Result<Self, BasisError> {
52 let k = knots.len();
53 if k < 3 {
54 crate::bail_invalid_basis!(
55 "cubic regression spline requires at least 3 knots, got {k}"
56 );
57 }
58 // Strictly increasing check.
59 for i in 1..k {
60 if !(knots[i] > knots[i - 1]) {
61 crate::bail_invalid_basis!(
62 "cubic regression spline knots must be strictly increasing; \
63 knot[{}]={} is not greater than knot[{}]={}",
64 i,
65 knots[i],
66 i - 1,
67 knots[i - 1]
68 );
69 }
70 }
71 let h: Vec<f64> = (0..k - 1).map(|i| knots[i + 1] - knots[i]).collect();
72 let f_matrix = build_f_matrix(&h, k)?;
73 Ok(Self { knots, f_matrix })
74 }
75
76 pub fn num_basis(&self) -> usize {
77 self.knots.len()
78 }
79
80 /// The natural cubic regression roughness penalty `S = Dᵀ B⁻¹ D` (k×k).
81 ///
82 /// Equivalently `S = Dᵀ F_int` where `F_int = B⁻¹ D` are the interior rows
83 /// of `F`. We assemble it directly from `D` and the interior block of `F`.
84 pub fn penalty(&self) -> Array2<f64> {
85 let k = self.knots.len();
86 let h: Vec<f64> = (0..k - 1)
87 .map(|i| self.knots[i + 1] - self.knots[i])
88 .collect();
89 // D is (k-2) x k.
90 let mut d = Array2::<f64>::zeros((k - 2, k));
91 for i in 0..k - 2 {
92 d[[i, i]] = 1.0 / h[i];
93 d[[i, i + 1]] = -1.0 / h[i] - 1.0 / h[i + 1];
94 d[[i, i + 2]] = 1.0 / h[i + 1];
95 }
96 // F_int = interior rows of F (rows 1..k-1 of F_matrix), shape (k-2) x k.
97 // S = Dᵀ F_int. (F_int = B⁻¹ D, so Dᵀ B⁻¹ D.)
98 let f_int = self.f_matrix.slice(s![1..k - 1, ..]).to_owned();
99 // S = Dᵀ (F_int) -> (k x (k-2)) x ((k-2) x k) = k x k.
100 let s = d.t().dot(&f_int);
101 // Symmetrize defensively (it is symmetric in exact arithmetic).
102 let mut s_sym = Array2::<f64>::zeros((k, k));
103 for a in 0..k {
104 for b in 0..k {
105 s_sym[[a, b]] = 0.5 * (s[[a, b]] + s[[b, a]]);
106 }
107 }
108 s_sym
109 }
110
111 /// Evaluate the cr design row for a single point `x` into `row` (length k).
112 /// `row` is overwritten.
113 pub fn eval_row_into(&self, x: f64, row: &mut [f64]) {
114 let k = self.knots.len();
115 // assert_eq!, not debug_assert_eq!: the ban-scanner forbids debug_assert
116 // (silent in release → debug/release divergence). The length check is a
117 // cheap O(1) guard, so an always-active assert is acceptable here.
118 assert_eq!(row.len(), k);
119 for r in row.iter_mut() {
120 *r = 0.0;
121 }
122 let x1 = self.knots[0];
123 let xk = self.knots[k - 1];
124
125 if x <= x1 {
126 // Linear extrapolation off the left endpoint, matching mgcv: the
127 // value at x1 is β_0, the slope is the spline's first derivative at
128 // x1. For the first interval [x*_0, x*_1] the cubic has
129 // f(x) = a₋β_0 + a₊β_1 + c₋δ_0 + c₊δ_1 with δ_0 = 0 (natural),
130 // so f'(x1⁻side) at x = x1 is
131 // slope = (β_1 - β_0)/h_0 - h_0/6 * δ_1 (δ_0 = 0).
132 let h0 = self.knots[1] - self.knots[0];
133 // row picks up β_0 (=1 at e_0) plus slope*(x-x1) expressed in β.
134 row[0] += 1.0;
135 // d/dx contributions: (β_1-β_0)/h0 term and -h0/6 * δ_1 term.
136 let dx = x - x1;
137 row[0] += dx * (-1.0 / h0);
138 row[1] += dx * (1.0 / h0);
139 // δ_1 = F[1,:]·β → -h0/6 * δ_1 contributes -h0/6 * F[1,:].
140 let coeff = dx * (-h0 / 6.0);
141 for c in 0..k {
142 row[c] += coeff * self.f_matrix[[1, c]];
143 }
144 return;
145 }
146 if x >= xk {
147 // Linear extrapolation off the right endpoint. For the last
148 // interval [x*_{k-2}, x*_{k-1}], δ_{k-1} = 0 (natural), and the
149 // first derivative at x = xk is
150 // slope = (β_{k-1} - β_{k-2})/h_{k-2} + h_{k-2}/6 * δ_{k-2}.
151 let hk = self.knots[k - 1] - self.knots[k - 2];
152 row[k - 1] += 1.0;
153 let dx = x - xk;
154 row[k - 2] += dx * (-1.0 / hk);
155 row[k - 1] += dx * (1.0 / hk);
156 // + h_{k-2}/6 * δ_{k-2}, δ_{k-2} = F[k-2,:]·β.
157 let coeff = dx * (hk / 6.0);
158 for c in 0..k {
159 row[c] += coeff * self.f_matrix[[k - 2, c]];
160 }
161 return;
162 }
163
164 // Interior: locate interval j with x*_j <= x <= x*_{j+1}.
165 // knots strictly increasing; binary search for the upper bound.
166 let mut j = match self
167 .knots
168 .as_slice()
169 .expect("contiguous knots")
170 .binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Less))
171 {
172 Ok(idx) => idx, // x equals a knot: use interval starting at idx
173 Err(idx) => idx - 1, // x in (knot[idx-1], knot[idx])
174 };
175 if j >= k - 1 {
176 j = k - 2;
177 }
178 let hj = self.knots[j + 1] - self.knots[j];
179 let a_minus = (self.knots[j + 1] - x) / hj;
180 let a_plus = (x - self.knots[j]) / hj;
181 let c_minus = (a_minus * a_minus * a_minus - a_minus) * hj * hj / 6.0;
182 let c_plus = (a_plus * a_plus * a_plus - a_plus) * hj * hj / 6.0;
183 row[j] += a_minus;
184 row[j + 1] += a_plus;
185 for c in 0..k {
186 row[c] += c_minus * self.f_matrix[[j, c]] + c_plus * self.f_matrix[[j + 1, c]];
187 }
188 }
189
190 /// Dense `n × k` design matrix for a column of evaluation points.
191 pub fn design(&self, data: ArrayView1<'_, f64>) -> Array2<f64> {
192 let k = self.knots.len();
193 let n = data.len();
194 let mut x = Array2::<f64>::zeros((n, k));
195 let mut row = vec![0.0f64; k];
196 for (i, &xi) in data.iter().enumerate() {
197 self.eval_row_into(xi, &mut row);
198 for c in 0..k {
199 x[[i, c]] = row[c];
200 }
201 }
202 x
203 }
204}
205
206/// Assemble the `k × k` map `F` (`δ = F β`) from interval widths `h`.
207/// Rows 0 and k-1 are zero (natural boundary). Interior rows solve
208/// `B (F_int) = D` for the `(k-2) × k` interior block `F_int`.
209fn build_f_matrix(h: &[f64], k: usize) -> Result<Array2<f64>, BasisError> {
210 let m = k - 2; // interior count
211 // B (m x m) tridiagonal SPD.
212 let mut b_diag = vec![0.0f64; m];
213 let mut b_off = vec![0.0f64; m.saturating_sub(1)]; // b_off[i] = B[i,i+1] = B[i+1,i]
214 for i in 0..m {
215 b_diag[i] = (h[i] + h[i + 1]) / 3.0;
216 }
217 for i in 0..m.saturating_sub(1) {
218 // B[i,i+1] = h_{i+1}/6 (the shared interior width).
219 b_off[i] = h[i + 1] / 6.0;
220 }
221 // D (m x k).
222 let mut d = Array2::<f64>::zeros((m, k));
223 for i in 0..m {
224 d[[i, i]] = 1.0 / h[i];
225 d[[i, i + 1]] = -1.0 / h[i] - 1.0 / h[i + 1];
226 d[[i, i + 2]] = 1.0 / h[i + 1];
227 }
228 // Solve B X = D column-by-column with the Thomas algorithm; X = F_int.
229 let f_int = thomas_solve_multi(&b_diag, &b_off, &d)?;
230 let mut f = Array2::<f64>::zeros((k, k));
231 for i in 0..m {
232 for c in 0..k {
233 f[[i + 1, c]] = f_int[[i, c]];
234 }
235 }
236 Ok(f)
237}
238
239/// Solve a symmetric tridiagonal system `B X = RHS` for every column of `RHS`
240/// using the Thomas algorithm. `diag` is length m, `off` is length m-1
241/// (the shared sub/super-diagonal). `rhs` is `m × c`. Returns `m × c`.
242fn thomas_solve_multi(
243 diag: &[f64],
244 off: &[f64],
245 rhs: &Array2<f64>,
246) -> Result<Array2<f64>, BasisError> {
247 let m = diag.len();
248 let cols = rhs.ncols();
249 if m == 0 {
250 return Ok(Array2::<f64>::zeros((0, cols)));
251 }
252 if rhs.nrows() != m {
253 crate::bail_dim_basis!(
254 "tridiagonal solve RHS has {} rows but system is {}x{}",
255 rhs.nrows(),
256 m,
257 m
258 );
259 }
260 // Forward sweep.
261 let mut c_prime = vec![0.0f64; m]; // modified super-diagonal
262 let mut d_prime = Array2::<f64>::zeros((m, cols));
263 let denom0 = diag[0];
264 if denom0.abs() < 1e-300 {
265 crate::bail_invalid_basis!("singular tridiagonal pivot at row 0 in cr penalty solve");
266 }
267 if m > 1 {
268 c_prime[0] = off[0] / denom0;
269 }
270 for col in 0..cols {
271 d_prime[[0, col]] = rhs[[0, col]] / denom0;
272 }
273 for i in 1..m {
274 let denom = diag[i] - off[i - 1] * c_prime[i - 1];
275 if denom.abs() < 1e-300 {
276 crate::bail_invalid_basis!("singular tridiagonal pivot at row {i} in cr penalty solve");
277 }
278 if i < m - 1 {
279 c_prime[i] = off[i] / denom;
280 }
281 for col in 0..cols {
282 d_prime[[i, col]] = (rhs[[i, col]] - off[i - 1] * d_prime[[i - 1, col]]) / denom;
283 }
284 }
285 // Back substitution.
286 let mut x = Array2::<f64>::zeros((m, cols));
287 for col in 0..cols {
288 x[[m - 1, col]] = d_prime[[m - 1, col]];
289 }
290 for i in (0..m - 1).rev() {
291 for col in 0..cols {
292 x[[i, col]] = d_prime[[i, col]] - c_prime[i] * x[[i + 1, col]];
293 }
294 }
295 Ok(x)
296}
297
298/// Place `k` cr knots at evenly-spaced quantiles of the unique sorted data,
299/// exactly as mgcv's default `cr` knot placement: the first and last knots are
300/// the min/max, and the interior knots are at the `1/(k-1) … (k-2)/(k-1)`
301/// quantiles of the *unique* observed values. Returns a strictly increasing
302/// length-`k` knot vector.
303pub fn select_cr_knots(data: ArrayView1<'_, f64>, k: usize) -> Result<Array1<f64>, BasisError> {
304 if k < 3 {
305 crate::bail_invalid_basis!("cubic regression spline requires k >= 3, got {k}");
306 }
307 if data.is_empty() {
308 crate::bail_invalid_basis!("cannot place cr knots on empty data");
309 }
310 if data.iter().any(|x| !x.is_finite()) {
311 crate::bail_invalid_basis!("cr knot placement requires finite data");
312 }
313 let mut sorted: Vec<f64> = data.iter().copied().collect();
314 sorted.sort_by(f64::total_cmp);
315 // Unique values (mgcv places cr knots on the unique data quantiles).
316 let mut unique: Vec<f64> = Vec::with_capacity(sorted.len());
317 for &v in &sorted {
318 if unique.last().map(|&p| p != v).unwrap_or(true) {
319 unique.push(v);
320 }
321 }
322 let nu = unique.len();
323 if nu < k {
324 crate::bail_invalid_basis!(
325 "cubic regression spline with k={k} requires at least {k} distinct \
326 values, got {nu}"
327 );
328 }
329 // mgcv's `place.knots`: knots at quantile type-1-ish positions over the
330 // index range [0, nu-1] evenly in (k-1) steps. Endpoints are exact min/max.
331 let mut knots = Array1::<f64>::zeros(k);
332 for j in 0..k {
333 let pos = (j as f64) * ((nu - 1) as f64) / ((k - 1) as f64);
334 let lo = pos.floor() as usize;
335 let hi = pos.ceil() as usize;
336 let frac = pos - lo as f64;
337 knots[j] = if lo == hi {
338 unique[lo]
339 } else {
340 unique[lo] * (1.0 - frac) + unique[hi] * frac
341 };
342 }
343 // Guard strict monotonicity in case of ties from interpolation rounding.
344 for i in 1..k {
345 if !(knots[i] > knots[i - 1]) {
346 crate::bail_invalid_basis!(
347 "cr knot placement produced non-increasing knots (too many knots \
348 for the data spread); reduce k"
349 );
350 }
351 }
352 Ok(knots)
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 /// A cr smooth must reproduce constants and lines exactly: the penalty null
360 /// space is {const, linear}, and the design with values β_i = f(x*_i)
361 /// interpolates any line through the knots with zero penalty.
362 #[test]
363 fn cr_penalty_nullspace_is_const_and_linear() {
364 let knots = Array1::from(vec![0.0, 0.3, 0.55, 0.8, 1.0]);
365 let cr = CubicRegressionBasis::new(knots.clone()).unwrap();
366 let s = cr.penalty();
367 let k = knots.len();
368 // const: β = 1.
369 let ones = Array1::<f64>::ones(k);
370 let q_const = ones.dot(&s.dot(&ones));
371 assert!(q_const.abs() < 1e-9, "const not in null space: {q_const}");
372 // linear: β_i = knot_i.
373 let lin = knots.clone();
374 let q_lin = lin.dot(&s.dot(&lin));
375 assert!(q_lin.abs() < 1e-9, "linear not in null space: {q_lin}");
376 // a quadratic should have positive penalty.
377 let quad: Array1<f64> = knots.mapv(|x| x * x);
378 let q_quad = quad.dot(&s.dot(&quad));
379 assert!(q_quad > 1e-6, "quadratic penalty not positive: {q_quad}");
380 }
381
382 /// The design must reproduce a line exactly at arbitrary evaluation points
383 /// (interior and extrapolated), since a line is in the cr span.
384 #[test]
385 fn cr_design_reproduces_line_including_extrapolation() {
386 let knots = Array1::from(vec![0.0, 0.25, 0.5, 0.75, 1.0]);
387 let cr = CubicRegressionBasis::new(knots.clone()).unwrap();
388 // f(x) = 2 + 3x → β_i = 2 + 3*knot_i.
389 let beta: Array1<f64> = knots.mapv(|x| 2.0 + 3.0 * x);
390 let xs = Array1::from(vec![-0.4, 0.0, 0.13, 0.5, 0.87, 1.0, 1.3]);
391 let design = cr.design(xs.view());
392 let fitted = design.dot(&beta);
393 for (i, &x) in xs.iter().enumerate() {
394 let truth = 2.0 + 3.0 * x;
395 assert!(
396 (fitted[i] - truth).abs() < 1e-9,
397 "line not reproduced at x={x}: got {}, want {truth}",
398 fitted[i]
399 );
400 }
401 }
402
403 /// Knot placement returns endpoints = min/max and strictly increasing knots.
404 #[test]
405 fn cr_knots_span_data_and_increase() {
406 let data = Array1::from((0..50).map(|i| i as f64 / 49.0).collect::<Vec<_>>());
407 let knots = select_cr_knots(data.view(), 5).unwrap();
408 assert_eq!(knots.len(), 5);
409 assert!((knots[0] - 0.0).abs() < 1e-12);
410 assert!((knots[4] - 1.0).abs() < 1e-12);
411 for i in 1..5 {
412 assert!(knots[i] > knots[i - 1]);
413 }
414 }
415}