Skip to main content

ferrolearn_preprocess/
spline_transformer.rs

1//! Spline transformer: generate B-spline basis functions for each feature.
2//!
3//! [`SplineTransformer`] expands each input feature into a set of B-spline
4//! basis columns. This is a nonlinear feature expansion technique that
5//! represents each feature as a combination of piecewise polynomial functions.
6//!
7//! # Knot Placement
8//!
9//! - [`KnotStrategy::Uniform`] — knots are evenly spaced between min and max.
10//! - [`KnotStrategy::Quantile`] — knots are placed at quantiles of the data.
11//!
12//! ## REQ status
13//!
14//! Translation target: scikit-learn 1.5.2 `class SplineTransformer`
15//! (`sklearn/preprocessing/_polynomial.py:580`). Tracking: #1331.
16//! Each REQ is BINARY — SHIPPED (impl + non-test consumer + tests + green
17//! verification) or NOT-STARTED (with a concrete open blocker).
18//!
19//! | REQ | Scope | Status | Evidence / Blocker |
20//! |-----|-------|--------|--------------------|
21//! | REQ-1 | Output dimensions (`n_knots+degree-1` cols/feature) + B-spline structural properties (partition-of-unity, non-negativity) | SHIPPED | [`FittedSplineTransformer::transform`]; sklearn `n_splines` `_polynomial.py:875`; tests `green_guard_column_count_per_feature` / `_partition_of_unity` / `_non_negativity` |
22//! | REQ-2 | Uniform-knot basis VALUE parity — EXTENDED edge-spacing knots + scipy `BSpline` design matrix | SHIPPED | [`FittedSplineTransformer`] knot construction matches sklearn `_polynomial.py:908-923` + `:925-940`; verified across degree∈{1,2,3}, multi-feature, both base endpoints in `tests/divergence_spline_transformer.rs` (was DIV-1 #1332) |
23//! | REQ-3 | `extrapolation` param: DEFAULT `constant` (clamp out-of-range to boundary basis) + NaN/Inf reject at fit/transform | SHIPPED (Constant default + finiteness); other modes NOT-STARTED | [`Extrapolation::Constant`] is the default; [`FittedSplineTransformer::transform`] clamps each value to `[xmin, xmax]` before evaluating the basis (mirrors sklearn `_polynomial.py:721` default + `:1059-1087` constant clamp); fit/transform reject non-finite input (sklearn `_validate_data` `:833-839`). Tests `divergence_extrapolation_constant_default_degree{1,2,3}` + `divergence_nan_input_must_error` in `tests/divergence_spline_transformer_extrapolation.rs`. Modes `linear`/`continue`/`periodic`/`error` remain NOT-STARTED — blocker #1333 |
24//! | REQ-4 | `include_bias` param (drop one column when `false`) | NOT-STARTED | no param; sklearn `_polynomial.py:635,942` — blocker #1334 |
25//! | REQ-5 | Quantile knots via `np.percentile`-exact (ferrolearn uses linear-interp percentile) | NOT-STARTED | `spline_transformer.rs` Quantile path; sklearn `_polynomial.py:747-753` — blocker #1335 |
26//! | REQ-6 | Error/parameter contracts (`n_samples<2`, `n_knots<2`, transform ncols, unfitted) | SHIPPED | [`SplineTransformer::fit`]; `degree==0` is now ALLOWED (piecewise-constant), matching sklearn `_parameter_constraints` `degree: Interval(Integral, 0, None, closed="left")` (`_polynomial.py:705`). `n_knots<2` rejection matches `n_knots: Interval(Integral, 2, None, closed="left")` (`:704`). The `n_samples>=2` requirement also MATCHES sklearn (`_validate_data(..., ensure_min_samples=2)`, `_polynomial.py:830`) — NOT a divergence. (blocker #1336) |
27//! | REQ-7 | `sparse_output` + `order` params | NOT-STARTED | no params; sklearn `_polynomial.py:716-730` — blocker #1337 |
28//! | REQ-8 | `sample_weight` (weighted knot placement) | NOT-STARTED | sklearn `fit(X, y=None, sample_weight=None)` `_polynomial.py:811` — blocker #1338 |
29//! | REQ-9 | `get_feature_names_out` (`{feat}_sp_{j}`) + `bsplines_`/`n_features_out_` fitted attrs | NOT-STARTED | sklearn `_polynomial.py:781-809,942` — blocker #1339 |
30//! | REQ-10 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` binding — blocker #1340 |
31//! | REQ-11 | ferray substrate | NOT-STARTED | dense `Array2` only — blocker #1341 |
32
33use ferrolearn_core::error::FerroError;
34use ferrolearn_core::traits::{Fit, FitTransform, Transform};
35use ndarray::Array2;
36use num_traits::Float;
37
38// ---------------------------------------------------------------------------
39// KnotStrategy
40// ---------------------------------------------------------------------------
41
42/// Strategy for placing knots in the spline transformer.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum KnotStrategy {
45    /// Knots are evenly spaced between the min and max of each feature.
46    Uniform,
47    /// Knots are placed at quantiles of the data.
48    Quantile,
49}
50
51// ---------------------------------------------------------------------------
52// Extrapolation
53// ---------------------------------------------------------------------------
54
55/// How to handle values outside the base knot interval `[xmin, xmax]`.
56///
57/// Mirrors scikit-learn's `extrapolation` parameter
58/// (`sklearn/preprocessing/_polynomial.py:707-709`,`:721`). The default is
59/// [`Extrapolation::Constant`] (sklearn's `__init__` default
60/// `extrapolation="constant"`, `_polynomial.py:721`).
61///
62/// Only [`Extrapolation::Constant`] is currently implemented. The remaining
63/// sklearn modes (`linear`, `continue`, `periodic`, `error`) are NOT-STARTED
64/// and surface a [`FerroError::InvalidParameter`] from the transform.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum Extrapolation {
67    /// Clamp out-of-range values to the boundary spline basis: for `x < xmin`
68    /// the basis is evaluated at `xmin`, for `x > xmax` at `xmax`. This is the
69    /// DEFAULT, matching sklearn `extrapolation="constant"`
70    /// (`_polynomial.py:721` default; the constant clamp at `:1059-1087` sets
71    /// the out-of-range row's first/last `degree` basis columns to the boundary
72    /// basis values `f_min[:degree]` / `f_max[-degree:]` — equivalent to
73    /// clamping `x` to `[xmin, xmax]` before evaluating the basis, since the
74    /// columns beyond `degree` are zero at the boundary).
75    #[default]
76    Constant,
77    /// Linearly continue the boundary splines (sklearn `"linear"`,
78    /// `_polynomial.py:1089-1123`). NOT-STARTED.
79    Linear,
80    /// Pass scipy `extrapolate=True` (sklearn `"continue"`). NOT-STARTED.
81    Continue,
82    /// Periodic splines (sklearn `"periodic"`). NOT-STARTED.
83    Periodic,
84    /// Raise on out-of-range input (sklearn `"error"`,
85    /// `_polynomial.py:1047-1058`). NOT-STARTED.
86    Error,
87}
88
89// ---------------------------------------------------------------------------
90// SplineTransformer (unfitted)
91// ---------------------------------------------------------------------------
92
93/// An unfitted spline transformer.
94///
95/// Calling [`Fit::fit`] computes the knot positions and returns a
96/// [`FittedSplineTransformer`] that generates B-spline basis functions.
97///
98/// # Parameters
99///
100/// - `n_knots` — number of interior knots (default 5).
101/// - `degree` — degree of the B-spline (default 3, i.e. cubic).
102/// - `knots` — knot placement strategy (default `Uniform`).
103///
104/// The number of output columns per feature is `n_knots + degree - 1`.
105///
106/// # Examples
107///
108/// ```
109/// use ferrolearn_preprocess::spline_transformer::{SplineTransformer, KnotStrategy};
110/// use ferrolearn_core::traits::{Fit, Transform};
111/// use ndarray::array;
112///
113/// let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
114/// let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
115/// let fitted = st.fit(&x, &()).unwrap();
116/// let out = fitted.transform(&x).unwrap();
117/// // 5 + 3 - 1 = 7 basis columns per feature
118/// assert_eq!(out.ncols(), 7);
119/// ```
120#[must_use]
121#[derive(Debug, Clone)]
122pub struct SplineTransformer<F> {
123    /// Number of interior knots.
124    n_knots: usize,
125    /// Degree of the B-spline.
126    degree: usize,
127    /// Knot placement strategy.
128    knots: KnotStrategy,
129    /// Out-of-range extrapolation policy (default [`Extrapolation::Constant`]).
130    extrapolation: Extrapolation,
131    _marker: std::marker::PhantomData<F>,
132}
133
134impl<F: Float + Send + Sync + 'static> SplineTransformer<F> {
135    /// Create a new `SplineTransformer` with the DEFAULT extrapolation policy
136    /// ([`Extrapolation::Constant`], matching sklearn's `extrapolation="constant"`
137    /// default, `_polynomial.py:721`).
138    pub fn new(n_knots: usize, degree: usize, knots: KnotStrategy) -> Self {
139        Self::with_extrapolation(n_knots, degree, knots, Extrapolation::Constant)
140    }
141
142    /// Create a new `SplineTransformer` with an explicit extrapolation policy.
143    pub fn with_extrapolation(
144        n_knots: usize,
145        degree: usize,
146        knots: KnotStrategy,
147        extrapolation: Extrapolation,
148    ) -> Self {
149        Self {
150            n_knots,
151            degree,
152            knots,
153            extrapolation,
154            _marker: std::marker::PhantomData,
155        }
156    }
157
158    /// Return the number of interior knots.
159    #[must_use]
160    pub fn n_knots(&self) -> usize {
161        self.n_knots
162    }
163
164    /// Return the B-spline degree.
165    #[must_use]
166    pub fn degree(&self) -> usize {
167        self.degree
168    }
169
170    /// Return the knot placement strategy.
171    #[must_use]
172    pub fn knot_strategy(&self) -> KnotStrategy {
173        self.knots
174    }
175
176    /// Return the out-of-range extrapolation policy.
177    #[must_use]
178    pub fn extrapolation(&self) -> Extrapolation {
179        self.extrapolation
180    }
181}
182
183impl<F: Float + Send + Sync + 'static> Default for SplineTransformer<F> {
184    fn default() -> Self {
185        Self::new(5, 3, KnotStrategy::Uniform)
186    }
187}
188
189// ---------------------------------------------------------------------------
190// FittedSplineTransformer
191// ---------------------------------------------------------------------------
192
193/// A fitted spline transformer holding per-feature knot positions.
194///
195/// Created by calling [`Fit::fit`] on a [`SplineTransformer`].
196#[derive(Debug, Clone)]
197pub struct FittedSplineTransformer<F> {
198    /// Full knot vector per feature (including boundary knots with multiplicity).
199    knot_vectors: Vec<Vec<F>>,
200    /// Per-feature base-interval lower bound (`xmin = knots[degree]`, the fit min).
201    /// Used to clamp out-of-range values under [`Extrapolation::Constant`].
202    xmin: Vec<F>,
203    /// Per-feature base-interval upper bound (`xmax = knots[n_basis]`, the fit max).
204    xmax: Vec<F>,
205    /// Degree of the B-spline.
206    degree: usize,
207    /// Number of basis functions per feature.
208    n_basis: usize,
209    /// Out-of-range extrapolation policy.
210    extrapolation: Extrapolation,
211}
212
213impl<F: Float + Send + Sync + 'static> FittedSplineTransformer<F> {
214    /// Return the knot vectors.
215    #[must_use]
216    pub fn knot_vectors(&self) -> &[Vec<F>] {
217        &self.knot_vectors
218    }
219
220    /// Return the number of basis functions per feature.
221    #[must_use]
222    pub fn n_basis_per_feature(&self) -> usize {
223        self.n_basis
224    }
225
226    /// Return the total number of output columns.
227    #[must_use]
228    pub fn n_output_features(&self) -> usize {
229        self.knot_vectors.len() * self.n_basis
230    }
231
232    /// Return the out-of-range extrapolation policy.
233    #[must_use]
234    pub fn extrapolation(&self) -> Extrapolation {
235        self.extrapolation
236    }
237}
238
239/// Reject non-finite (NaN/Inf) entries in `x`, mirroring sklearn's
240/// `_validate_data(..., force_all_finite=True)` (`_polynomial.py:833-839`),
241/// which raises `ValueError("Input X contains NaN.")` / infinity.
242fn reject_non_finite<F: Float>(x: &Array2<F>, context: &str) -> Result<(), FerroError> {
243    if x.iter().any(|v| !v.is_finite()) {
244        return Err(FerroError::InvalidParameter {
245            name: "X".into(),
246            reason: format!("Input X contains NaN or infinity. ({context})"),
247        });
248    }
249    Ok(())
250}
251
252// ---------------------------------------------------------------------------
253// B-spline evaluation (Cox-de Boor recursion)
254// ---------------------------------------------------------------------------
255
256/// Evaluate all B-spline basis functions at a given value `x` using the
257/// Cox-de Boor recursion.
258///
259/// `knots` is the full knot vector of length `n_basis + degree + 1`.
260/// Returns a vector of length `n_basis` containing the basis values.
261fn bspline_basis<F: Float>(x: F, knots: &[F], degree: usize, n_basis: usize) -> Vec<F> {
262    // Start with degree-0 basis functions
263    let n_intervals = knots.len() - 1;
264    let mut basis = vec![F::zero(); n_intervals];
265
266    // Degree 0: indicator functions using half-open intervals [t_i, t_{i+1}).
267    // Special case: with sklearn's EXTENDED knot vector the base interval is
268    // `[knots[degree], knots[n_basis]]` (knots[n_basis] is the right end of the
269    // base support, NOT the rightmost extended knot). scipy's `design_matrix`
270    // includes the right endpoint of the base interval, so a value at
271    // `x == knots[n_basis]` must be evaluated as the limit from the left rather
272    // than returning all-zero under a naive half-open `t_i <= x < t_{i+1}`.
273    // Activate the last non-degenerate interval that LIES AT OR BEFORE the base
274    // right endpoint so the Cox-de Boor recursion propagates a non-zero value.
275    let base_right = knots[n_basis];
276    if x >= base_right {
277        // Find the last interval ending at the base right endpoint with
278        // non-zero width and activate it (the closed-right base span).
279        let mut found = false;
280        for i in (0..n_intervals).rev() {
281            if knots[i + 1] <= base_right && knots[i] < knots[i + 1] {
282                basis[i] = F::one();
283                found = true;
284                break;
285            }
286        }
287        // Fallback: if all such intervals are degenerate, activate the last one
288        if !found {
289            basis[n_intervals - 1] = F::one();
290        }
291    } else {
292        for i in 0..n_intervals {
293            // Half-open: [t_i, t_{i+1})
294            basis[i] = if x >= knots[i] && x < knots[i + 1] {
295                F::one()
296            } else {
297                F::zero()
298            };
299        }
300    }
301
302    // Build up to the desired degree
303    for d in 1..=degree {
304        let n_current = n_intervals - d;
305        let mut new_basis = vec![F::zero(); n_current];
306        for i in 0..n_current {
307            let denom1 = knots[i + d] - knots[i];
308            let denom2 = knots[i + d + 1] - knots[i + 1];
309
310            let left = if denom1 > F::zero() {
311                (x - knots[i]) / denom1 * basis[i]
312            } else {
313                F::zero()
314            };
315
316            let right = if denom2 > F::zero() {
317                (knots[i + d + 1] - x) / denom2 * basis[i + 1]
318            } else {
319                F::zero()
320            };
321
322            new_basis[i] = left + right;
323        }
324        basis = new_basis;
325    }
326
327    // Truncate or pad to n_basis
328    basis.truncate(n_basis);
329    while basis.len() < n_basis {
330        basis.push(F::zero());
331    }
332
333    basis
334}
335
336// ---------------------------------------------------------------------------
337// Trait implementations
338// ---------------------------------------------------------------------------
339
340impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SplineTransformer<F> {
341    type Fitted = FittedSplineTransformer<F>;
342    type Error = FerroError;
343
344    /// Fit by computing knot positions for each feature.
345    ///
346    /// # Errors
347    ///
348    /// - [`FerroError::InsufficientSamples`] if the input has fewer than 2 rows.
349    /// - [`FerroError::InvalidParameter`] if `n_knots` < 2.
350    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSplineTransformer<F>, FerroError> {
351        // sklearn `_validate_data(..., force_all_finite=True)` rejects NaN/Inf at
352        // fit (`_polynomial.py:833-839`). Match that contract.
353        reject_non_finite(x, "SplineTransformer::fit")?;
354
355        let n_samples = x.nrows();
356        if n_samples < 2 {
357            return Err(FerroError::InsufficientSamples {
358                required: 2,
359                actual: n_samples,
360                context: "SplineTransformer::fit".into(),
361            });
362        }
363        if self.n_knots < 2 {
364            return Err(FerroError::InvalidParameter {
365                name: "n_knots".into(),
366                reason: "n_knots must be at least 2".into(),
367            });
368        }
369
370        let n_features = x.ncols();
371        let n_basis = self.n_knots + self.degree - 1;
372        let mut knot_vectors = Vec::with_capacity(n_features);
373        let mut xmin = Vec::with_capacity(n_features);
374        let mut xmax = Vec::with_capacity(n_features);
375
376        for j in 0..n_features {
377            let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
378            col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
379
380            let min_val = col_vals[0];
381            let max_val = col_vals[col_vals.len() - 1];
382
383            // Base-interval boundaries used by `Extrapolation::Constant`: a value
384            // below `xmin`/above `xmax` is clamped to the boundary before the
385            // basis is evaluated (sklearn `_polynomial.py:1059-1087`). These are
386            // the fit min/max, equal to `knots[degree]`/`knots[n_basis]` in the
387            // extended knot vector.
388            xmin.push(min_val);
389            xmax.push(max_val);
390
391            // Compute interior knots
392            let interior_knots: Vec<F> = match self.knots {
393                KnotStrategy::Uniform => (0..self.n_knots)
394                    .map(|i| {
395                        min_val
396                            + (max_val - min_val) * F::from(i).unwrap()
397                                / F::from(self.n_knots - 1).unwrap()
398                    })
399                    .collect(),
400                KnotStrategy::Quantile => {
401                    let n = col_vals.len();
402                    (0..self.n_knots)
403                        .map(|i| {
404                            let frac = F::from(i).unwrap()
405                                / F::from(self.n_knots - 1).unwrap_or_else(F::one);
406                            let pos = frac * F::from(n.saturating_sub(1)).unwrap();
407                            let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
408                            let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
409                            let f = pos - F::from(lo).unwrap();
410                            col_vals[lo] * (F::one() - f) + col_vals[hi] * f
411                        })
412                        .collect()
413                }
414            };
415
416            // Build full knot vector using sklearn's EXTENDED edge-spacing
417            // construction (`_polynomial.py:908-923`). sklearn explicitly
418            // REJECTS the clamped/`np.tile` repeated-boundary construction
419            // (`:898-906`, Eilers & Marx) in favour of reusing the spacing of
420            // the two first/last base knots:
421            //   dist_min = base[1] - base[0]; dist_max = base[-1] - base[-2]
422            //   left  = linspace(base[0] - degree*dist_min, base[0] - dist_min, degree)
423            //   right = linspace(base[-1] + dist_max, base[-1] + degree*dist_max, degree)
424            //   knots = [left, base, right]
425            // numpy `linspace(a, b, num)` is inclusive of both endpoints.
426            let base = &interior_knots;
427            let nb = base.len();
428            let dist_min = base[1] - base[0];
429            let dist_max = base[nb - 1] - base[nb - 2];
430            let degree = self.degree;
431            let deg_f = F::from(degree).unwrap_or_else(F::one);
432
433            // numpy linspace with `num` inclusive endpoints. For num == 0 numpy
434            // returns an empty array; for num == 1 just [a]; for num >= 2 it
435            // includes both a and b. num == 0 occurs for degree == 0 (no
436            // edge-extension knots — the knot vector is the base knots alone).
437            let linspace = |a: F, b: F, num: usize| -> Vec<F> {
438                if num == 0 {
439                    return Vec::new();
440                }
441                if num == 1 {
442                    return vec![a];
443                }
444                let denom = F::from(num - 1).unwrap_or_else(F::one);
445                (0..num)
446                    .map(|i| {
447                        let t = F::from(i).unwrap_or_else(F::zero) / denom;
448                        a + (b - a) * t
449                    })
450                    .collect()
451            };
452
453            let left = linspace(base[0] - deg_f * dist_min, base[0] - dist_min, degree);
454            let right = linspace(
455                base[nb - 1] + dist_max,
456                base[nb - 1] + deg_f * dist_max,
457                degree,
458            );
459
460            let mut full_knots = Vec::with_capacity(left.len() + nb + right.len());
461            full_knots.extend_from_slice(&left);
462            full_knots.extend_from_slice(base);
463            full_knots.extend_from_slice(&right);
464
465            knot_vectors.push(full_knots);
466        }
467
468        Ok(FittedSplineTransformer {
469            knot_vectors,
470            xmin,
471            xmax,
472            degree: self.degree,
473            n_basis,
474            extrapolation: self.extrapolation,
475        })
476    }
477}
478
479impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSplineTransformer<F> {
480    type Output = Array2<F>;
481    type Error = FerroError;
482
483    /// Generate B-spline basis functions for each feature.
484    ///
485    /// # Errors
486    ///
487    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
488    /// from the number of features seen during fitting.
489    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
490        let n_features = self.knot_vectors.len();
491        if x.ncols() != n_features {
492            return Err(FerroError::ShapeMismatch {
493                expected: vec![x.nrows(), n_features],
494                actual: vec![x.nrows(), x.ncols()],
495                context: "FittedSplineTransformer::transform".into(),
496            });
497        }
498
499        // sklearn validates the transform input too (`_validate_data` in
500        // `transform`), rejecting NaN/Inf.
501        reject_non_finite(x, "FittedSplineTransformer::transform")?;
502
503        // Only `Constant` extrapolation is implemented. The other sklearn modes
504        // are NOT-STARTED — surface a clear error rather than emit wrong values.
505        match self.extrapolation {
506            Extrapolation::Constant => {}
507            Extrapolation::Linear
508            | Extrapolation::Continue
509            | Extrapolation::Periodic
510            | Extrapolation::Error => {
511                return Err(FerroError::InvalidParameter {
512                    name: "extrapolation".into(),
513                    reason: "only Extrapolation::Constant is implemented; \
514                             linear/continue/periodic/error are NOT-STARTED (blocker #1333)"
515                        .into(),
516                });
517            }
518        }
519
520        let n_samples = x.nrows();
521        let n_out = n_features * self.n_basis;
522        let mut out = Array2::zeros((n_samples, n_out));
523
524        for j in 0..n_features {
525            let knots = &self.knot_vectors[j];
526            let col_offset = j * self.n_basis;
527            let lo = self.xmin[j];
528            let hi = self.xmax[j];
529
530            for i in 0..n_samples {
531                // `Extrapolation::Constant`: clamp the value to the base interval
532                // `[xmin, xmax]` before evaluating the basis. At the boundary,
533                // only the first/last `degree` basis columns are non-zero, so the
534                // clamp reproduces sklearn's `f_min[:degree]` / `f_max[-degree:]`
535                // assignment (`_polynomial.py:1059-1087`). The clamp is a no-op
536                // for in-range values, preserving the verified in-range basis.
537                let raw = x[[i, j]];
538                let val = if raw < lo {
539                    lo
540                } else if raw > hi {
541                    hi
542                } else {
543                    raw
544                };
545                let basis_vals = bspline_basis(val, knots, self.degree, self.n_basis);
546                for (k, &bv) in basis_vals.iter().enumerate() {
547                    out[[i, col_offset + k]] = bv;
548                }
549            }
550        }
551
552        Ok(out)
553    }
554}
555
556/// Implement `Transform` on the unfitted transformer.
557impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SplineTransformer<F> {
558    type Output = Array2<F>;
559    type Error = FerroError;
560
561    /// Always returns an error — the transformer must be fitted first.
562    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
563        Err(FerroError::InvalidParameter {
564            name: "SplineTransformer".into(),
565            reason: "transformer must be fitted before calling transform; use fit() first".into(),
566        })
567    }
568}
569
570impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SplineTransformer<F> {
571    type FitError = FerroError;
572
573    /// Fit and transform in one step.
574    ///
575    /// # Errors
576    ///
577    /// Returns an error if fitting fails.
578    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
579        let fitted = self.fit(x, &())?;
580        fitted.transform(x)
581    }
582}
583
584// ---------------------------------------------------------------------------
585// Tests
586// ---------------------------------------------------------------------------
587
588#[cfg(test)]
589mod tests {
590    use super::*;
591    use approx::assert_abs_diff_eq;
592    use ndarray::array;
593
594    #[test]
595    fn test_spline_output_dimensions() {
596        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
597        let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
598        let fitted = st.fit(&x, &()).unwrap();
599        let out = fitted.transform(&x).unwrap();
600        // n_basis = n_knots + degree - 1 = 5 + 3 - 1 = 7
601        assert_eq!(out.ncols(), 7);
602        assert_eq!(out.nrows(), 5);
603    }
604
605    #[test]
606    fn test_spline_partition_of_unity() {
607        // B-spline basis functions should sum to 1 at any interior point
608        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
609        let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
610        let fitted = st.fit(&x, &()).unwrap();
611        let out = fitted.transform(&x).unwrap();
612        for i in 0..out.nrows() {
613            let row_sum: f64 = out.row(i).iter().sum();
614            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
615        }
616    }
617
618    #[test]
619    fn test_spline_non_negative() {
620        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
621        let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
622        let fitted = st.fit(&x, &()).unwrap();
623        let out = fitted.transform(&x).unwrap();
624        for v in &out {
625            assert!(*v >= -1e-10, "Basis value should be non-negative, got {v}");
626        }
627    }
628
629    #[test]
630    fn test_spline_quantile_knots() {
631        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Quantile);
632        let x = array![[0.0], [0.1], [0.2], [0.5], [1.0]];
633        let fitted = st.fit(&x, &()).unwrap();
634        let out = fitted.transform(&x).unwrap();
635        assert_eq!(out.ncols(), 7);
636        // Partition of unity should still hold
637        for i in 0..out.nrows() {
638            let row_sum: f64 = out.row(i).iter().sum();
639            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
640        }
641    }
642
643    #[test]
644    fn test_spline_multi_feature() {
645        let st = SplineTransformer::<f64>::new(3, 2, KnotStrategy::Uniform);
646        let x = array![[0.0, 10.0], [0.5, 15.0], [1.0, 20.0]];
647        let fitted = st.fit(&x, &()).unwrap();
648        let out = fitted.transform(&x).unwrap();
649        // n_basis per feature = 3 + 2 - 1 = 4, total = 2 * 4 = 8
650        assert_eq!(out.ncols(), 8);
651    }
652
653    #[test]
654    fn test_spline_fit_transform() {
655        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
656        let x = array![[0.0], [0.5], [1.0]];
657        let out = st.fit_transform(&x).unwrap();
658        assert_eq!(out.ncols(), 7);
659    }
660
661    #[test]
662    fn test_spline_insufficient_samples_error() {
663        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
664        let x = array![[1.0]];
665        assert!(st.fit(&x, &()).is_err());
666    }
667
668    #[test]
669    fn test_spline_too_few_knots_error() {
670        let st = SplineTransformer::<f64>::new(1, 3, KnotStrategy::Uniform);
671        let x = array![[0.0], [1.0]];
672        assert!(st.fit(&x, &()).is_err());
673    }
674
675    #[test]
676    fn test_spline_zero_degree_allowed() -> Result<(), FerroError> {
677        // sklearn allows degree==0 (piecewise-constant B-spline):
678        // `_parameter_constraints` `degree: Interval(Integral, 0, None,
679        // closed="left")` (`_polynomial.py:705`). degree==0 must fit, not error.
680        let st = SplineTransformer::<f64>::new(5, 0, KnotStrategy::Uniform);
681        let x = array![[0.0], [1.0]];
682        let fitted = st.fit(&x, &())?;
683        // n_basis = n_knots + degree - 1 = 5 + 0 - 1 = 4
684        let out = fitted.transform(&x)?;
685        assert_eq!(out.ncols(), 4);
686        Ok(())
687    }
688
689    #[test]
690    fn test_spline_shape_mismatch_error() {
691        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
692        let x_train = array![[0.0, 1.0], [0.5, 1.5]];
693        let fitted = st.fit(&x_train, &()).unwrap();
694        let x_bad = array![[0.0]];
695        assert!(fitted.transform(&x_bad).is_err());
696    }
697
698    #[test]
699    fn test_spline_unfitted_error() {
700        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
701        let x = array![[0.0]];
702        assert!(st.transform(&x).is_err());
703    }
704
705    #[test]
706    fn test_spline_default() {
707        let st = SplineTransformer::<f64>::default();
708        assert_eq!(st.n_knots(), 5);
709        assert_eq!(st.degree(), 3);
710        assert_eq!(st.knot_strategy(), KnotStrategy::Uniform);
711    }
712
713    #[test]
714    fn test_spline_degree1() {
715        // Linear splines: should produce piecewise linear basis
716        let st = SplineTransformer::<f64>::new(3, 1, KnotStrategy::Uniform);
717        let x = array![[0.0], [0.5], [1.0]];
718        let fitted = st.fit(&x, &()).unwrap();
719        let out = fitted.transform(&x).unwrap();
720        // n_basis = 3 + 1 - 1 = 3
721        assert_eq!(out.ncols(), 3);
722        // Partition of unity
723        for i in 0..out.nrows() {
724            let row_sum: f64 = out.row(i).iter().sum();
725            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
726        }
727    }
728}