Skip to main content

gam_terms/basis/
cubic_regression.rs

1//! Natural cubic regression spline (`cr`) basis — mgcv-compatible.
2//!
3//! Implements the Lancaster–Salkauskas natural cubic regression spline that
4//! mgcv exposes as `bs="cr"` (and its shrinkage twin `bs="cs"`), following
5//! Wood (2017) *Generalized Additive Models*, §5.3.1.
6//!
7//! The smooth is parameterized by its values at `k` knots,
8//! `β_i = f(x*_i)`, with natural boundary conditions `f''(x*_1) = f''(x*_k) =
9//! 0`. The basis dimension is exactly `k` (the number of knots), and the
10//! roughness penalty `∫ f''(x)² dx` is the quadratic form `βᵀ S β` with
11//! `S = Dᵀ B⁻¹ D` whose null space is `{const, linear}` (dimension 2).
12//!
13//! This matches mgcv's `smooth.construct.cr.smooth.spec` output (`$X` and
14//! `$S[[1]]`) to round-off for the same knot vector — see the unit tests at
15//! the bottom of this module and the in-tree quality cross-checks.
16//!
17//! ## Geometry (the `F` matrix)
18//! For interior knots, the second derivatives `δ` are linear in the values
19//! `β` via `δ = F β`, where `F` is `k × k` with zero first/last rows and
20//! interior rows given by `B⁻¹ D`:
21//!   * `D` is `(k-2) × k`:   `D[i,i]=1/h_i`, `D[i,i+1]=-1/h_i-1/h_{i+1}`,
22//!                            `D[i,i+2]=1/h_{i+1}`.
23//!   * `B` is `(k-2) × (k-2)` tridiagonal SPD: `B[i,i]=(h_i+h_{i+1})/3`,
24//!                            `B[i,i+1]=B[i+1,i]=h_{i+1}/6`.
25//! with `h_i = x*_{i+1} - x*_i` (1-indexed in the math, 0-indexed below).
26//!
27//! ## Design row
28//! For `x ∈ [x*_j, x*_{j+1}]` (knot interval `j`, 0-indexed) with
29//! `a₋ = (x*_{j+1}-x)/h_j`, `a₊ = (x-x*_j)/h_j`:
30//!   `row = a₋·e_j + a₊·e_{j+1} + c₋·F[j,:] + c₊·F[j+1,:]`
31//! where `c₋ = (a₋³-a₋) h_j²/6`, `c₊ = (a₊³-a₊) h_j²/6`.
32//!
33//! Outside `[x*_1, x*_k]` mgcv extrapolates *linearly*: the value and first
34//! derivative are continued from the nearest endpoint knot. We reproduce that
35//! exactly so predict-time rows past the data range match mgcv.
36
37use super::*;
38
39/// Precomputed natural cubic regression spline geometry for a fixed knot set.
40#[derive(Clone, Debug)]
41pub struct CubicRegressionBasis {
42    /// Knot locations `x*_1 < … < x*_k` (strictly increasing).
43    pub knots: Array1<f64>,
44    /// The `k × k` second-derivative map `F` (`δ = F β`); rows 0 and k-1 are zero.
45    f_matrix: Array2<f64>,
46}
47
48impl CubicRegressionBasis {
49    /// Build the cr geometry for a strictly increasing knot vector of length
50    /// `k >= 3`. (mgcv requires `k >= 3` for a cubic regression spline.)
51    pub fn new(knots: Array1<f64>) -> Result<Self, BasisError> {
52        let k = knots.len();
53        if k < 3 {
54            crate::bail_invalid_basis!(
55                "cubic regression spline requires at least 3 knots, got {k}"
56            );
57        }
58        // Strictly increasing check.
59        for i in 1..k {
60            if !(knots[i] > knots[i - 1]) {
61                crate::bail_invalid_basis!(
62                    "cubic regression spline knots must be strictly increasing; \
63                     knot[{}]={} is not greater than knot[{}]={}",
64                    i,
65                    knots[i],
66                    i - 1,
67                    knots[i - 1]
68                );
69            }
70        }
71        let h: Vec<f64> = (0..k - 1).map(|i| knots[i + 1] - knots[i]).collect();
72        let f_matrix = build_f_matrix(&h, k)?;
73        Ok(Self { knots, f_matrix })
74    }
75
76    pub fn num_basis(&self) -> usize {
77        self.knots.len()
78    }
79
80    /// The natural cubic regression roughness penalty `S = Dᵀ B⁻¹ D` (k×k).
81    ///
82    /// Equivalently `S = Dᵀ F_int` where `F_int = B⁻¹ D` are the interior rows
83    /// of `F`. We assemble it directly from `D` and the interior block of `F`.
84    pub fn penalty(&self) -> Array2<f64> {
85        let k = self.knots.len();
86        let h: Vec<f64> = (0..k - 1)
87            .map(|i| self.knots[i + 1] - self.knots[i])
88            .collect();
89        // D is (k-2) x k.
90        let mut d = Array2::<f64>::zeros((k - 2, k));
91        for i in 0..k - 2 {
92            d[[i, i]] = 1.0 / h[i];
93            d[[i, i + 1]] = -1.0 / h[i] - 1.0 / h[i + 1];
94            d[[i, i + 2]] = 1.0 / h[i + 1];
95        }
96        // F_int = interior rows of F (rows 1..k-1 of F_matrix), shape (k-2) x k.
97        // S = Dᵀ F_int. (F_int = B⁻¹ D, so Dᵀ B⁻¹ D.)
98        let f_int = self.f_matrix.slice(s![1..k - 1, ..]).to_owned();
99        // S = Dᵀ (F_int)  -> (k x (k-2)) x ((k-2) x k) = k x k.
100        let s = d.t().dot(&f_int);
101        // Symmetrize defensively (it is symmetric in exact arithmetic).
102        let mut s_sym = Array2::<f64>::zeros((k, k));
103        for a in 0..k {
104            for b in 0..k {
105                s_sym[[a, b]] = 0.5 * (s[[a, b]] + s[[b, a]]);
106            }
107        }
108        s_sym
109    }
110
111    /// Evaluate the cr design row for a single point `x` into `row` (length k).
112    /// `row` is overwritten.
113    pub fn eval_row_into(&self, x: f64, row: &mut [f64]) {
114        let k = self.knots.len();
115        // assert_eq!, not debug_assert_eq!: the ban-scanner forbids debug_assert
116        // (silent in release → debug/release divergence). The length check is a
117        // cheap O(1) guard, so an always-active assert is acceptable here.
118        assert_eq!(row.len(), k);
119        for r in row.iter_mut() {
120            *r = 0.0;
121        }
122        let x1 = self.knots[0];
123        let xk = self.knots[k - 1];
124
125        if x <= x1 {
126            // Linear extrapolation off the left endpoint, matching mgcv: the
127            // value at x1 is β_0, the slope is the spline's first derivative at
128            // x1. For the first interval [x*_0, x*_1] the cubic has
129            //   f(x) = a₋β_0 + a₊β_1 + c₋δ_0 + c₊δ_1   with δ_0 = 0 (natural),
130            // so f'(x1⁻side) at x = x1 is
131            //   slope = (β_1 - β_0)/h_0 - h_0/6 * δ_1     (δ_0 = 0).
132            let h0 = self.knots[1] - self.knots[0];
133            // row picks up β_0 (=1 at e_0) plus slope*(x-x1) expressed in β.
134            row[0] += 1.0;
135            // d/dx contributions: (β_1-β_0)/h0 term and -h0/6 * δ_1 term.
136            let dx = x - x1;
137            row[0] += dx * (-1.0 / h0);
138            row[1] += dx * (1.0 / h0);
139            // δ_1 = F[1,:]·β  → -h0/6 * δ_1 contributes -h0/6 * F[1,:].
140            let coeff = dx * (-h0 / 6.0);
141            for c in 0..k {
142                row[c] += coeff * self.f_matrix[[1, c]];
143            }
144            return;
145        }
146        if x >= xk {
147            // Linear extrapolation off the right endpoint. For the last
148            // interval [x*_{k-2}, x*_{k-1}], δ_{k-1} = 0 (natural), and the
149            // first derivative at x = xk is
150            //   slope = (β_{k-1} - β_{k-2})/h_{k-2} + h_{k-2}/6 * δ_{k-2}.
151            let hk = self.knots[k - 1] - self.knots[k - 2];
152            row[k - 1] += 1.0;
153            let dx = x - xk;
154            row[k - 2] += dx * (-1.0 / hk);
155            row[k - 1] += dx * (1.0 / hk);
156            // + h_{k-2}/6 * δ_{k-2}, δ_{k-2} = F[k-2,:]·β.
157            let coeff = dx * (hk / 6.0);
158            for c in 0..k {
159                row[c] += coeff * self.f_matrix[[k - 2, c]];
160            }
161            return;
162        }
163
164        // Interior: locate interval j with x*_j <= x <= x*_{j+1}.
165        // knots strictly increasing; binary search for the upper bound.
166        let mut j = match self
167            .knots
168            .as_slice()
169            .expect("contiguous knots")
170            .binary_search_by(|probe| probe.partial_cmp(&x).unwrap_or(std::cmp::Ordering::Less))
171        {
172            Ok(idx) => idx,      // x equals a knot: use interval starting at idx
173            Err(idx) => idx - 1, // x in (knot[idx-1], knot[idx])
174        };
175        if j >= k - 1 {
176            j = k - 2;
177        }
178        let hj = self.knots[j + 1] - self.knots[j];
179        let a_minus = (self.knots[j + 1] - x) / hj;
180        let a_plus = (x - self.knots[j]) / hj;
181        let c_minus = (a_minus * a_minus * a_minus - a_minus) * hj * hj / 6.0;
182        let c_plus = (a_plus * a_plus * a_plus - a_plus) * hj * hj / 6.0;
183        row[j] += a_minus;
184        row[j + 1] += a_plus;
185        for c in 0..k {
186            row[c] += c_minus * self.f_matrix[[j, c]] + c_plus * self.f_matrix[[j + 1, c]];
187        }
188    }
189
190    /// Dense `n × k` design matrix for a column of evaluation points.
191    pub fn design(&self, data: ArrayView1<'_, f64>) -> Array2<f64> {
192        let k = self.knots.len();
193        let n = data.len();
194        let mut x = Array2::<f64>::zeros((n, k));
195        let mut row = vec![0.0f64; k];
196        for (i, &xi) in data.iter().enumerate() {
197            self.eval_row_into(xi, &mut row);
198            for c in 0..k {
199                x[[i, c]] = row[c];
200            }
201        }
202        x
203    }
204}
205
206/// Assemble the `k × k` map `F` (`δ = F β`) from interval widths `h`.
207/// Rows 0 and k-1 are zero (natural boundary). Interior rows solve
208/// `B (F_int) = D` for the `(k-2) × k` interior block `F_int`.
209fn build_f_matrix(h: &[f64], k: usize) -> Result<Array2<f64>, BasisError> {
210    let m = k - 2; // interior count
211    // B (m x m) tridiagonal SPD.
212    let mut b_diag = vec![0.0f64; m];
213    let mut b_off = vec![0.0f64; m.saturating_sub(1)]; // b_off[i] = B[i,i+1] = B[i+1,i]
214    for i in 0..m {
215        b_diag[i] = (h[i] + h[i + 1]) / 3.0;
216    }
217    for i in 0..m.saturating_sub(1) {
218        // B[i,i+1] = h_{i+1}/6 (the shared interior width).
219        b_off[i] = h[i + 1] / 6.0;
220    }
221    // D (m x k).
222    let mut d = Array2::<f64>::zeros((m, k));
223    for i in 0..m {
224        d[[i, i]] = 1.0 / h[i];
225        d[[i, i + 1]] = -1.0 / h[i] - 1.0 / h[i + 1];
226        d[[i, i + 2]] = 1.0 / h[i + 1];
227    }
228    // Solve B X = D column-by-column with the Thomas algorithm; X = F_int.
229    let f_int = thomas_solve_multi(&b_diag, &b_off, &d)?;
230    let mut f = Array2::<f64>::zeros((k, k));
231    for i in 0..m {
232        for c in 0..k {
233            f[[i + 1, c]] = f_int[[i, c]];
234        }
235    }
236    Ok(f)
237}
238
239/// Solve a symmetric tridiagonal system `B X = RHS` for every column of `RHS`
240/// using the Thomas algorithm. `diag` is length m, `off` is length m-1
241/// (the shared sub/super-diagonal). `rhs` is `m × c`. Returns `m × c`.
242fn thomas_solve_multi(
243    diag: &[f64],
244    off: &[f64],
245    rhs: &Array2<f64>,
246) -> Result<Array2<f64>, BasisError> {
247    let m = diag.len();
248    let cols = rhs.ncols();
249    if m == 0 {
250        return Ok(Array2::<f64>::zeros((0, cols)));
251    }
252    if rhs.nrows() != m {
253        crate::bail_dim_basis!(
254            "tridiagonal solve RHS has {} rows but system is {}x{}",
255            rhs.nrows(),
256            m,
257            m
258        );
259    }
260    // Forward sweep.
261    let mut c_prime = vec![0.0f64; m]; // modified super-diagonal
262    let mut d_prime = Array2::<f64>::zeros((m, cols));
263    let denom0 = diag[0];
264    if denom0.abs() < 1e-300 {
265        crate::bail_invalid_basis!("singular tridiagonal pivot at row 0 in cr penalty solve");
266    }
267    if m > 1 {
268        c_prime[0] = off[0] / denom0;
269    }
270    for col in 0..cols {
271        d_prime[[0, col]] = rhs[[0, col]] / denom0;
272    }
273    for i in 1..m {
274        let denom = diag[i] - off[i - 1] * c_prime[i - 1];
275        if denom.abs() < 1e-300 {
276            crate::bail_invalid_basis!("singular tridiagonal pivot at row {i} in cr penalty solve");
277        }
278        if i < m - 1 {
279            c_prime[i] = off[i] / denom;
280        }
281        for col in 0..cols {
282            d_prime[[i, col]] = (rhs[[i, col]] - off[i - 1] * d_prime[[i - 1, col]]) / denom;
283        }
284    }
285    // Back substitution.
286    let mut x = Array2::<f64>::zeros((m, cols));
287    for col in 0..cols {
288        x[[m - 1, col]] = d_prime[[m - 1, col]];
289    }
290    for i in (0..m - 1).rev() {
291        for col in 0..cols {
292            x[[i, col]] = d_prime[[i, col]] - c_prime[i] * x[[i + 1, col]];
293        }
294    }
295    Ok(x)
296}
297
298/// Place `k` cr knots at evenly-spaced quantiles of the unique sorted data,
299/// exactly as mgcv's default `cr` knot placement: the first and last knots are
300/// the min/max, and the interior knots are at the `1/(k-1) … (k-2)/(k-1)`
301/// quantiles of the *unique* observed values. Returns a strictly increasing
302/// length-`k` knot vector.
303pub fn select_cr_knots(data: ArrayView1<'_, f64>, k: usize) -> Result<Array1<f64>, BasisError> {
304    if k < 3 {
305        crate::bail_invalid_basis!("cubic regression spline requires k >= 3, got {k}");
306    }
307    if data.is_empty() {
308        crate::bail_invalid_basis!("cannot place cr knots on empty data");
309    }
310    if data.iter().any(|x| !x.is_finite()) {
311        crate::bail_invalid_basis!("cr knot placement requires finite data");
312    }
313    let mut sorted: Vec<f64> = data.iter().copied().collect();
314    sorted.sort_by(f64::total_cmp);
315    // Unique values (mgcv places cr knots on the unique data quantiles).
316    let mut unique: Vec<f64> = Vec::with_capacity(sorted.len());
317    for &v in &sorted {
318        if unique.last().map(|&p| p != v).unwrap_or(true) {
319            unique.push(v);
320        }
321    }
322    let nu = unique.len();
323    if nu < k {
324        crate::bail_invalid_basis!(
325            "cubic regression spline with k={k} requires at least {k} distinct \
326             values, got {nu}"
327        );
328    }
329    // mgcv's `place.knots`: knots at quantile type-1-ish positions over the
330    // index range [0, nu-1] evenly in (k-1) steps. Endpoints are exact min/max.
331    let mut knots = Array1::<f64>::zeros(k);
332    for j in 0..k {
333        let pos = (j as f64) * ((nu - 1) as f64) / ((k - 1) as f64);
334        let lo = pos.floor() as usize;
335        let hi = pos.ceil() as usize;
336        let frac = pos - lo as f64;
337        knots[j] = if lo == hi {
338            unique[lo]
339        } else {
340            unique[lo] * (1.0 - frac) + unique[hi] * frac
341        };
342    }
343    // Guard strict monotonicity in case of ties from interpolation rounding.
344    for i in 1..k {
345        if !(knots[i] > knots[i - 1]) {
346            crate::bail_invalid_basis!(
347                "cr knot placement produced non-increasing knots (too many knots \
348                 for the data spread); reduce k"
349            );
350        }
351    }
352    Ok(knots)
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    /// A cr smooth must reproduce constants and lines exactly: the penalty null
360    /// space is {const, linear}, and the design with values β_i = f(x*_i)
361    /// interpolates any line through the knots with zero penalty.
362    #[test]
363    fn cr_penalty_nullspace_is_const_and_linear() {
364        let knots = Array1::from(vec![0.0, 0.3, 0.55, 0.8, 1.0]);
365        let cr = CubicRegressionBasis::new(knots.clone()).unwrap();
366        let s = cr.penalty();
367        let k = knots.len();
368        // const: β = 1.
369        let ones = Array1::<f64>::ones(k);
370        let q_const = ones.dot(&s.dot(&ones));
371        assert!(q_const.abs() < 1e-9, "const not in null space: {q_const}");
372        // linear: β_i = knot_i.
373        let lin = knots.clone();
374        let q_lin = lin.dot(&s.dot(&lin));
375        assert!(q_lin.abs() < 1e-9, "linear not in null space: {q_lin}");
376        // a quadratic should have positive penalty.
377        let quad: Array1<f64> = knots.mapv(|x| x * x);
378        let q_quad = quad.dot(&s.dot(&quad));
379        assert!(q_quad > 1e-6, "quadratic penalty not positive: {q_quad}");
380    }
381
382    /// The design must reproduce a line exactly at arbitrary evaluation points
383    /// (interior and extrapolated), since a line is in the cr span.
384    #[test]
385    fn cr_design_reproduces_line_including_extrapolation() {
386        let knots = Array1::from(vec![0.0, 0.25, 0.5, 0.75, 1.0]);
387        let cr = CubicRegressionBasis::new(knots.clone()).unwrap();
388        // f(x) = 2 + 3x  → β_i = 2 + 3*knot_i.
389        let beta: Array1<f64> = knots.mapv(|x| 2.0 + 3.0 * x);
390        let xs = Array1::from(vec![-0.4, 0.0, 0.13, 0.5, 0.87, 1.0, 1.3]);
391        let design = cr.design(xs.view());
392        let fitted = design.dot(&beta);
393        for (i, &x) in xs.iter().enumerate() {
394            let truth = 2.0 + 3.0 * x;
395            assert!(
396                (fitted[i] - truth).abs() < 1e-9,
397                "line not reproduced at x={x}: got {}, want {truth}",
398                fitted[i]
399            );
400        }
401    }
402
403    /// Knot placement returns endpoints = min/max and strictly increasing knots.
404    #[test]
405    fn cr_knots_span_data_and_increase() {
406        let data = Array1::from((0..50).map(|i| i as f64 / 49.0).collect::<Vec<_>>());
407        let knots = select_cr_knots(data.view(), 5).unwrap();
408        assert_eq!(knots.len(), 5);
409        assert!((knots[0] - 0.0).abs() < 1e-12);
410        assert!((knots[4] - 1.0).abs() < 1e-12);
411        for i in 1..5 {
412            assert!(knots[i] > knots[i - 1]);
413        }
414    }
415}