Skip to main content

ferrolearn_preprocess/
spline_transformer.rs

1//! Spline transformer: generate B-spline basis functions for each feature.
2//!
3//! [`SplineTransformer`] expands each input feature into a set of B-spline
4//! basis columns. This is a nonlinear feature expansion technique that
5//! represents each feature as a combination of piecewise polynomial functions.
6//!
7//! # Knot Placement
8//!
9//! - [`KnotStrategy::Uniform`] — knots are evenly spaced between min and max.
10//! - [`KnotStrategy::Quantile`] — knots are placed at quantiles of the data.
11
12use ferrolearn_core::error::FerroError;
13use ferrolearn_core::traits::{Fit, FitTransform, Transform};
14use ndarray::Array2;
15use num_traits::Float;
16
17// ---------------------------------------------------------------------------
18// KnotStrategy
19// ---------------------------------------------------------------------------
20
21/// Strategy for placing knots in the spline transformer.
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum KnotStrategy {
24    /// Knots are evenly spaced between the min and max of each feature.
25    Uniform,
26    /// Knots are placed at quantiles of the data.
27    Quantile,
28}
29
30// ---------------------------------------------------------------------------
31// SplineTransformer (unfitted)
32// ---------------------------------------------------------------------------
33
34/// An unfitted spline transformer.
35///
36/// Calling [`Fit::fit`] computes the knot positions and returns a
37/// [`FittedSplineTransformer`] that generates B-spline basis functions.
38///
39/// # Parameters
40///
41/// - `n_knots` — number of interior knots (default 5).
42/// - `degree` — degree of the B-spline (default 3, i.e. cubic).
43/// - `knots` — knot placement strategy (default `Uniform`).
44///
45/// The number of output columns per feature is `n_knots + degree - 1`.
46///
47/// # Examples
48///
49/// ```
50/// use ferrolearn_preprocess::spline_transformer::{SplineTransformer, KnotStrategy};
51/// use ferrolearn_core::traits::{Fit, Transform};
52/// use ndarray::array;
53///
54/// let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
55/// let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
56/// let fitted = st.fit(&x, &()).unwrap();
57/// let out = fitted.transform(&x).unwrap();
58/// // 5 + 3 - 1 = 7 basis columns per feature
59/// assert_eq!(out.ncols(), 7);
60/// ```
61#[must_use]
62#[derive(Debug, Clone)]
63pub struct SplineTransformer<F> {
64    /// Number of interior knots.
65    n_knots: usize,
66    /// Degree of the B-spline.
67    degree: usize,
68    /// Knot placement strategy.
69    knots: KnotStrategy,
70    _marker: std::marker::PhantomData<F>,
71}
72
73impl<F: Float + Send + Sync + 'static> SplineTransformer<F> {
74    /// Create a new `SplineTransformer`.
75    pub fn new(n_knots: usize, degree: usize, knots: KnotStrategy) -> Self {
76        Self {
77            n_knots,
78            degree,
79            knots,
80            _marker: std::marker::PhantomData,
81        }
82    }
83
84    /// Return the number of interior knots.
85    #[must_use]
86    pub fn n_knots(&self) -> usize {
87        self.n_knots
88    }
89
90    /// Return the B-spline degree.
91    #[must_use]
92    pub fn degree(&self) -> usize {
93        self.degree
94    }
95
96    /// Return the knot placement strategy.
97    #[must_use]
98    pub fn knot_strategy(&self) -> KnotStrategy {
99        self.knots
100    }
101}
102
103impl<F: Float + Send + Sync + 'static> Default for SplineTransformer<F> {
104    fn default() -> Self {
105        Self::new(5, 3, KnotStrategy::Uniform)
106    }
107}
108
109// ---------------------------------------------------------------------------
110// FittedSplineTransformer
111// ---------------------------------------------------------------------------
112
113/// A fitted spline transformer holding per-feature knot positions.
114///
115/// Created by calling [`Fit::fit`] on a [`SplineTransformer`].
116#[derive(Debug, Clone)]
117pub struct FittedSplineTransformer<F> {
118    /// Full knot vector per feature (including boundary knots with multiplicity).
119    knot_vectors: Vec<Vec<F>>,
120    /// Degree of the B-spline.
121    degree: usize,
122    /// Number of basis functions per feature.
123    n_basis: usize,
124}
125
126impl<F: Float + Send + Sync + 'static> FittedSplineTransformer<F> {
127    /// Return the knot vectors.
128    #[must_use]
129    pub fn knot_vectors(&self) -> &[Vec<F>] {
130        &self.knot_vectors
131    }
132
133    /// Return the number of basis functions per feature.
134    #[must_use]
135    pub fn n_basis_per_feature(&self) -> usize {
136        self.n_basis
137    }
138
139    /// Return the total number of output columns.
140    #[must_use]
141    pub fn n_output_features(&self) -> usize {
142        self.knot_vectors.len() * self.n_basis
143    }
144}
145
146// ---------------------------------------------------------------------------
147// B-spline evaluation (Cox-de Boor recursion)
148// ---------------------------------------------------------------------------
149
150/// Evaluate all B-spline basis functions at a given value `x` using the
151/// Cox-de Boor recursion.
152///
153/// `knots` is the full knot vector of length `n_basis + degree + 1`.
154/// Returns a vector of length `n_basis` containing the basis values.
155fn bspline_basis<F: Float>(x: F, knots: &[F], degree: usize, n_basis: usize) -> Vec<F> {
156    // Start with degree-0 basis functions
157    let n_intervals = knots.len() - 1;
158    let mut basis = vec![F::zero(); n_intervals];
159
160    // Degree 0: indicator functions using half-open intervals [t_i, t_{i+1}).
161    // Special case: when x equals the rightmost knot, assign it to the last
162    // non-degenerate interval so the Cox-de Boor recursion can propagate
163    // the value up through knot spans with non-zero width.
164    let last_knot = knots[knots.len() - 1];
165    if x >= last_knot {
166        // Find the last interval with non-zero width and activate it
167        let mut found = false;
168        for i in (0..n_intervals).rev() {
169            if knots[i] < knots[i + 1] {
170                basis[i] = F::one();
171                found = true;
172                break;
173            }
174        }
175        // Fallback: if all intervals are degenerate, activate the last one
176        if !found {
177            basis[n_intervals - 1] = F::one();
178        }
179    } else {
180        for i in 0..n_intervals {
181            // Half-open: [t_i, t_{i+1})
182            basis[i] = if x >= knots[i] && x < knots[i + 1] {
183                F::one()
184            } else {
185                F::zero()
186            };
187        }
188    }
189
190    // Build up to the desired degree
191    for d in 1..=degree {
192        let n_current = n_intervals - d;
193        let mut new_basis = vec![F::zero(); n_current];
194        for i in 0..n_current {
195            let denom1 = knots[i + d] - knots[i];
196            let denom2 = knots[i + d + 1] - knots[i + 1];
197
198            let left = if denom1 > F::zero() {
199                (x - knots[i]) / denom1 * basis[i]
200            } else {
201                F::zero()
202            };
203
204            let right = if denom2 > F::zero() {
205                (knots[i + d + 1] - x) / denom2 * basis[i + 1]
206            } else {
207                F::zero()
208            };
209
210            new_basis[i] = left + right;
211        }
212        basis = new_basis;
213    }
214
215    // Truncate or pad to n_basis
216    basis.truncate(n_basis);
217    while basis.len() < n_basis {
218        basis.push(F::zero());
219    }
220
221    basis
222}
223
224// ---------------------------------------------------------------------------
225// Trait implementations
226// ---------------------------------------------------------------------------
227
228impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SplineTransformer<F> {
229    type Fitted = FittedSplineTransformer<F>;
230    type Error = FerroError;
231
232    /// Fit by computing knot positions for each feature.
233    ///
234    /// # Errors
235    ///
236    /// - [`FerroError::InsufficientSamples`] if the input has fewer than 2 rows.
237    /// - [`FerroError::InvalidParameter`] if `n_knots` < 2 or `degree` is 0.
238    fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSplineTransformer<F>, FerroError> {
239        let n_samples = x.nrows();
240        if n_samples < 2 {
241            return Err(FerroError::InsufficientSamples {
242                required: 2,
243                actual: n_samples,
244                context: "SplineTransformer::fit".into(),
245            });
246        }
247        if self.n_knots < 2 {
248            return Err(FerroError::InvalidParameter {
249                name: "n_knots".into(),
250                reason: "n_knots must be at least 2".into(),
251            });
252        }
253        if self.degree == 0 {
254            return Err(FerroError::InvalidParameter {
255                name: "degree".into(),
256                reason: "degree must be at least 1".into(),
257            });
258        }
259
260        let n_features = x.ncols();
261        let n_basis = self.n_knots + self.degree - 1;
262        let mut knot_vectors = Vec::with_capacity(n_features);
263
264        for j in 0..n_features {
265            let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
266            col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
267
268            let min_val = col_vals[0];
269            let max_val = col_vals[col_vals.len() - 1];
270
271            // Compute interior knots
272            let interior_knots: Vec<F> = match self.knots {
273                KnotStrategy::Uniform => (0..self.n_knots)
274                    .map(|i| {
275                        min_val
276                            + (max_val - min_val) * F::from(i).unwrap()
277                                / F::from(self.n_knots - 1).unwrap()
278                    })
279                    .collect(),
280                KnotStrategy::Quantile => {
281                    let n = col_vals.len();
282                    (0..self.n_knots)
283                        .map(|i| {
284                            let frac =
285                                F::from(i).unwrap() / F::from(self.n_knots - 1).unwrap_or(F::one());
286                            let pos = frac * F::from(n.saturating_sub(1)).unwrap();
287                            let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
288                            let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
289                            let f = pos - F::from(lo).unwrap();
290                            col_vals[lo] * (F::one() - f) + col_vals[hi] * f
291                        })
292                        .collect()
293                }
294            };
295
296            // Build full knot vector with boundary knots of multiplicity (degree + 1)
297            let mut full_knots = Vec::new();
298            // Left boundary knots
299            for _ in 0..self.degree {
300                full_knots.push(min_val);
301            }
302            // Interior knots
303            full_knots.extend_from_slice(&interior_knots);
304            // Right boundary knots
305            for _ in 0..self.degree {
306                full_knots.push(max_val);
307            }
308
309            knot_vectors.push(full_knots);
310        }
311
312        Ok(FittedSplineTransformer {
313            knot_vectors,
314            degree: self.degree,
315            n_basis,
316        })
317    }
318}
319
320impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSplineTransformer<F> {
321    type Output = Array2<F>;
322    type Error = FerroError;
323
324    /// Generate B-spline basis functions for each feature.
325    ///
326    /// # Errors
327    ///
328    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
329    /// from the number of features seen during fitting.
330    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
331        let n_features = self.knot_vectors.len();
332        if x.ncols() != n_features {
333            return Err(FerroError::ShapeMismatch {
334                expected: vec![x.nrows(), n_features],
335                actual: vec![x.nrows(), x.ncols()],
336                context: "FittedSplineTransformer::transform".into(),
337            });
338        }
339
340        let n_samples = x.nrows();
341        let n_out = n_features * self.n_basis;
342        let mut out = Array2::zeros((n_samples, n_out));
343
344        for j in 0..n_features {
345            let knots = &self.knot_vectors[j];
346            let col_offset = j * self.n_basis;
347
348            for i in 0..n_samples {
349                let val = x[[i, j]];
350                let basis_vals = bspline_basis(val, knots, self.degree, self.n_basis);
351                for (k, &bv) in basis_vals.iter().enumerate() {
352                    out[[i, col_offset + k]] = bv;
353                }
354            }
355        }
356
357        Ok(out)
358    }
359}
360
361/// Implement `Transform` on the unfitted transformer.
362impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SplineTransformer<F> {
363    type Output = Array2<F>;
364    type Error = FerroError;
365
366    /// Always returns an error — the transformer must be fitted first.
367    fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
368        Err(FerroError::InvalidParameter {
369            name: "SplineTransformer".into(),
370            reason: "transformer must be fitted before calling transform; use fit() first".into(),
371        })
372    }
373}
374
375impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SplineTransformer<F> {
376    type FitError = FerroError;
377
378    /// Fit and transform in one step.
379    ///
380    /// # Errors
381    ///
382    /// Returns an error if fitting fails.
383    fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
384        let fitted = self.fit(x, &())?;
385        fitted.transform(x)
386    }
387}
388
389// ---------------------------------------------------------------------------
390// Tests
391// ---------------------------------------------------------------------------
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use approx::assert_abs_diff_eq;
397    use ndarray::array;
398
399    #[test]
400    fn test_spline_output_dimensions() {
401        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
402        let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
403        let fitted = st.fit(&x, &()).unwrap();
404        let out = fitted.transform(&x).unwrap();
405        // n_basis = n_knots + degree - 1 = 5 + 3 - 1 = 7
406        assert_eq!(out.ncols(), 7);
407        assert_eq!(out.nrows(), 5);
408    }
409
410    #[test]
411    fn test_spline_partition_of_unity() {
412        // B-spline basis functions should sum to 1 at any interior point
413        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
414        let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
415        let fitted = st.fit(&x, &()).unwrap();
416        let out = fitted.transform(&x).unwrap();
417        for i in 0..out.nrows() {
418            let row_sum: f64 = out.row(i).iter().sum();
419            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
420        }
421    }
422
423    #[test]
424    fn test_spline_non_negative() {
425        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
426        let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
427        let fitted = st.fit(&x, &()).unwrap();
428        let out = fitted.transform(&x).unwrap();
429        for v in out.iter() {
430            assert!(
431                *v >= -1e-10,
432                "Basis value should be non-negative, got {}",
433                v
434            );
435        }
436    }
437
438    #[test]
439    fn test_spline_quantile_knots() {
440        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Quantile);
441        let x = array![[0.0], [0.1], [0.2], [0.5], [1.0]];
442        let fitted = st.fit(&x, &()).unwrap();
443        let out = fitted.transform(&x).unwrap();
444        assert_eq!(out.ncols(), 7);
445        // Partition of unity should still hold
446        for i in 0..out.nrows() {
447            let row_sum: f64 = out.row(i).iter().sum();
448            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
449        }
450    }
451
452    #[test]
453    fn test_spline_multi_feature() {
454        let st = SplineTransformer::<f64>::new(3, 2, KnotStrategy::Uniform);
455        let x = array![[0.0, 10.0], [0.5, 15.0], [1.0, 20.0]];
456        let fitted = st.fit(&x, &()).unwrap();
457        let out = fitted.transform(&x).unwrap();
458        // n_basis per feature = 3 + 2 - 1 = 4, total = 2 * 4 = 8
459        assert_eq!(out.ncols(), 8);
460    }
461
462    #[test]
463    fn test_spline_fit_transform() {
464        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
465        let x = array![[0.0], [0.5], [1.0]];
466        let out = st.fit_transform(&x).unwrap();
467        assert_eq!(out.ncols(), 7);
468    }
469
470    #[test]
471    fn test_spline_insufficient_samples_error() {
472        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
473        let x = array![[1.0]];
474        assert!(st.fit(&x, &()).is_err());
475    }
476
477    #[test]
478    fn test_spline_too_few_knots_error() {
479        let st = SplineTransformer::<f64>::new(1, 3, KnotStrategy::Uniform);
480        let x = array![[0.0], [1.0]];
481        assert!(st.fit(&x, &()).is_err());
482    }
483
484    #[test]
485    fn test_spline_zero_degree_error() {
486        let st = SplineTransformer::<f64>::new(5, 0, KnotStrategy::Uniform);
487        let x = array![[0.0], [1.0]];
488        assert!(st.fit(&x, &()).is_err());
489    }
490
491    #[test]
492    fn test_spline_shape_mismatch_error() {
493        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
494        let x_train = array![[0.0, 1.0], [0.5, 1.5]];
495        let fitted = st.fit(&x_train, &()).unwrap();
496        let x_bad = array![[0.0]];
497        assert!(fitted.transform(&x_bad).is_err());
498    }
499
500    #[test]
501    fn test_spline_unfitted_error() {
502        let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
503        let x = array![[0.0]];
504        assert!(st.transform(&x).is_err());
505    }
506
507    #[test]
508    fn test_spline_default() {
509        let st = SplineTransformer::<f64>::default();
510        assert_eq!(st.n_knots(), 5);
511        assert_eq!(st.degree(), 3);
512        assert_eq!(st.knot_strategy(), KnotStrategy::Uniform);
513    }
514
515    #[test]
516    fn test_spline_degree1() {
517        // Linear splines: should produce piecewise linear basis
518        let st = SplineTransformer::<f64>::new(3, 1, KnotStrategy::Uniform);
519        let x = array![[0.0], [0.5], [1.0]];
520        let fitted = st.fit(&x, &()).unwrap();
521        let out = fitted.transform(&x).unwrap();
522        // n_basis = 3 + 1 - 1 = 3
523        assert_eq!(out.ncols(), 3);
524        // Partition of unity
525        for i in 0..out.nrows() {
526            let row_sum: f64 = out.row(i).iter().sum();
527            assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
528        }
529    }
530}