ferrolearn_preprocess/spline_transformer.rs
1//! Spline transformer: generate B-spline basis functions for each feature.
2//!
3//! [`SplineTransformer`] expands each input feature into a set of B-spline
4//! basis columns. This is a nonlinear feature expansion technique that
5//! represents each feature as a combination of piecewise polynomial functions.
6//!
7//! # Knot Placement
8//!
9//! - [`KnotStrategy::Uniform`] — knots are evenly spaced between min and max.
10//! - [`KnotStrategy::Quantile`] — knots are placed at quantiles of the data.
11//!
12//! ## REQ status
13//!
14//! Translation target: scikit-learn 1.5.2 `class SplineTransformer`
15//! (`sklearn/preprocessing/_polynomial.py:580`). Tracking: #1331.
16//! Each REQ is BINARY — SHIPPED (impl + non-test consumer + tests + green
17//! verification) or NOT-STARTED (with a concrete open blocker).
18//!
19//! | REQ | Scope | Status | Evidence / Blocker |
20//! |-----|-------|--------|--------------------|
21//! | REQ-1 | Output dimensions (`n_knots+degree-1` cols/feature) + B-spline structural properties (partition-of-unity, non-negativity) | SHIPPED | [`FittedSplineTransformer::transform`]; sklearn `n_splines` `_polynomial.py:875`; tests `green_guard_column_count_per_feature` / `_partition_of_unity` / `_non_negativity` |
22//! | REQ-2 | Uniform-knot basis VALUE parity — EXTENDED edge-spacing knots + scipy `BSpline` design matrix | SHIPPED | [`FittedSplineTransformer`] knot construction matches sklearn `_polynomial.py:908-923` + `:925-940`; verified across degree∈{1,2,3}, multi-feature, both base endpoints in `tests/divergence_spline_transformer.rs` (was DIV-1 #1332) |
23//! | REQ-3 | `extrapolation` param: DEFAULT `constant` (clamp out-of-range to boundary basis) + NaN/Inf reject at fit/transform | SHIPPED (Constant default + finiteness); other modes NOT-STARTED | [`Extrapolation::Constant`] is the default; [`FittedSplineTransformer::transform`] clamps each value to `[xmin, xmax]` before evaluating the basis (mirrors sklearn `_polynomial.py:721` default + `:1059-1087` constant clamp); fit/transform reject non-finite input (sklearn `_validate_data` `:833-839`). Tests `divergence_extrapolation_constant_default_degree{1,2,3}` + `divergence_nan_input_must_error` in `tests/divergence_spline_transformer_extrapolation.rs`. Modes `linear`/`continue`/`periodic`/`error` remain NOT-STARTED — blocker #1333 |
24//! | REQ-4 | `include_bias` param (drop one column when `false`) | NOT-STARTED | no param; sklearn `_polynomial.py:635,942` — blocker #1334 |
25//! | REQ-5 | Quantile knots via `np.percentile`-exact (ferrolearn uses linear-interp percentile) | NOT-STARTED | `spline_transformer.rs` Quantile path; sklearn `_polynomial.py:747-753` — blocker #1335 |
26//! | REQ-6 | Error/parameter contracts (`n_samples<2`, `n_knots<2`, transform ncols, unfitted) | SHIPPED | [`SplineTransformer::fit`]; `degree==0` is now ALLOWED (piecewise-constant), matching sklearn `_parameter_constraints` `degree: Interval(Integral, 0, None, closed="left")` (`_polynomial.py:705`). `n_knots<2` rejection matches `n_knots: Interval(Integral, 2, None, closed="left")` (`:704`). The `n_samples>=2` requirement also MATCHES sklearn (`_validate_data(..., ensure_min_samples=2)`, `_polynomial.py:830`) — NOT a divergence. (blocker #1336) |
27//! | REQ-7 | `sparse_output` + `order` params | NOT-STARTED | no params; sklearn `_polynomial.py:716-730` — blocker #1337 |
28//! | REQ-8 | `sample_weight` (weighted knot placement) | NOT-STARTED | sklearn `fit(X, y=None, sample_weight=None)` `_polynomial.py:811` — blocker #1338 |
29//! | REQ-9 | `get_feature_names_out` (`{feat}_sp_{j}`) + `bsplines_`/`n_features_out_` fitted attrs | NOT-STARTED | sklearn `_polynomial.py:781-809,942` — blocker #1339 |
30//! | REQ-10 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` binding — blocker #1340 |
31//! | REQ-11 | ferray substrate | NOT-STARTED | dense `Array2` only — blocker #1341 |
32
33use ferrolearn_core::error::FerroError;
34use ferrolearn_core::traits::{Fit, FitTransform, Transform};
35use ndarray::Array2;
36use num_traits::Float;
37
38// ---------------------------------------------------------------------------
39// KnotStrategy
40// ---------------------------------------------------------------------------
41
42/// Strategy for placing knots in the spline transformer.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum KnotStrategy {
45 /// Knots are evenly spaced between the min and max of each feature.
46 Uniform,
47 /// Knots are placed at quantiles of the data.
48 Quantile,
49}
50
51// ---------------------------------------------------------------------------
52// Extrapolation
53// ---------------------------------------------------------------------------
54
55/// How to handle values outside the base knot interval `[xmin, xmax]`.
56///
57/// Mirrors scikit-learn's `extrapolation` parameter
58/// (`sklearn/preprocessing/_polynomial.py:707-709`,`:721`). The default is
59/// [`Extrapolation::Constant`] (sklearn's `__init__` default
60/// `extrapolation="constant"`, `_polynomial.py:721`).
61///
62/// Only [`Extrapolation::Constant`] is currently implemented. The remaining
63/// sklearn modes (`linear`, `continue`, `periodic`, `error`) are NOT-STARTED
64/// and surface a [`FerroError::InvalidParameter`] from the transform.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
66pub enum Extrapolation {
67 /// Clamp out-of-range values to the boundary spline basis: for `x < xmin`
68 /// the basis is evaluated at `xmin`, for `x > xmax` at `xmax`. This is the
69 /// DEFAULT, matching sklearn `extrapolation="constant"`
70 /// (`_polynomial.py:721` default; the constant clamp at `:1059-1087` sets
71 /// the out-of-range row's first/last `degree` basis columns to the boundary
72 /// basis values `f_min[:degree]` / `f_max[-degree:]` — equivalent to
73 /// clamping `x` to `[xmin, xmax]` before evaluating the basis, since the
74 /// columns beyond `degree` are zero at the boundary).
75 #[default]
76 Constant,
77 /// Linearly continue the boundary splines (sklearn `"linear"`,
78 /// `_polynomial.py:1089-1123`). NOT-STARTED.
79 Linear,
80 /// Pass scipy `extrapolate=True` (sklearn `"continue"`). NOT-STARTED.
81 Continue,
82 /// Periodic splines (sklearn `"periodic"`). NOT-STARTED.
83 Periodic,
84 /// Raise on out-of-range input (sklearn `"error"`,
85 /// `_polynomial.py:1047-1058`). NOT-STARTED.
86 Error,
87}
88
89// ---------------------------------------------------------------------------
90// SplineTransformer (unfitted)
91// ---------------------------------------------------------------------------
92
93/// An unfitted spline transformer.
94///
95/// Calling [`Fit::fit`] computes the knot positions and returns a
96/// [`FittedSplineTransformer`] that generates B-spline basis functions.
97///
98/// # Parameters
99///
100/// - `n_knots` — number of interior knots (default 5).
101/// - `degree` — degree of the B-spline (default 3, i.e. cubic).
102/// - `knots` — knot placement strategy (default `Uniform`).
103///
104/// The number of output columns per feature is `n_knots + degree - 1`.
105///
106/// # Examples
107///
108/// ```
109/// use ferrolearn_preprocess::spline_transformer::{SplineTransformer, KnotStrategy};
110/// use ferrolearn_core::traits::{Fit, Transform};
111/// use ndarray::array;
112///
113/// let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
114/// let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
115/// let fitted = st.fit(&x, &()).unwrap();
116/// let out = fitted.transform(&x).unwrap();
117/// // 5 + 3 - 1 = 7 basis columns per feature
118/// assert_eq!(out.ncols(), 7);
119/// ```
120#[must_use]
121#[derive(Debug, Clone)]
122pub struct SplineTransformer<F> {
123 /// Number of interior knots.
124 n_knots: usize,
125 /// Degree of the B-spline.
126 degree: usize,
127 /// Knot placement strategy.
128 knots: KnotStrategy,
129 /// Out-of-range extrapolation policy (default [`Extrapolation::Constant`]).
130 extrapolation: Extrapolation,
131 _marker: std::marker::PhantomData<F>,
132}
133
134impl<F: Float + Send + Sync + 'static> SplineTransformer<F> {
135 /// Create a new `SplineTransformer` with the DEFAULT extrapolation policy
136 /// ([`Extrapolation::Constant`], matching sklearn's `extrapolation="constant"`
137 /// default, `_polynomial.py:721`).
138 pub fn new(n_knots: usize, degree: usize, knots: KnotStrategy) -> Self {
139 Self::with_extrapolation(n_knots, degree, knots, Extrapolation::Constant)
140 }
141
142 /// Create a new `SplineTransformer` with an explicit extrapolation policy.
143 pub fn with_extrapolation(
144 n_knots: usize,
145 degree: usize,
146 knots: KnotStrategy,
147 extrapolation: Extrapolation,
148 ) -> Self {
149 Self {
150 n_knots,
151 degree,
152 knots,
153 extrapolation,
154 _marker: std::marker::PhantomData,
155 }
156 }
157
158 /// Return the number of interior knots.
159 #[must_use]
160 pub fn n_knots(&self) -> usize {
161 self.n_knots
162 }
163
164 /// Return the B-spline degree.
165 #[must_use]
166 pub fn degree(&self) -> usize {
167 self.degree
168 }
169
170 /// Return the knot placement strategy.
171 #[must_use]
172 pub fn knot_strategy(&self) -> KnotStrategy {
173 self.knots
174 }
175
176 /// Return the out-of-range extrapolation policy.
177 #[must_use]
178 pub fn extrapolation(&self) -> Extrapolation {
179 self.extrapolation
180 }
181}
182
183impl<F: Float + Send + Sync + 'static> Default for SplineTransformer<F> {
184 fn default() -> Self {
185 Self::new(5, 3, KnotStrategy::Uniform)
186 }
187}
188
189// ---------------------------------------------------------------------------
190// FittedSplineTransformer
191// ---------------------------------------------------------------------------
192
193/// A fitted spline transformer holding per-feature knot positions.
194///
195/// Created by calling [`Fit::fit`] on a [`SplineTransformer`].
196#[derive(Debug, Clone)]
197pub struct FittedSplineTransformer<F> {
198 /// Full knot vector per feature (including boundary knots with multiplicity).
199 knot_vectors: Vec<Vec<F>>,
200 /// Per-feature base-interval lower bound (`xmin = knots[degree]`, the fit min).
201 /// Used to clamp out-of-range values under [`Extrapolation::Constant`].
202 xmin: Vec<F>,
203 /// Per-feature base-interval upper bound (`xmax = knots[n_basis]`, the fit max).
204 xmax: Vec<F>,
205 /// Degree of the B-spline.
206 degree: usize,
207 /// Number of basis functions per feature.
208 n_basis: usize,
209 /// Out-of-range extrapolation policy.
210 extrapolation: Extrapolation,
211}
212
213impl<F: Float + Send + Sync + 'static> FittedSplineTransformer<F> {
214 /// Return the knot vectors.
215 #[must_use]
216 pub fn knot_vectors(&self) -> &[Vec<F>] {
217 &self.knot_vectors
218 }
219
220 /// Return the number of basis functions per feature.
221 #[must_use]
222 pub fn n_basis_per_feature(&self) -> usize {
223 self.n_basis
224 }
225
226 /// Return the total number of output columns.
227 #[must_use]
228 pub fn n_output_features(&self) -> usize {
229 self.knot_vectors.len() * self.n_basis
230 }
231
232 /// Return the out-of-range extrapolation policy.
233 #[must_use]
234 pub fn extrapolation(&self) -> Extrapolation {
235 self.extrapolation
236 }
237}
238
239/// Reject non-finite (NaN/Inf) entries in `x`, mirroring sklearn's
240/// `_validate_data(..., force_all_finite=True)` (`_polynomial.py:833-839`),
241/// which raises `ValueError("Input X contains NaN.")` / infinity.
242fn reject_non_finite<F: Float>(x: &Array2<F>, context: &str) -> Result<(), FerroError> {
243 if x.iter().any(|v| !v.is_finite()) {
244 return Err(FerroError::InvalidParameter {
245 name: "X".into(),
246 reason: format!("Input X contains NaN or infinity. ({context})"),
247 });
248 }
249 Ok(())
250}
251
252// ---------------------------------------------------------------------------
253// B-spline evaluation (Cox-de Boor recursion)
254// ---------------------------------------------------------------------------
255
256/// Evaluate all B-spline basis functions at a given value `x` using the
257/// Cox-de Boor recursion.
258///
259/// `knots` is the full knot vector of length `n_basis + degree + 1`.
260/// Returns a vector of length `n_basis` containing the basis values.
261fn bspline_basis<F: Float>(x: F, knots: &[F], degree: usize, n_basis: usize) -> Vec<F> {
262 // Start with degree-0 basis functions
263 let n_intervals = knots.len() - 1;
264 let mut basis = vec![F::zero(); n_intervals];
265
266 // Degree 0: indicator functions using half-open intervals [t_i, t_{i+1}).
267 // Special case: with sklearn's EXTENDED knot vector the base interval is
268 // `[knots[degree], knots[n_basis]]` (knots[n_basis] is the right end of the
269 // base support, NOT the rightmost extended knot). scipy's `design_matrix`
270 // includes the right endpoint of the base interval, so a value at
271 // `x == knots[n_basis]` must be evaluated as the limit from the left rather
272 // than returning all-zero under a naive half-open `t_i <= x < t_{i+1}`.
273 // Activate the last non-degenerate interval that LIES AT OR BEFORE the base
274 // right endpoint so the Cox-de Boor recursion propagates a non-zero value.
275 let base_right = knots[n_basis];
276 if x >= base_right {
277 // Find the last interval ending at the base right endpoint with
278 // non-zero width and activate it (the closed-right base span).
279 let mut found = false;
280 for i in (0..n_intervals).rev() {
281 if knots[i + 1] <= base_right && knots[i] < knots[i + 1] {
282 basis[i] = F::one();
283 found = true;
284 break;
285 }
286 }
287 // Fallback: if all such intervals are degenerate, activate the last one
288 if !found {
289 basis[n_intervals - 1] = F::one();
290 }
291 } else {
292 for i in 0..n_intervals {
293 // Half-open: [t_i, t_{i+1})
294 basis[i] = if x >= knots[i] && x < knots[i + 1] {
295 F::one()
296 } else {
297 F::zero()
298 };
299 }
300 }
301
302 // Build up to the desired degree
303 for d in 1..=degree {
304 let n_current = n_intervals - d;
305 let mut new_basis = vec![F::zero(); n_current];
306 for i in 0..n_current {
307 let denom1 = knots[i + d] - knots[i];
308 let denom2 = knots[i + d + 1] - knots[i + 1];
309
310 let left = if denom1 > F::zero() {
311 (x - knots[i]) / denom1 * basis[i]
312 } else {
313 F::zero()
314 };
315
316 let right = if denom2 > F::zero() {
317 (knots[i + d + 1] - x) / denom2 * basis[i + 1]
318 } else {
319 F::zero()
320 };
321
322 new_basis[i] = left + right;
323 }
324 basis = new_basis;
325 }
326
327 // Truncate or pad to n_basis
328 basis.truncate(n_basis);
329 while basis.len() < n_basis {
330 basis.push(F::zero());
331 }
332
333 basis
334}
335
336// ---------------------------------------------------------------------------
337// Trait implementations
338// ---------------------------------------------------------------------------
339
340impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SplineTransformer<F> {
341 type Fitted = FittedSplineTransformer<F>;
342 type Error = FerroError;
343
344 /// Fit by computing knot positions for each feature.
345 ///
346 /// # Errors
347 ///
348 /// - [`FerroError::InsufficientSamples`] if the input has fewer than 2 rows.
349 /// - [`FerroError::InvalidParameter`] if `n_knots` < 2.
350 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSplineTransformer<F>, FerroError> {
351 // sklearn `_validate_data(..., force_all_finite=True)` rejects NaN/Inf at
352 // fit (`_polynomial.py:833-839`). Match that contract.
353 reject_non_finite(x, "SplineTransformer::fit")?;
354
355 let n_samples = x.nrows();
356 if n_samples < 2 {
357 return Err(FerroError::InsufficientSamples {
358 required: 2,
359 actual: n_samples,
360 context: "SplineTransformer::fit".into(),
361 });
362 }
363 if self.n_knots < 2 {
364 return Err(FerroError::InvalidParameter {
365 name: "n_knots".into(),
366 reason: "n_knots must be at least 2".into(),
367 });
368 }
369
370 let n_features = x.ncols();
371 let n_basis = self.n_knots + self.degree - 1;
372 let mut knot_vectors = Vec::with_capacity(n_features);
373 let mut xmin = Vec::with_capacity(n_features);
374 let mut xmax = Vec::with_capacity(n_features);
375
376 for j in 0..n_features {
377 let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
378 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
379
380 let min_val = col_vals[0];
381 let max_val = col_vals[col_vals.len() - 1];
382
383 // Base-interval boundaries used by `Extrapolation::Constant`: a value
384 // below `xmin`/above `xmax` is clamped to the boundary before the
385 // basis is evaluated (sklearn `_polynomial.py:1059-1087`). These are
386 // the fit min/max, equal to `knots[degree]`/`knots[n_basis]` in the
387 // extended knot vector.
388 xmin.push(min_val);
389 xmax.push(max_val);
390
391 // Compute interior knots
392 let interior_knots: Vec<F> = match self.knots {
393 KnotStrategy::Uniform => (0..self.n_knots)
394 .map(|i| {
395 min_val
396 + (max_val - min_val) * F::from(i).unwrap()
397 / F::from(self.n_knots - 1).unwrap()
398 })
399 .collect(),
400 KnotStrategy::Quantile => {
401 let n = col_vals.len();
402 (0..self.n_knots)
403 .map(|i| {
404 let frac = F::from(i).unwrap()
405 / F::from(self.n_knots - 1).unwrap_or_else(F::one);
406 let pos = frac * F::from(n.saturating_sub(1)).unwrap();
407 let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
408 let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
409 let f = pos - F::from(lo).unwrap();
410 col_vals[lo] * (F::one() - f) + col_vals[hi] * f
411 })
412 .collect()
413 }
414 };
415
416 // Build full knot vector using sklearn's EXTENDED edge-spacing
417 // construction (`_polynomial.py:908-923`). sklearn explicitly
418 // REJECTS the clamped/`np.tile` repeated-boundary construction
419 // (`:898-906`, Eilers & Marx) in favour of reusing the spacing of
420 // the two first/last base knots:
421 // dist_min = base[1] - base[0]; dist_max = base[-1] - base[-2]
422 // left = linspace(base[0] - degree*dist_min, base[0] - dist_min, degree)
423 // right = linspace(base[-1] + dist_max, base[-1] + degree*dist_max, degree)
424 // knots = [left, base, right]
425 // numpy `linspace(a, b, num)` is inclusive of both endpoints.
426 let base = &interior_knots;
427 let nb = base.len();
428 let dist_min = base[1] - base[0];
429 let dist_max = base[nb - 1] - base[nb - 2];
430 let degree = self.degree;
431 let deg_f = F::from(degree).unwrap_or_else(F::one);
432
433 // numpy linspace with `num` inclusive endpoints. For num == 0 numpy
434 // returns an empty array; for num == 1 just [a]; for num >= 2 it
435 // includes both a and b. num == 0 occurs for degree == 0 (no
436 // edge-extension knots — the knot vector is the base knots alone).
437 let linspace = |a: F, b: F, num: usize| -> Vec<F> {
438 if num == 0 {
439 return Vec::new();
440 }
441 if num == 1 {
442 return vec![a];
443 }
444 let denom = F::from(num - 1).unwrap_or_else(F::one);
445 (0..num)
446 .map(|i| {
447 let t = F::from(i).unwrap_or_else(F::zero) / denom;
448 a + (b - a) * t
449 })
450 .collect()
451 };
452
453 let left = linspace(base[0] - deg_f * dist_min, base[0] - dist_min, degree);
454 let right = linspace(
455 base[nb - 1] + dist_max,
456 base[nb - 1] + deg_f * dist_max,
457 degree,
458 );
459
460 let mut full_knots = Vec::with_capacity(left.len() + nb + right.len());
461 full_knots.extend_from_slice(&left);
462 full_knots.extend_from_slice(base);
463 full_knots.extend_from_slice(&right);
464
465 knot_vectors.push(full_knots);
466 }
467
468 Ok(FittedSplineTransformer {
469 knot_vectors,
470 xmin,
471 xmax,
472 degree: self.degree,
473 n_basis,
474 extrapolation: self.extrapolation,
475 })
476 }
477}
478
479impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSplineTransformer<F> {
480 type Output = Array2<F>;
481 type Error = FerroError;
482
483 /// Generate B-spline basis functions for each feature.
484 ///
485 /// # Errors
486 ///
487 /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
488 /// from the number of features seen during fitting.
489 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
490 let n_features = self.knot_vectors.len();
491 if x.ncols() != n_features {
492 return Err(FerroError::ShapeMismatch {
493 expected: vec![x.nrows(), n_features],
494 actual: vec![x.nrows(), x.ncols()],
495 context: "FittedSplineTransformer::transform".into(),
496 });
497 }
498
499 // sklearn validates the transform input too (`_validate_data` in
500 // `transform`), rejecting NaN/Inf.
501 reject_non_finite(x, "FittedSplineTransformer::transform")?;
502
503 // Only `Constant` extrapolation is implemented. The other sklearn modes
504 // are NOT-STARTED — surface a clear error rather than emit wrong values.
505 match self.extrapolation {
506 Extrapolation::Constant => {}
507 Extrapolation::Linear
508 | Extrapolation::Continue
509 | Extrapolation::Periodic
510 | Extrapolation::Error => {
511 return Err(FerroError::InvalidParameter {
512 name: "extrapolation".into(),
513 reason: "only Extrapolation::Constant is implemented; \
514 linear/continue/periodic/error are NOT-STARTED (blocker #1333)"
515 .into(),
516 });
517 }
518 }
519
520 let n_samples = x.nrows();
521 let n_out = n_features * self.n_basis;
522 let mut out = Array2::zeros((n_samples, n_out));
523
524 for j in 0..n_features {
525 let knots = &self.knot_vectors[j];
526 let col_offset = j * self.n_basis;
527 let lo = self.xmin[j];
528 let hi = self.xmax[j];
529
530 for i in 0..n_samples {
531 // `Extrapolation::Constant`: clamp the value to the base interval
532 // `[xmin, xmax]` before evaluating the basis. At the boundary,
533 // only the first/last `degree` basis columns are non-zero, so the
534 // clamp reproduces sklearn's `f_min[:degree]` / `f_max[-degree:]`
535 // assignment (`_polynomial.py:1059-1087`). The clamp is a no-op
536 // for in-range values, preserving the verified in-range basis.
537 let raw = x[[i, j]];
538 let val = if raw < lo {
539 lo
540 } else if raw > hi {
541 hi
542 } else {
543 raw
544 };
545 let basis_vals = bspline_basis(val, knots, self.degree, self.n_basis);
546 for (k, &bv) in basis_vals.iter().enumerate() {
547 out[[i, col_offset + k]] = bv;
548 }
549 }
550 }
551
552 Ok(out)
553 }
554}
555
556/// Implement `Transform` on the unfitted transformer.
557impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SplineTransformer<F> {
558 type Output = Array2<F>;
559 type Error = FerroError;
560
561 /// Always returns an error — the transformer must be fitted first.
562 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
563 Err(FerroError::InvalidParameter {
564 name: "SplineTransformer".into(),
565 reason: "transformer must be fitted before calling transform; use fit() first".into(),
566 })
567 }
568}
569
570impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SplineTransformer<F> {
571 type FitError = FerroError;
572
573 /// Fit and transform in one step.
574 ///
575 /// # Errors
576 ///
577 /// Returns an error if fitting fails.
578 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
579 let fitted = self.fit(x, &())?;
580 fitted.transform(x)
581 }
582}
583
584// ---------------------------------------------------------------------------
585// Tests
586// ---------------------------------------------------------------------------
587
588#[cfg(test)]
589mod tests {
590 use super::*;
591 use approx::assert_abs_diff_eq;
592 use ndarray::array;
593
594 #[test]
595 fn test_spline_output_dimensions() {
596 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
597 let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
598 let fitted = st.fit(&x, &()).unwrap();
599 let out = fitted.transform(&x).unwrap();
600 // n_basis = n_knots + degree - 1 = 5 + 3 - 1 = 7
601 assert_eq!(out.ncols(), 7);
602 assert_eq!(out.nrows(), 5);
603 }
604
605 #[test]
606 fn test_spline_partition_of_unity() {
607 // B-spline basis functions should sum to 1 at any interior point
608 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
609 let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
610 let fitted = st.fit(&x, &()).unwrap();
611 let out = fitted.transform(&x).unwrap();
612 for i in 0..out.nrows() {
613 let row_sum: f64 = out.row(i).iter().sum();
614 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
615 }
616 }
617
618 #[test]
619 fn test_spline_non_negative() {
620 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
621 let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
622 let fitted = st.fit(&x, &()).unwrap();
623 let out = fitted.transform(&x).unwrap();
624 for v in &out {
625 assert!(*v >= -1e-10, "Basis value should be non-negative, got {v}");
626 }
627 }
628
629 #[test]
630 fn test_spline_quantile_knots() {
631 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Quantile);
632 let x = array![[0.0], [0.1], [0.2], [0.5], [1.0]];
633 let fitted = st.fit(&x, &()).unwrap();
634 let out = fitted.transform(&x).unwrap();
635 assert_eq!(out.ncols(), 7);
636 // Partition of unity should still hold
637 for i in 0..out.nrows() {
638 let row_sum: f64 = out.row(i).iter().sum();
639 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
640 }
641 }
642
643 #[test]
644 fn test_spline_multi_feature() {
645 let st = SplineTransformer::<f64>::new(3, 2, KnotStrategy::Uniform);
646 let x = array![[0.0, 10.0], [0.5, 15.0], [1.0, 20.0]];
647 let fitted = st.fit(&x, &()).unwrap();
648 let out = fitted.transform(&x).unwrap();
649 // n_basis per feature = 3 + 2 - 1 = 4, total = 2 * 4 = 8
650 assert_eq!(out.ncols(), 8);
651 }
652
653 #[test]
654 fn test_spline_fit_transform() {
655 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
656 let x = array![[0.0], [0.5], [1.0]];
657 let out = st.fit_transform(&x).unwrap();
658 assert_eq!(out.ncols(), 7);
659 }
660
661 #[test]
662 fn test_spline_insufficient_samples_error() {
663 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
664 let x = array![[1.0]];
665 assert!(st.fit(&x, &()).is_err());
666 }
667
668 #[test]
669 fn test_spline_too_few_knots_error() {
670 let st = SplineTransformer::<f64>::new(1, 3, KnotStrategy::Uniform);
671 let x = array![[0.0], [1.0]];
672 assert!(st.fit(&x, &()).is_err());
673 }
674
675 #[test]
676 fn test_spline_zero_degree_allowed() -> Result<(), FerroError> {
677 // sklearn allows degree==0 (piecewise-constant B-spline):
678 // `_parameter_constraints` `degree: Interval(Integral, 0, None,
679 // closed="left")` (`_polynomial.py:705`). degree==0 must fit, not error.
680 let st = SplineTransformer::<f64>::new(5, 0, KnotStrategy::Uniform);
681 let x = array![[0.0], [1.0]];
682 let fitted = st.fit(&x, &())?;
683 // n_basis = n_knots + degree - 1 = 5 + 0 - 1 = 4
684 let out = fitted.transform(&x)?;
685 assert_eq!(out.ncols(), 4);
686 Ok(())
687 }
688
689 #[test]
690 fn test_spline_shape_mismatch_error() {
691 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
692 let x_train = array![[0.0, 1.0], [0.5, 1.5]];
693 let fitted = st.fit(&x_train, &()).unwrap();
694 let x_bad = array![[0.0]];
695 assert!(fitted.transform(&x_bad).is_err());
696 }
697
698 #[test]
699 fn test_spline_unfitted_error() {
700 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
701 let x = array![[0.0]];
702 assert!(st.transform(&x).is_err());
703 }
704
705 #[test]
706 fn test_spline_default() {
707 let st = SplineTransformer::<f64>::default();
708 assert_eq!(st.n_knots(), 5);
709 assert_eq!(st.degree(), 3);
710 assert_eq!(st.knot_strategy(), KnotStrategy::Uniform);
711 }
712
713 #[test]
714 fn test_spline_degree1() {
715 // Linear splines: should produce piecewise linear basis
716 let st = SplineTransformer::<f64>::new(3, 1, KnotStrategy::Uniform);
717 let x = array![[0.0], [0.5], [1.0]];
718 let fitted = st.fit(&x, &()).unwrap();
719 let out = fitted.transform(&x).unwrap();
720 // n_basis = 3 + 1 - 1 = 3
721 assert_eq!(out.ncols(), 3);
722 // Partition of unity
723 for i in 0..out.nrows() {
724 let row_sum: f64 = out.row(i).iter().sum();
725 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
726 }
727 }
728}