Skip to main content

gam_terms/basis/
bspline_eval.rs

1use super::*;
2
3/// Absolute floor below which a B-spline knot span (`t_{i+k} - t_i`) is treated
4/// as degenerate: the corresponding Cox–de Boor / derivative-recurrence
5/// denominator is then skipped (its term contributes zero), and a zero-support
6/// basis function is rejected. Set well above `f64::EPSILON` so that knot
7/// vectors with near-coincident knots are caught before the division amplifies
8/// rounding noise, yet far below any meaningful covariate-scale knot spacing.
9pub(crate) const KNOT_SPAN_DEGENERACY_FLOOR: f64 = 1e-12;
10
11/// Absolute distance by which a covariate value must lie outside the clamped
12/// B-spline domain before the linear extrapolation correction is applied; below
13/// this the point is treated as on-boundary and no extrapolation term is added.
14pub(crate) const BSPLINE_EXTRAPOLATION_THRESHOLD: f64 = 1e-12;
15
16/// Default number of rows in each block the streaming design evaluators
17/// materialize at a time when the caller does not supply an explicit chunk
18/// size. Bounds the transient working set (one `chunk_rows × p` dense block)
19/// while staying large enough to amortize per-chunk kernel-column setup.
20pub(crate) const DEFAULT_STREAMING_CHUNK_ROWS: usize = 2048;
21/// Marker type for dense basis matrix output.
22pub struct Dense;
23
24/// Marker type for sparse basis matrix output.
25pub struct Sparse;
26
27/// Trait for selecting basis storage format at compile time.
28pub trait BasisOutput {
29    type Output;
30}
31
32impl BasisOutput for Dense {
33    type Output = Arc<Array2<f64>>;
34}
35
36impl BasisOutput for Sparse {
37    type Output = SparseColMat<usize, f64>;
38}
39
40/// Unified B-spline basis generation with configurable storage, knot source, and options.
41///
42/// This function consolidates various basis generation functions into a single entry point.
43/// Use type parameters to select output format:
44/// - `create_basis::<Dense>(...)` for dense `Array2<f64>` output
45/// - `create_basis::<Sparse>(...)` for sparse `SparseColMat` output
46///
47/// # Arguments
48/// * `data` - Data points to evaluate basis at
49/// * `knot_source` - Either pre-computed knots or parameters for uniform generation
50/// * `degree` - B-spline degree (e.g., 3 for cubic)
51/// * `options` - Derivative order and other options
52///
53/// # Returns
54/// Tuple of (basis matrix, knot vector used)
55pub fn create_basis<O: BasisOutputFormat>(
56    data: ArrayView1<f64>,
57    knot_source: KnotSource<'_>,
58    degree: usize,
59    options: BasisOptions,
60) -> Result<(O::Output, Array1<f64>), BasisError> {
61    if degree < 1 {
62        return Err(BasisError::InvalidDegree(degree));
63    }
64
65    if options.basis_family != BasisFamily::BSpline && options.derivative_order != 0 {
66        crate::bail_invalid_basis!("derivatives are only supported for BasisFamily::BSpline");
67    }
68
69    let eval_kind = match options.derivative_order {
70        0 => BasisEvalKind::Basis,
71        1 => BasisEvalKind::FirstDerivative,
72        2 => BasisEvalKind::SecondDerivative,
73        n => {
74            crate::bail_invalid_basis!(
75                "unsupported derivative order {n}; only 0, 1, 2 are supported"
76            );
77        }
78    };
79
80    let knot_degree = match options.basis_family {
81        BasisFamily::BSpline | BasisFamily::MSpline => degree,
82        BasisFamily::ISpline => degree
83            .checked_add(1)
84            .ok_or_else(|| BasisError::InvalidInput("I-spline degree overflow".to_string()))?,
85    };
86
87    let knotvec: Array1<f64> = match knot_source {
88        KnotSource::Provided(view) => view.to_owned(),
89        KnotSource::Generate {
90            data_range,
91            num_internal_knots,
92        } => {
93            if data_range.0 > data_range.1 {
94                return Err(BasisError::InvalidRange(data_range.0, data_range.1));
95            }
96            if data_range.0 == data_range.1 {
97                return Err(BasisError::DegenerateRange(num_internal_knots));
98            }
99            internal::generate_full_knot_vector(data_range, num_internal_knots, knot_degree)?
100        }
101    };
102    validate_knots_for_degree(knotvec.view(), knot_degree)?;
103    validate_knot_spans_nondegenerate(knotvec.view(), knot_degree)?;
104
105    match options.basis_family {
106        BasisFamily::BSpline => O::build_basis(data, degree, eval_kind, knotvec),
107        BasisFamily::MSpline => {
108            if O::LAYOUT.is_sparse() {
109                let sparse = create_mspline_sparse(data, knotvec.view(), degree)?;
110                Ok((O::from_sparse(sparse)?, knotvec))
111            } else {
112                let dense = create_mspline_dense(data, knotvec.view(), degree)?;
113                Ok((O::from_dense(dense)?, knotvec))
114            }
115        }
116        BasisFamily::ISpline => {
117            if O::LAYOUT.is_sparse() {
118                crate::bail_invalid_basis!(
119                    "BasisFamily::ISpline does not support sparse output; use Dense"
120                );
121            }
122            let dense = create_ispline_dense(data, knotvec.view(), degree)?;
123            Ok((O::from_dense(dense)?, knotvec))
124        }
125    }
126}
127
128/// Applies first-order linear extension outside a knot-domain interval to a basis matrix
129/// that was evaluated at clamped coordinates.
130///
131/// Given `z_raw` and `z_clamped = clamp(z_raw, left, right)`, this mutates
132/// `basisvalues` in-place as:
133/// `B_ext(z_raw) = B(z_clamped) + (z_raw - z_clamped) * B'(z_clamped)`.
134pub fn apply_linear_extension_from_first_derivative(
135    z_raw: ArrayView1<f64>,
136    z_clamped: ArrayView1<f64>,
137    knot_vector: ArrayView1<f64>,
138    degree: usize,
139    basisvalues: &mut Array2<f64>,
140) -> Result<(), BasisError> {
141    if z_raw.len() != z_clamped.len() {
142        crate::bail_dim_basis!("z_raw and z_clamped must have equal length");
143    }
144    if basisvalues.nrows() != z_raw.len() {
145        crate::bail_dim_basis!("basis row count must match z length");
146    }
147
148    let mut needs_ext = false;
149    for i in 0..z_raw.len() {
150        if (z_raw[i] - z_clamped[i]).abs() > BSPLINE_EXTRAPOLATION_THRESHOLD {
151            needs_ext = true;
152            break;
153        }
154    }
155    if !needs_ext {
156        return Ok(());
157    }
158
159    let (b_prime_arc, _) = create_basis::<Dense>(
160        z_clamped,
161        KnotSource::Provided(knot_vector),
162        degree,
163        BasisOptions::first_derivative(),
164    )?;
165    let b_prime = b_prime_arc.as_ref();
166    if b_prime.nrows() != basisvalues.nrows() || b_prime.ncols() != basisvalues.ncols() {
167        crate::bail_dim_basis!("basis derivative shape mismatch");
168    }
169
170    for i in 0..z_raw.len() {
171        let dz = z_raw[i] - z_clamped[i];
172        if dz.abs() <= BSPLINE_EXTRAPOLATION_THRESHOLD {
173            continue;
174        }
175        for j in 0..basisvalues.ncols() {
176            basisvalues[[i, j]] += dz * b_prime[[i, j]];
177        }
178    }
179    Ok(())
180}
181
182/// Storage layout discriminant for [`BasisOutputFormat`] impls. Encoded as an
183/// enum rather than a bool so the type-level distinction reads as
184/// "Dense vs Sparse" at call sites instead of a polarity-sensitive flag.
185#[derive(Debug, Clone, Copy, Eq, PartialEq)]
186pub enum BasisStorageLayout {
187    Dense,
188    Sparse,
189}
190
191impl BasisStorageLayout {
192    #[inline]
193    pub const fn is_sparse(self) -> bool {
194        matches!(self, Self::Sparse)
195    }
196}
197
198/// Trait for building basis matrices with different storage formats.
199/// This is an implementation detail for the unified `create_basis` function.
200pub trait BasisOutputFormat {
201    type Output;
202    const LAYOUT: BasisStorageLayout;
203
204    fn build_basis(
205        data: ArrayView1<f64>,
206        degree: usize,
207        eval_kind: BasisEvalKind,
208        knotvec: Array1<f64>,
209    ) -> Result<(Self::Output, Array1<f64>), BasisError>;
210
211    fn from_dense(dense: Array2<f64>) -> Result<Self::Output, BasisError>;
212    fn from_sparse(sparse: SparseColMat<usize, f64>) -> Result<Self::Output, BasisError>;
213}
214
215impl BasisOutputFormat for Dense {
216    type Output = Arc<Array2<f64>>;
217    const LAYOUT: BasisStorageLayout = BasisStorageLayout::Dense;
218
219    fn build_basis(
220        data: ArrayView1<f64>,
221        degree: usize,
222        eval_kind: BasisEvalKind,
223        knotvec: Array1<f64>,
224    ) -> Result<(Self::Output, Array1<f64>), BasisError> {
225        let knotview = knotvec.view();
226
227        let num_basis_functions = knotview.len().saturating_sub(degree + 1);
228        let basis_matrix = if should_use_sparse_basis(num_basis_functions, degree, 1) {
229            let left = knotview[degree];
230            let right = knotview[num_basis_functions];
231            let data_clamped = data.mapv(|x| x.clamp(left, right));
232            let sparse = generate_basis_internal::<SparseStorage>(
233                data_clamped.view(),
234                knotview,
235                degree,
236                eval_kind,
237            )?;
238            let mut dense = Array2::<f64>::zeros((sparse.nrows(), sparse.ncols()));
239            let (symbolic, values) = sparse.parts();
240            let col_ptr = symbolic.col_ptr();
241            let row_idx = symbolic.row_idx();
242            for col in 0..sparse.ncols() {
243                let start = col_ptr[col];
244                let end = col_ptr[col + 1];
245                for idx in start..end {
246                    dense[[row_idx[idx], col]] += values[idx];
247                }
248            }
249            apply_dense_bspline_extrapolation(data, knotview, degree, eval_kind, &mut dense)?;
250            dense
251        } else {
252            generate_basis_internal::<DenseStorage>(data.view(), knotview, degree, eval_kind)?
253        };
254
255        Ok((Arc::new(basis_matrix), knotvec))
256    }
257
258    fn from_dense(dense: Array2<f64>) -> Result<Self::Output, BasisError> {
259        Ok(Arc::new(dense))
260    }
261
262    fn from_sparse(sparse: SparseColMat<usize, f64>) -> Result<Self::Output, BasisError> {
263        let mut dense = Array2::<f64>::zeros((sparse.nrows(), sparse.ncols()));
264        let (symbolic, values) = sparse.parts();
265        let col_ptr = symbolic.col_ptr();
266        let row_idx = symbolic.row_idx();
267        for col in 0..sparse.ncols() {
268            let start = col_ptr[col];
269            let end = col_ptr[col + 1];
270            for idx in start..end {
271                dense[[row_idx[idx], col]] += values[idx];
272            }
273        }
274        Ok(Arc::new(dense))
275    }
276}
277
278pub(crate) fn apply_dense_bspline_extrapolation(
279    data: ArrayView1<f64>,
280    knotview: ArrayView1<f64>,
281    degree: usize,
282    eval_kind: BasisEvalKind,
283    basis_matrix: &mut Array2<f64>,
284) -> Result<(), BasisError> {
285    let num_basis_functions = basis_matrix.ncols();
286    if num_basis_functions == 0 {
287        return Ok(());
288    }
289
290    let left = knotview[degree];
291    let right = knotview[num_basis_functions];
292    if !(left.is_finite() && right.is_finite() && left < right) {
293        return Ok(());
294    }
295
296    // Open (unclamped) knots: the value evaluator clamps the eval point to the
297    // modeling interval `[knots[degree], knots[num_basis]]` (constant extension —
298    // there is no linear extension because `has_clamped_bspline_boundaries` is
299    // false). A constant function has zero derivative, so BOTH the first and
300    // second derivative must be zero in the exterior spans. Without this, the
301    // dense derivative path leaves the raw mathematical B-spline derivative in the
302    // boundary spans (nonzero), which no longer matches a finite difference of the
303    // constant-extended value basis (gam#1348). The genuine cyclic basis never
304    // reaches here (it pre-wraps its input into the base period).
305    if !has_clamped_bspline_boundaries(knotview, degree) {
306        if matches!(
307            eval_kind,
308            BasisEvalKind::FirstDerivative | BasisEvalKind::SecondDerivative
309        ) {
310            for (i, &x) in data.iter().enumerate() {
311                if x < left || x > right {
312                    basis_matrix.row_mut(i).fill(0.0);
313                }
314            }
315        }
316        return Ok(());
317    }
318
319    if matches!(eval_kind, BasisEvalKind::FirstDerivative) {
320        let num_basis_lower = knotview.len().saturating_sub(degree);
321        let mut lower_basis = vec![0.0; num_basis_lower];
322        let mut lower_scratch = internal::BsplineScratch::new(degree.saturating_sub(1));
323        for (i, &x) in data.iter().enumerate() {
324            if x >= left && x <= right {
325                continue;
326            }
327            let x_c = x.clamp(left, right);
328            let mut row = basis_matrix.row_mut(i);
329            let row_slice = row
330                .as_slice_mut()
331                .expect("basis matrix rows should be contiguous");
332            evaluate_bspline_derivative_scalar_into(
333                x_c,
334                knotview,
335                degree,
336                row_slice,
337                &mut lower_basis,
338                &mut lower_scratch,
339            )?;
340        }
341    }
342
343    if matches!(eval_kind, BasisEvalKind::SecondDerivative) {
344        for (i, &x) in data.iter().enumerate() {
345            if x < left || x > right {
346                basis_matrix.row_mut(i).fill(0.0);
347            }
348        }
349    }
350
351    if matches!(eval_kind, BasisEvalKind::Basis) {
352        let z_clamped = data.mapv(|x| x.clamp(left, right));
353        apply_linear_extension_from_first_derivative(
354            data,
355            z_clamped.view(),
356            knotview,
357            degree,
358            basis_matrix,
359        )?;
360    }
361
362    Ok(())
363}
364
365#[inline]
366pub(crate) fn has_clamped_bspline_boundaries(knotview: ArrayView1<f64>, degree: usize) -> bool {
367    let clamp_count = degree + 1;
368    if knotview.len() < 2 * clamp_count {
369        return false;
370    }
371    let left = knotview[0];
372    let right = knotview[knotview.len() - 1];
373    let scale = (right - left).abs().max(1.0);
374    let tol = KNOT_SPAN_DEGENERACY_FLOOR * scale;
375    let left_clamped = knotview
376        .iter()
377        .take(clamp_count)
378        .all(|&k| (k - left).abs() <= tol);
379    let right_clamped = knotview
380        .iter()
381        .rev()
382        .take(clamp_count)
383        .all(|&k| (k - right).abs() <= tol);
384    left_clamped && right_clamped
385}
386
387/// Clamp a B-spline derivative evaluation point to the modeling interval
388/// `[knots[degree], knots[num_basis]]`, mirroring the value evaluator's clamp
389/// (`evaluate_splines_at_point_into`). Outside that interval the non-periodic
390/// value basis is a linear extension, so its derivative is the constant boundary
391/// derivative — which is exactly what evaluating at the clamped endpoint yields.
392/// Keeping the derivative's boundary semantics identical to the value's is what
393/// makes the analytic derivative equal a finite difference of the value (gam#1348).
394#[inline]
395pub(crate) fn clamp_eval_point_to_modeling_interval(
396    x: f64,
397    knotview: ArrayView1<f64>,
398    degree: usize,
399) -> f64 {
400    let num_basis = knotview.len().saturating_sub(degree + 1);
401    if num_basis == 0 {
402        return x;
403    }
404    let left = knotview[degree];
405    let right = knotview[num_basis];
406    if !left.is_finite() || !right.is_finite() || left >= right {
407        return x;
408    }
409    x.clamp(left, right)
410}
411
412/// True when `x` lies strictly outside the modeling interval of an *open*
413/// (non-clamped) knot vector, where the analytic B-spline derivative of every
414/// order must be zero.
415///
416/// The boundary extension differs by knot geometry, and the derivative has to
417/// follow whatever the value basis does so that it equals a finite difference of
418/// the value (gam#1348):
419///
420/// * **Open / unclamped** knots — the value evaluator clamps its argument to
421///   `[t[degree], t[num_basis]]` and holds the value *constant* outside it
422///   (`has_clamped_bspline_boundaries` is false, so the dense builder applies no
423///   linear extension). A constant has zero derivative, so the exterior
424///   derivative is zero — this returns `true`.
425/// * **Clamped** knots — the value is extended *linearly* past the boundary, so
426///   the exterior derivative is the nonzero boundary slope obtained by evaluating
427///   at the clamped endpoint. This returns `false`, leaving the existing
428///   clamp-and-evaluate path in charge.
429///
430/// The dense builder already zeroes the open-knot exterior in
431/// [`apply_dense_bspline_extrapolation`]; this predicate lets the *per-point*
432/// sparse, scalar, and recurrence evaluators do the same, so every derivative
433/// path agrees with the value basis — not just the dense one the public
434/// `bspline_basis_derivative` happens to use. (Genuinely cyclic bases pre-wrap
435/// their input into the base period and never reach here.)
436#[inline]
437pub(crate) fn open_knot_derivative_exterior_is_zero(
438    x: f64,
439    knotview: ArrayView1<f64>,
440    degree: usize,
441) -> bool {
442    let num_basis = knotview.len().saturating_sub(degree + 1);
443    if num_basis == 0 {
444        return false;
445    }
446    let left = knotview[degree];
447    let right = knotview[num_basis];
448    if !(left.is_finite() && right.is_finite() && left < right) {
449        return false;
450    }
451    (x < left || x > right) && !has_clamped_bspline_boundaries(knotview, degree)
452}
453
454/// True when the *linear* (clamped-knot) exterior extension forces the order-`k`
455/// derivative to vanish outside the modeling interval.
456///
457/// On a clamped knot vector the value basis is extended past the boundary as the
458/// affine function `B(x_b) + (x − x_b)·B'(x_b)` (see
459/// [`apply_dense_bspline_extrapolation`] and the value clamp in
460/// `clamp_eval_point_to_modeling_interval`). An affine function has a constant
461/// first derivative (the boundary slope) and **identically zero** second and
462/// higher derivatives, so for `derivative_order ≥ 2` the exterior derivative is
463/// zero — *not* the boundary's own `B^{(k)}(x_b)`, which is what naively
464/// clamping the eval point and evaluating the order-`k` recurrence returns.
465///
466/// This is the clamped-knot counterpart of
467/// [`open_knot_derivative_exterior_is_zero`] (which zeroes *every* order for the
468/// *constant* open-knot extension). The dense builder already enforces the
469/// affine-exterior contract in [`apply_dense_bspline_extrapolation`]; this
470/// predicate lets the per-point scalar/recurrence evaluators agree with it so a
471/// higher-derivative design equals a finite difference of the value in the
472/// boundary spans.
473#[inline]
474pub(crate) fn linear_extension_higher_derivative_is_zero(
475    x: f64,
476    knotview: ArrayView1<f64>,
477    degree: usize,
478    derivative_order: usize,
479) -> bool {
480    if derivative_order < 2 {
481        return false;
482    }
483    let num_basis = knotview.len().saturating_sub(degree + 1);
484    if num_basis == 0 {
485        return false;
486    }
487    let left = knotview[degree];
488    let right = knotview[num_basis];
489    if !(left.is_finite() && right.is_finite() && left < right) {
490        return false;
491    }
492    x < left || x > right
493}
494
495#[inline]
496pub(crate) fn one_sided_derivative_eval_point(
497    x: f64,
498    knotview: ArrayView1<f64>,
499    degree: usize,
500) -> f64 {
501    let num_basis = knotview.len().saturating_sub(degree + 1);
502    if num_basis == 0 {
503        return x;
504    }
505    let left = knotview[degree];
506    let right = knotview[num_basis];
507    if !left.is_finite() || !right.is_finite() || left >= right {
508        return x;
509    }
510    if x == left {
511        let next = left.next_up();
512        if next < right {
513            next
514        } else {
515            left + 0.5 * (right - left)
516        }
517    } else if x == right {
518        let prev = right.next_down();
519        if prev > left {
520            prev
521        } else {
522            left + 0.5 * (right - left)
523        }
524    } else {
525        x
526    }
527}
528
529impl BasisOutputFormat for Sparse {
530    type Output = SparseColMat<usize, f64>;
531    const LAYOUT: BasisStorageLayout = BasisStorageLayout::Sparse;
532
533    fn build_basis(
534        data: ArrayView1<f64>,
535        degree: usize,
536        eval_kind: BasisEvalKind,
537        knotvec: Array1<f64>,
538    ) -> Result<(Self::Output, Array1<f64>), BasisError> {
539        let knotview = knotvec.view();
540        let sparse =
541            generate_basis_internal::<SparseStorage>(data.view(), knotview, degree, eval_kind)?;
542        Ok((sparse, knotvec))
543    }
544
545    fn from_dense(dense: Array2<f64>) -> Result<Self::Output, BasisError> {
546        let (nrows, ncols) = dense.dim();
547        let mut triplets: Vec<Triplet<usize, usize, f64>> = Vec::new();
548        triplets.reserve(nrows.saturating_mul(ncols / 8));
549        for i in 0..nrows {
550            for j in 0..ncols {
551                let v = dense[[i, j]];
552                if v.abs() > 0.0 {
553                    triplets.push(Triplet::new(i, j, v));
554                }
555            }
556        }
557        SparseColMat::try_new_from_triplets(nrows, ncols, &triplets)
558            .map_err(|e| BasisError::SparseCreation(format!("{e:?}")))
559    }
560
561    fn from_sparse(sparse: SparseColMat<usize, f64>) -> Result<Self::Output, BasisError> {
562        Ok(sparse)
563    }
564}
565
566pub(crate) fn validate_knots_for_degree(
567    knot_vector: ArrayView1<f64>,
568    degree: usize,
569) -> Result<(), BasisError> {
570    if degree < 1 {
571        return Err(BasisError::InvalidDegree(degree));
572    }
573
574    let required_knots = 2 * (degree + 1);
575    if knot_vector.len() < required_knots {
576        return Err(BasisError::InsufficientKnotsForDegree {
577            degree,
578            required: required_knots,
579            provided: knot_vector.len(),
580        });
581    }
582
583    if knot_vector.iter().any(|&k| !k.is_finite()) {
584        return Err(BasisError::InvalidKnotVector(
585            "knot vector contains non-finite (NaN or Infinity) values".to_string(),
586        ));
587    }
588
589    if knot_vector.len() >= 2 {
590        for i in 0..(knot_vector.len() - 1) {
591            if knot_vector[i] > knot_vector[i + 1] {
592                return Err(BasisError::InvalidKnotVector(
593                    "knot vector is not non-decreasing".to_string(),
594                ));
595            }
596        }
597    }
598
599    Ok(())
600}
601
602/// Rejects knot vectors whose effective basis functions have zero support
603/// (i.e. `t[i+degree+1] == t[i]` for any `i`). This is stricter than the
604/// structural `validate_knots_for_degree` and is only appropriate at the
605/// user-facing top-level of basis construction — the recursive derivative
606/// evaluators repeatedly call `validate_knots_for_degree` with a reduced
607/// `degree` on the *same* (clamped) knot vector, where the outermost lower-
608/// degree "basis function" always collapses to zero support by construction
609/// and is harmless because the derivative recursion guards the matching
610/// `1/(t_{i+k}-t_i)` denominator with an absolute-value check.
611pub(crate) fn validate_knot_spans_nondegenerate(
612    knot_vector: ArrayView1<f64>,
613    degree: usize,
614) -> Result<(), BasisError> {
615    if knot_vector.len() <= degree + 1 {
616        return Ok(());
617    }
618    let num_basis = knot_vector.len() - degree - 1;
619    for i in 0..num_basis {
620        let span = knot_vector[i + degree + 1] - knot_vector[i];
621        if span <= KNOT_SPAN_DEGENERACY_FLOOR {
622            return Err(BasisError::InvalidKnotVector(format!(
623                "basis function {i} has zero support: t[i+degree+1]-t[i]={span:.3e} must be > 0"
624            )));
625        }
626    }
627    Ok(())
628}
629
630#[derive(Clone, Copy, Debug)]
631pub enum BasisEvalKind {
632    Basis,
633    FirstDerivative,
634    SecondDerivative,
635}
636
637pub(crate) struct BasisEvalScratch {
638    pub(crate) basis: internal::BsplineScratch,
639    pub(crate) lower_basis: Vec<f64>,
640    pub(crate) lower_scratch: internal::BsplineScratch,
641    pub(crate) derivative_workspace: BsplineDerivativeWorkspace,
642}
643
644impl BasisEvalScratch {
645    pub(crate) fn new(degree: usize) -> Self {
646        let lower_degree = degree.saturating_sub(1);
647        Self {
648            basis: internal::BsplineScratch::new(degree),
649            lower_basis: vec![0.0; lower_degree + 1],
650            lower_scratch: internal::BsplineScratch::new(lower_degree),
651            derivative_workspace: BsplineDerivativeWorkspace::new(),
652        }
653    }
654}
655
656#[inline]
657pub(crate) fn copy_full_row_to_sparse_window(full: &[f64], values: &mut [f64]) -> usize {
658    values.fill(0.0);
659    let Some(start_col) = full.iter().position(|&v| v != 0.0) else {
660        return 0;
661    };
662    for (offset, value_slot) in values.iter_mut().enumerate() {
663        if let Some(&v) = full.get(start_col + offset) {
664            *value_slot = v;
665        }
666    }
667    start_col
668}
669
670pub(crate) fn evaluate_splines_derivative_sparse_intowith_lower(
671    x: f64,
672    degree: usize,
673    knotview: ArrayView1<f64>,
674    values: &mut [f64],
675    lowervalues: &mut [f64],
676    lower_scratch: &mut internal::BsplineScratch,
677) -> usize {
678    let num_basis = knotview.len().saturating_sub(degree + 1);
679    if degree == 0 {
680        values.fill(0.0);
681        return 0;
682    }
683
684    let num_basis_lower = knotview.len().saturating_sub(degree);
685    if lowervalues.len() < num_basis_lower {
686        values.fill(0.0);
687        return 0;
688    }
689    lowervalues[..num_basis_lower].fill(0.0);
690
691    // Non-periodic (open/clamped) B-spline derivative, kept consistent with the
692    // value basis so it equals a finite difference of the value (gam#1348). On an
693    // *open* knot vector the value is held constant outside the modeling interval,
694    // so the exterior derivative is zero — the dense builder enforces this in
695    // `apply_dense_bspline_extrapolation`, and the per-point sparse path (used
696    // directly for open knots, which never take the clamped extrapolation
697    // fallback in `SparseStorage::build`) must do the same or a P-spline
698    // derivative design disagrees with its own value in the boundary spans.
699    // Clamped knots extend linearly and keep their nonzero boundary slope, so the
700    // guard intentionally fires only for the open-knot exterior. No periodic wrap:
701    // wrapping moved boundary-span points onto unrelated interior columns;
702    // genuinely cyclic bases pre-wrap their input into the base period upstream.
703    if open_knot_derivative_exterior_is_zero(x, knotview, degree) {
704        values.fill(0.0);
705        return 0;
706    }
707    // Clamped knots extend the value linearly past the boundary, so the exterior
708    // first derivative is the constant boundary slope obtained by evaluating at
709    // the clamped endpoint — mirror the value clamp (and the scalar derivative
710    // path) here so the sparse derivative agrees with a finite difference of the
711    // value in the boundary spans. Without the clamp a far-exterior point lands
712    // outside the degree-(d−1) support and the recurrence reads zero, breaking
713    // the boundary-slope contract.
714    let x_clamped = clamp_eval_point_to_modeling_interval(x, knotview, degree);
715    let x_eval = one_sided_derivative_eval_point(x_clamped, knotview, degree);
716    internal::evaluate_splines_at_point_full_support_into(
717        x_eval,
718        degree - 1,
719        knotview,
720        &mut lowervalues[..num_basis_lower],
721        lower_scratch,
722    );
723
724    let mut full_derivative = vec![0.0; num_basis];
725    for i in 0..num_basis {
726        let denom_left = knotview[i + degree] - knotview[i];
727        let denom_right = knotview[i + degree + 1] - knotview[i + 1];
728        let left_term = if denom_left.abs() > KNOT_SPAN_DEGENERACY_FLOOR {
729            lowervalues[i] / denom_left
730        } else {
731            0.0
732        };
733        let right_term = if denom_right.abs() > KNOT_SPAN_DEGENERACY_FLOOR {
734            lowervalues[i + 1] / denom_right
735        } else {
736            0.0
737        };
738        let value = (degree as f64) * (left_term - right_term);
739        full_derivative[i] = value;
740    }
741
742    copy_full_row_to_sparse_window(&full_derivative, values)
743}
744
745#[inline]
746pub(crate) fn evaluate_splines_derivative_sparse_into(
747    x: f64,
748    degree: usize,
749    knotview: ArrayView1<f64>,
750    values: &mut [f64],
751    scratch: &mut BasisEvalScratch,
752) -> usize {
753    let num_basis_lower = knotview.len().saturating_sub(degree);
754    if scratch.lower_basis.len() != num_basis_lower {
755        scratch.lower_basis.resize(num_basis_lower, 0.0);
756        scratch
757            .lower_scratch
758            .ensure_degree(degree.saturating_sub(1));
759    }
760    evaluate_splines_derivative_sparse_intowith_lower(
761        x,
762        degree,
763        knotview,
764        values,
765        &mut scratch.lower_basis,
766        &mut scratch.lower_scratch,
767    )
768}
769
770pub(crate) fn evaluate_splinessecond_derivative_sparse_into(
771    x: f64,
772    degree: usize,
773    knotview: ArrayView1<f64>,
774    values: &mut [f64],
775    scratch: &mut BasisEvalScratch,
776) -> usize {
777    let num_basis = knotview.len().saturating_sub(degree + 1);
778    if degree < 2 {
779        values.fill(0.0);
780        return 0;
781    }
782
783    if scratch.lower_basis.len() != num_basis {
784        scratch.lower_basis.resize(num_basis, 0.0);
785    }
786    evaluate_bspline_derivative_recurrence_into(
787        2,
788        x,
789        knotview,
790        degree,
791        &mut scratch.lower_basis,
792        &mut scratch.derivative_workspace,
793        0,
794    )
795    .expect("validated B-spline second-derivative inputs");
796
797    copy_full_row_to_sparse_window(&scratch.lower_basis, values)
798}
799
800#[inline]
801pub(crate) fn evaluate_splines_sparsewith_kind(
802    x: f64,
803    degree: usize,
804    knotview: ArrayView1<f64>,
805    eval_kind: BasisEvalKind,
806    values: &mut [f64],
807    scratch: &mut BasisEvalScratch,
808) -> usize {
809    match eval_kind {
810        BasisEvalKind::Basis => {
811            internal::evaluate_splines_sparse_into(x, degree, knotview, values, &mut scratch.basis)
812        }
813        BasisEvalKind::FirstDerivative => {
814            evaluate_splines_derivative_sparse_into(x, degree, knotview, values, scratch)
815        }
816        BasisEvalKind::SecondDerivative => {
817            evaluate_splinessecond_derivative_sparse_into(x, degree, knotview, values, scratch)
818        }
819    }
820}
821
822#[inline]
823pub(crate) fn evaluate_bsplinerow_entries<F>(
824    x: f64,
825    degree: usize,
826    knotview: ArrayView1<f64>,
827    eval_kind: BasisEvalKind,
828    num_basis_functions: usize,
829    scratch: &mut BasisEvalScratch,
830    values: &mut [f64],
831    mut write_entry: F,
832) where
833    F: FnMut(usize, f64),
834{
835    let start_col =
836        evaluate_splines_sparsewith_kind(x, degree, knotview, eval_kind, values, scratch);
837    for (offset, &v) in values.iter().enumerate() {
838        if v == 0.0 {
839            continue;
840        }
841        let col_j = start_col + offset;
842        if col_j < num_basis_functions {
843            write_entry(col_j, v);
844        }
845    }
846}
847
848pub(crate) trait BasisStorage {
849    type Output;
850
851    fn build(
852        data: ArrayView1<f64>,
853        knotview: ArrayView1<f64>,
854        degree: usize,
855        eval_kind: BasisEvalKind,
856        num_basis_functions: usize,
857        support: usize,
858        use_parallel: bool,
859    ) -> Result<Self::Output, BasisError>;
860}
861
862pub(crate) struct DenseStorage;
863
864impl BasisStorage for DenseStorage {
865    type Output = Array2<f64>;
866
867    fn build(
868        data: ArrayView1<f64>,
869        knotview: ArrayView1<f64>,
870        degree: usize,
871        eval_kind: BasisEvalKind,
872        num_basis_functions: usize,
873        support: usize,
874        use_parallel: bool,
875    ) -> Result<Self::Output, BasisError> {
876        let mut basis_matrix = Array2::zeros((data.len(), num_basis_functions));
877
878        if let (true, Some(data_slice)) = (use_parallel, data.as_slice()) {
879            basis_matrix
880                .axis_iter_mut(Axis(0))
881                .into_par_iter()
882                .zip(data_slice.par_iter().copied())
883                .for_each_init(
884                    || (BasisEvalScratch::new(degree), vec![0.0; support]),
885                    |(scratch, values), (mut row, x)| {
886                        let row_slice = row
887                            .as_slice_mut()
888                            .expect("basis matrix rows should be contiguous");
889                        evaluate_bsplinerow_entries(
890                            x,
891                            degree,
892                            knotview,
893                            eval_kind,
894                            num_basis_functions,
895                            scratch,
896                            values,
897                            |col_j, v| row_slice[col_j] = v,
898                        );
899                    },
900                );
901        } else {
902            let mut scratch = BasisEvalScratch::new(degree);
903            let mut values = vec![0.0; support];
904            for (mut row, &x) in basis_matrix.axis_iter_mut(Axis(0)).zip(data.iter()) {
905                let row_slice = row
906                    .as_slice_mut()
907                    .expect("basis matrix rows should be contiguous");
908                evaluate_bsplinerow_entries(
909                    x,
910                    degree,
911                    knotview,
912                    eval_kind,
913                    num_basis_functions,
914                    &mut scratch,
915                    &mut values,
916                    |col_j, v| row_slice[col_j] = v,
917                );
918            }
919        }
920
921        apply_dense_bspline_extrapolation(data, knotview, degree, eval_kind, &mut basis_matrix)?;
922
923        Ok(basis_matrix)
924    }
925}
926
927pub(crate) struct SparseStorage;
928
929impl BasisStorage for SparseStorage {
930    type Output = SparseColMat<usize, f64>;
931
932    fn build(
933        data: ArrayView1<f64>,
934        knotview: ArrayView1<f64>,
935        degree: usize,
936        eval_kind: BasisEvalKind,
937        num_basis_functions: usize,
938        support: usize,
939        use_parallel: bool,
940    ) -> Result<Self::Output, BasisError> {
941        let nrows = data.len();
942        let left = knotview[degree];
943        let right = knotview[num_basis_functions];
944        let needs_extrapolation = has_clamped_bspline_boundaries(knotview, degree)
945            && data.iter().any(|&x| x < left || x > right);
946        if needs_extrapolation {
947            let dense = DenseStorage::build(
948                data,
949                knotview,
950                degree,
951                eval_kind,
952                num_basis_functions,
953                support,
954                use_parallel,
955            )?;
956            return Sparse::from_dense(dense);
957        }
958
959        let triplets: Vec<Triplet<usize, usize, f64>> =
960            if let (true, Some(data_slice)) = (use_parallel, data.as_slice()) {
961                const CHUNK_SIZE: usize = 1024;
962                let triplet_chunks: Vec<Vec<Triplet<usize, usize, f64>>> = data_slice
963                    .par_chunks(CHUNK_SIZE)
964                    .enumerate()
965                    .map_init(
966                        || (BasisEvalScratch::new(degree), vec![0.0; support]),
967                        |(scratch, values), (chunk_idx, chunk)| {
968                            let baserow = chunk_idx * CHUNK_SIZE;
969                            let mut local = Vec::with_capacity(chunk.len().saturating_mul(support));
970                            for (i, &x) in chunk.iter().enumerate() {
971                                let row_i = baserow + i;
972                                evaluate_bsplinerow_entries(
973                                    x,
974                                    degree,
975                                    knotview,
976                                    eval_kind,
977                                    num_basis_functions,
978                                    scratch,
979                                    values,
980                                    |col_j, v| local.push(Triplet::new(row_i, col_j, v)),
981                                );
982                            }
983                            local
984                        },
985                    )
986                    .collect();
987
988                let mut flattened = Vec::with_capacity(nrows.saturating_mul(support));
989                for mut chunk in triplet_chunks {
990                    flattened.append(&mut chunk);
991                }
992                flattened
993            } else {
994                let mut scratch = BasisEvalScratch::new(degree);
995                let mut values = vec![0.0; support];
996                let mut triplets = Vec::with_capacity(nrows.saturating_mul(support));
997
998                for (row_i, &x) in data.iter().enumerate() {
999                    evaluate_bsplinerow_entries(
1000                        x,
1001                        degree,
1002                        knotview,
1003                        eval_kind,
1004                        num_basis_functions,
1005                        &mut scratch,
1006                        &mut values,
1007                        |col_j, v| triplets.push(Triplet::new(row_i, col_j, v)),
1008                    );
1009                }
1010
1011                triplets
1012            };
1013
1014        SparseColMat::try_new_from_triplets(nrows, num_basis_functions, &triplets)
1015            .map_err(|err| BasisError::SparseCreation(format!("{err:?}")))
1016    }
1017}
1018
1019pub(crate) fn generate_basis_internal<S: BasisStorage>(
1020    data: ArrayView1<f64>,
1021    knotview: ArrayView1<f64>,
1022    degree: usize,
1023    eval_kind: BasisEvalKind,
1024) -> Result<S::Output, BasisError> {
1025    let num_basis_functions = knotview.len().saturating_sub(degree + 1);
1026    let support = degree + 1;
1027    // Parallel dispatch heuristic:
1028    // Lower degrees have cheaper per-row evaluation and need larger batches to
1029    // amortize Rayon scheduling overhead. Cubic+ rows are costlier, so parallel
1030    // wins earlier.
1031    let par_threshold = match degree {
1032        0 | 1 => 512,
1033        2 | 3 => 128,
1034        _ => 64,
1035    };
1036    let use_parallel = data.len() >= par_threshold && data.as_slice().is_some();
1037    S::build(
1038        data,
1039        knotview,
1040        degree,
1041        eval_kind,
1042        num_basis_functions,
1043        support,
1044        use_parallel,
1045    )
1046}
1047
1048/// Returns true if the B-spline basis should be built in sparse form based on density.
1049pub fn should_use_sparse_basis(num_basis_cols: usize, degree: usize, dim: usize) -> bool {
1050    if num_basis_cols == 0 {
1051        return false;
1052    }
1053
1054    let support_perrow = (degree + 1).saturating_pow(dim as u32) as f64;
1055    let density = support_perrow / num_basis_cols as f64;
1056
1057    density < 0.20 && num_basis_cols > 32
1058}
1059
1060/// Creates a penalty matrix `S` for a B-spline basis from a difference matrix `D`.
1061/// The penalty is of the form `S = D' * D`, penalizing the squared `order`-th
1062/// differences of the spline coefficients. This is the core of P-splines.
1063///
1064/// This function supports both uniform knots (using ordinary differences) and
1065/// non-uniform knots (using divided differences), which is critical for
1066/// correctly penalizing curvature when knots are irregularly spaced (e.g. quantiles).
1067///
1068/// # Arguments
1069/// * `num_basis_functions`: The number of basis functions (i.e., columns in the basis matrix).
1070/// * `order`: The order of the difference penalty (e.g., 2 for second differences).
1071/// * `greville_abscissae`: Optional Greville abscissae for divided differences.
1072///   If `None`, assumes uniform knots and uses ordinary integer differences.
1073///   If `Some`, uses divided differences scaled by the inverse of the knot spans.
1074///
1075/// # Returns
1076/// A square `Array2<f64>` of shape `[num_basis, num_basis]` representing the penalty `S`.
1077pub fn create_difference_penalty_matrix(
1078    num_basis_functions: usize,
1079    order: usize,
1080    greville_abscissae: Option<ArrayView1<f64>>,
1081) -> Result<Array2<f64>, BasisError> {
1082    if order == 0 || order >= num_basis_functions {
1083        return Err(BasisError::InvalidPenaltyOrder {
1084            order,
1085            num_basis: num_basis_functions,
1086        });
1087    }
1088
1089    if let Some(g) = greville_abscissae
1090        && g.len() != num_basis_functions
1091    {
1092        crate::bail_dim_basis!(
1093            "Greville abscissae length {} does not match num_basis_functions {}",
1094            g.len(),
1095            num_basis_functions
1096        );
1097    }
1098
1099    // Start with the identity matrix
1100    let mut d = Array2::<f64>::eye(num_basis_functions);
1101
1102    // Apply the differencing operation `order` times.
1103    // Each `diff` reduces the number of rows by 1.
1104    for o in 1..=order {
1105        // Calculate the difference between adjacent rows: D^{(o)} = Delta * D^{(o-1)}
1106        d = &d.slice(s![1.., ..]) - &d.slice(s![..-1, ..]);
1107
1108        // If using non-uniform knots, apply divided difference scaling:
1109        // D^{(o)}_i = D^{(o)}_i / (xi_{i+o} - xi_i)
1110        //
1111        // The raw divided-difference divisor `g[i+o] - g[i]` carries the units
1112        // of the covariate, so a pure rescaling `g -> c*g` (a change of physical
1113        // units for `x`) would multiply every divisor by `c` and hence scale the
1114        // resulting penalty `S = DᵀD` by `c^(-2*order)`. That makes the smooth
1115        // NOT scale-equivariant: REML would select a different `lambda` and the
1116        // fit would drift purely from the abscissa magnitude (#1364). The
1117        // divided difference's *purpose* is to weight each row by the relative
1118        // local span (so non-uniform knots approximate a derivative penalty),
1119        // which is an inherently dimensionless notion. We therefore normalize
1120        // each order's spans by their geometric-mean span at that order, so the
1121        // divisor is a unitless local/typical-span ratio: invariant to a global
1122        // rescaling of `x` and identically `1` for uniform knots (recovering the
1123        // plain integer-difference penalty).
1124        if let Some(g) = greville_abscissae {
1125            let nrows = d.nrows();
1126            let mut log_span_sum = 0.0_f64;
1127            for i in 0..nrows {
1128                let span = g[i + o] - g[i];
1129                if span.abs() <= KNOT_SPAN_DEGENERACY_FLOOR {
1130                    return Err(BasisError::InvalidKnotVector(format!(
1131                        "singular divided-difference span at order {o}, row {i}: Greville abscissae g[{}]={:.6e} and g[{i}]={:.6e} collapse",
1132                        i + o,
1133                        g[i + o],
1134                        g[i]
1135                    )));
1136                }
1137                log_span_sum += span.abs().ln();
1138            }
1139            // Geometric mean of the spans at this order; scales as `c` under
1140            // `g -> c*g`, so dividing each span by it cancels the units exactly.
1141            let ref_span = (log_span_sum / nrows as f64).exp();
1142            for i in 0..nrows {
1143                let span = (g[i + o] - g[i]) / ref_span;
1144                let mut row = d.row_mut(i);
1145                row /= span;
1146            }
1147        }
1148    }
1149
1150    // The penalty matrix S = D' * D
1151    let s = fast_ata(&d);
1152    Ok(s)
1153}
1154
1155pub(crate) fn bspline_raw_column_count(
1156    knots: &Array1<f64>,
1157    degree: usize,
1158    periodic: Option<(f64, f64, usize)>,
1159) -> Result<usize, String> {
1160    if let Some((_, _, num_basis)) = periodic {
1161        if num_basis <= degree {
1162            return Err(format!(
1163                "streaming cyclic B-spline basis requires more basis functions ({num_basis}) than degree ({degree})"
1164            ));
1165        }
1166        return Ok(num_basis);
1167    }
1168    knots
1169        .len()
1170        .checked_sub(degree + 1)
1171        .filter(|&p| p > 0)
1172        .ok_or_else(|| {
1173            format!(
1174                "streaming B-spline knots length {} is too short for degree {}",
1175                knots.len(),
1176                degree
1177            )
1178        })
1179}
1180
1181pub(crate) fn bspline_raw_row_chunk(
1182    data: ArrayView1<'_, f64>,
1183    knots: ArrayView1<'_, f64>,
1184    degree: usize,
1185    periodic: Option<(f64, f64, usize)>,
1186    start: usize,
1187    end: usize,
1188) -> Result<Array2<f64>, BasisError> {
1189    if start > end || end > data.len() {
1190        crate::bail_dim_basis!(
1191            "B-spline row chunk [{start}, {end}) is out of bounds for {} rows",
1192            data.len()
1193        );
1194    }
1195    let chunk = data.slice(s![start..end]);
1196    if let Some((domain_start, period, num_basis)) = periodic {
1197        if period <= 0.0 {
1198            crate::bail_invalid_basis!("periodic B-spline period must be positive, got {period}");
1199        }
1200        let wrapped = chunk.mapv(|x| wrap_to_period(x, domain_start, period));
1201        let (extended, _) = create_basis::<Dense>(
1202            wrapped.view(),
1203            KnotSource::Provided(knots),
1204            degree,
1205            BasisOptions::value(),
1206        )?;
1207        let mut cyclic = Array2::<f64>::zeros((chunk.len(), num_basis));
1208        for i in 0..extended.nrows() {
1209            for j in 0..extended.ncols() {
1210                cyclic[[i, j % num_basis]] += extended[[i, j]];
1211            }
1212        }
1213        Ok(cyclic)
1214    } else {
1215        let (basis, _) = create_basis::<Dense>(
1216            chunk,
1217            KnotSource::Provided(knots),
1218            degree,
1219            BasisOptions::value(),
1220        )?;
1221        Ok((*basis).clone())
1222    }
1223}
1224
1225/// Selects Greville abscissae for difference-penalty scaling.
1226///
1227/// The classical P-spline integer-difference penalty `D'D` penalizes the squared
1228/// `m`-th differences of the *coefficients*. Those differences represent the
1229/// squared `m`-th derivative of the *function* — with the correct polynomial null
1230/// space `{1, x, …, x^{m-1}}` — only when the basis has *evenly spaced Greville
1231/// abscissae*, because a coefficient sequence that is linear in its index then
1232/// reproduces a function linear in `x` (B-spline linear precision, `Σ ξ_i B_i(x)
1233/// = x`).
1234///
1235/// Equally spaced *breakpoints* are **not** sufficient. gam's B-splines are
1236/// clamped (boundary knots repeated `degree + 1` times), so even on a uniform
1237/// interior grid the Greville abscissae `ξ_i = (1/m)·Σ_{k=1}^{degree} t_{i+k}`
1238/// cluster toward the ends. With such a basis the integer-difference null space
1239/// is a *rotated approximation* of the polynomial space rather than the exact
1240/// `{1, x, …}`. That tilts the direction REML shrinks toward as `λ → ∞`, so the
1241/// recovered surface is biased and the selected smoothing parameters land off
1242/// the true optimum (e.g. anisotropic tensor `te`/`ti` recovery degrades).
1243///
1244/// We therefore gate the integer-difference fast path on uniformity of the
1245/// **Greville abscissae** and otherwise return them, so divided-difference
1246/// scaling in [`create_difference_penalty_matrix`] restores the exact polynomial
1247/// null space for any knot geometry (clamped, quantile, or otherwise). When the
1248/// abscissae are already uniform (e.g. a non-clamped Eilers–Marx grid), the
1249/// divided differences reduce to the integer differences up to an overall scale,
1250/// so returning `None` there is exact and cheaper.
1251pub fn penalty_greville_abscissae_for_knots(
1252    knot_vector: &Array1<f64>,
1253    degree: usize,
1254) -> Result<Option<Array1<f64>>, BasisError> {
1255    // Degenerate / too-short knot vectors have no meaningful divided-difference
1256    // scaling (and `compute_greville_abscissae` rejects them); fall back to the
1257    // plain integer-difference penalty exactly as before.
1258    let greville = match compute_greville_abscissae(knot_vector, degree) {
1259        Ok(g) => g,
1260        Err(_) => return Ok(None),
1261    };
1262    if is_uniformly_spaced_sequence(greville.view()) {
1263        Ok(None)
1264    } else {
1265        Ok(Some(greville))
1266    }
1267}
1268
1269/// True when the entries of `values` are (numerically) evenly spaced. Used to
1270/// decide whether classical integer-difference penalties coincide with the
1271/// geometry-correct divided-difference penalty for a basis.
1272pub(crate) fn is_uniformly_spaced_sequence(values: ArrayView1<'_, f64>) -> bool {
1273    let n = values.len();
1274    if n <= 2 {
1275        return true;
1276    }
1277    let span = (values[n - 1] - values[0]).abs().max(1.0);
1278    let h0 = values[1] - values[0];
1279    for i in 2..n {
1280        let hi = values[i] - values[i - 1];
1281        if (hi - h0).abs() > 1e-8 * span {
1282            return false;
1283        }
1284    }
1285    true
1286}