1use ferrolearn_core::error::FerroError;
13use ferrolearn_core::traits::{Fit, FitTransform, Transform};
14use ndarray::Array2;
15use num_traits::Float;
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23pub enum KnotStrategy {
24 Uniform,
26 Quantile,
28}
29
30#[must_use]
62#[derive(Debug, Clone)]
63pub struct SplineTransformer<F> {
64 n_knots: usize,
66 degree: usize,
68 knots: KnotStrategy,
70 _marker: std::marker::PhantomData<F>,
71}
72
73impl<F: Float + Send + Sync + 'static> SplineTransformer<F> {
74 pub fn new(n_knots: usize, degree: usize, knots: KnotStrategy) -> Self {
76 Self {
77 n_knots,
78 degree,
79 knots,
80 _marker: std::marker::PhantomData,
81 }
82 }
83
84 #[must_use]
86 pub fn n_knots(&self) -> usize {
87 self.n_knots
88 }
89
90 #[must_use]
92 pub fn degree(&self) -> usize {
93 self.degree
94 }
95
96 #[must_use]
98 pub fn knot_strategy(&self) -> KnotStrategy {
99 self.knots
100 }
101}
102
103impl<F: Float + Send + Sync + 'static> Default for SplineTransformer<F> {
104 fn default() -> Self {
105 Self::new(5, 3, KnotStrategy::Uniform)
106 }
107}
108
109#[derive(Debug, Clone)]
117pub struct FittedSplineTransformer<F> {
118 knot_vectors: Vec<Vec<F>>,
120 degree: usize,
122 n_basis: usize,
124}
125
126impl<F: Float + Send + Sync + 'static> FittedSplineTransformer<F> {
127 #[must_use]
129 pub fn knot_vectors(&self) -> &[Vec<F>] {
130 &self.knot_vectors
131 }
132
133 #[must_use]
135 pub fn n_basis_per_feature(&self) -> usize {
136 self.n_basis
137 }
138
139 #[must_use]
141 pub fn n_output_features(&self) -> usize {
142 self.knot_vectors.len() * self.n_basis
143 }
144}
145
146fn bspline_basis<F: Float>(x: F, knots: &[F], degree: usize, n_basis: usize) -> Vec<F> {
156 let n_intervals = knots.len() - 1;
158 let mut basis = vec![F::zero(); n_intervals];
159
160 let last_knot = knots[knots.len() - 1];
165 if x >= last_knot {
166 let mut found = false;
168 for i in (0..n_intervals).rev() {
169 if knots[i] < knots[i + 1] {
170 basis[i] = F::one();
171 found = true;
172 break;
173 }
174 }
175 if !found {
177 basis[n_intervals - 1] = F::one();
178 }
179 } else {
180 for i in 0..n_intervals {
181 basis[i] = if x >= knots[i] && x < knots[i + 1] {
183 F::one()
184 } else {
185 F::zero()
186 };
187 }
188 }
189
190 for d in 1..=degree {
192 let n_current = n_intervals - d;
193 let mut new_basis = vec![F::zero(); n_current];
194 for i in 0..n_current {
195 let denom1 = knots[i + d] - knots[i];
196 let denom2 = knots[i + d + 1] - knots[i + 1];
197
198 let left = if denom1 > F::zero() {
199 (x - knots[i]) / denom1 * basis[i]
200 } else {
201 F::zero()
202 };
203
204 let right = if denom2 > F::zero() {
205 (knots[i + d + 1] - x) / denom2 * basis[i + 1]
206 } else {
207 F::zero()
208 };
209
210 new_basis[i] = left + right;
211 }
212 basis = new_basis;
213 }
214
215 basis.truncate(n_basis);
217 while basis.len() < n_basis {
218 basis.push(F::zero());
219 }
220
221 basis
222}
223
224impl<F: Float + Send + Sync + 'static> Fit<Array2<F>, ()> for SplineTransformer<F> {
229 type Fitted = FittedSplineTransformer<F>;
230 type Error = FerroError;
231
232 fn fit(&self, x: &Array2<F>, _y: &()) -> Result<FittedSplineTransformer<F>, FerroError> {
239 let n_samples = x.nrows();
240 if n_samples < 2 {
241 return Err(FerroError::InsufficientSamples {
242 required: 2,
243 actual: n_samples,
244 context: "SplineTransformer::fit".into(),
245 });
246 }
247 if self.n_knots < 2 {
248 return Err(FerroError::InvalidParameter {
249 name: "n_knots".into(),
250 reason: "n_knots must be at least 2".into(),
251 });
252 }
253 if self.degree == 0 {
254 return Err(FerroError::InvalidParameter {
255 name: "degree".into(),
256 reason: "degree must be at least 1".into(),
257 });
258 }
259
260 let n_features = x.ncols();
261 let n_basis = self.n_knots + self.degree - 1;
262 let mut knot_vectors = Vec::with_capacity(n_features);
263
264 for j in 0..n_features {
265 let mut col_vals: Vec<F> = x.column(j).iter().copied().collect();
266 col_vals.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
267
268 let min_val = col_vals[0];
269 let max_val = col_vals[col_vals.len() - 1];
270
271 let interior_knots: Vec<F> = match self.knots {
273 KnotStrategy::Uniform => (0..self.n_knots)
274 .map(|i| {
275 min_val
276 + (max_val - min_val) * F::from(i).unwrap()
277 / F::from(self.n_knots - 1).unwrap()
278 })
279 .collect(),
280 KnotStrategy::Quantile => {
281 let n = col_vals.len();
282 (0..self.n_knots)
283 .map(|i| {
284 let frac =
285 F::from(i).unwrap() / F::from(self.n_knots - 1).unwrap_or(F::one());
286 let pos = frac * F::from(n.saturating_sub(1)).unwrap();
287 let lo = pos.floor().to_usize().unwrap_or(0).min(n - 1);
288 let hi = pos.ceil().to_usize().unwrap_or(0).min(n - 1);
289 let f = pos - F::from(lo).unwrap();
290 col_vals[lo] * (F::one() - f) + col_vals[hi] * f
291 })
292 .collect()
293 }
294 };
295
296 let mut full_knots = Vec::new();
298 for _ in 0..self.degree {
300 full_knots.push(min_val);
301 }
302 full_knots.extend_from_slice(&interior_knots);
304 for _ in 0..self.degree {
306 full_knots.push(max_val);
307 }
308
309 knot_vectors.push(full_knots);
310 }
311
312 Ok(FittedSplineTransformer {
313 knot_vectors,
314 degree: self.degree,
315 n_basis,
316 })
317 }
318}
319
320impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSplineTransformer<F> {
321 type Output = Array2<F>;
322 type Error = FerroError;
323
324 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
331 let n_features = self.knot_vectors.len();
332 if x.ncols() != n_features {
333 return Err(FerroError::ShapeMismatch {
334 expected: vec![x.nrows(), n_features],
335 actual: vec![x.nrows(), x.ncols()],
336 context: "FittedSplineTransformer::transform".into(),
337 });
338 }
339
340 let n_samples = x.nrows();
341 let n_out = n_features * self.n_basis;
342 let mut out = Array2::zeros((n_samples, n_out));
343
344 for j in 0..n_features {
345 let knots = &self.knot_vectors[j];
346 let col_offset = j * self.n_basis;
347
348 for i in 0..n_samples {
349 let val = x[[i, j]];
350 let basis_vals = bspline_basis(val, knots, self.degree, self.n_basis);
351 for (k, &bv) in basis_vals.iter().enumerate() {
352 out[[i, col_offset + k]] = bv;
353 }
354 }
355 }
356
357 Ok(out)
358 }
359}
360
361impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for SplineTransformer<F> {
363 type Output = Array2<F>;
364 type Error = FerroError;
365
366 fn transform(&self, _x: &Array2<F>) -> Result<Array2<F>, FerroError> {
368 Err(FerroError::InvalidParameter {
369 name: "SplineTransformer".into(),
370 reason: "transformer must be fitted before calling transform; use fit() first".into(),
371 })
372 }
373}
374
375impl<F: Float + Send + Sync + 'static> FitTransform<Array2<F>> for SplineTransformer<F> {
376 type FitError = FerroError;
377
378 fn fit_transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
384 let fitted = self.fit(x, &())?;
385 fitted.transform(x)
386 }
387}
388
389#[cfg(test)]
394mod tests {
395 use super::*;
396 use approx::assert_abs_diff_eq;
397 use ndarray::array;
398
399 #[test]
400 fn test_spline_output_dimensions() {
401 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
402 let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
403 let fitted = st.fit(&x, &()).unwrap();
404 let out = fitted.transform(&x).unwrap();
405 assert_eq!(out.ncols(), 7);
407 assert_eq!(out.nrows(), 5);
408 }
409
410 #[test]
411 fn test_spline_partition_of_unity() {
412 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
414 let x = array![[0.0], [0.25], [0.5], [0.75], [1.0]];
415 let fitted = st.fit(&x, &()).unwrap();
416 let out = fitted.transform(&x).unwrap();
417 for i in 0..out.nrows() {
418 let row_sum: f64 = out.row(i).iter().sum();
419 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
420 }
421 }
422
423 #[test]
424 fn test_spline_non_negative() {
425 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
426 let x = array![[0.0], [0.1], [0.5], [0.9], [1.0]];
427 let fitted = st.fit(&x, &()).unwrap();
428 let out = fitted.transform(&x).unwrap();
429 for v in out.iter() {
430 assert!(
431 *v >= -1e-10,
432 "Basis value should be non-negative, got {}",
433 v
434 );
435 }
436 }
437
438 #[test]
439 fn test_spline_quantile_knots() {
440 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Quantile);
441 let x = array![[0.0], [0.1], [0.2], [0.5], [1.0]];
442 let fitted = st.fit(&x, &()).unwrap();
443 let out = fitted.transform(&x).unwrap();
444 assert_eq!(out.ncols(), 7);
445 for i in 0..out.nrows() {
447 let row_sum: f64 = out.row(i).iter().sum();
448 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
449 }
450 }
451
452 #[test]
453 fn test_spline_multi_feature() {
454 let st = SplineTransformer::<f64>::new(3, 2, KnotStrategy::Uniform);
455 let x = array![[0.0, 10.0], [0.5, 15.0], [1.0, 20.0]];
456 let fitted = st.fit(&x, &()).unwrap();
457 let out = fitted.transform(&x).unwrap();
458 assert_eq!(out.ncols(), 8);
460 }
461
462 #[test]
463 fn test_spline_fit_transform() {
464 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
465 let x = array![[0.0], [0.5], [1.0]];
466 let out = st.fit_transform(&x).unwrap();
467 assert_eq!(out.ncols(), 7);
468 }
469
470 #[test]
471 fn test_spline_insufficient_samples_error() {
472 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
473 let x = array![[1.0]];
474 assert!(st.fit(&x, &()).is_err());
475 }
476
477 #[test]
478 fn test_spline_too_few_knots_error() {
479 let st = SplineTransformer::<f64>::new(1, 3, KnotStrategy::Uniform);
480 let x = array![[0.0], [1.0]];
481 assert!(st.fit(&x, &()).is_err());
482 }
483
484 #[test]
485 fn test_spline_zero_degree_error() {
486 let st = SplineTransformer::<f64>::new(5, 0, KnotStrategy::Uniform);
487 let x = array![[0.0], [1.0]];
488 assert!(st.fit(&x, &()).is_err());
489 }
490
491 #[test]
492 fn test_spline_shape_mismatch_error() {
493 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
494 let x_train = array![[0.0, 1.0], [0.5, 1.5]];
495 let fitted = st.fit(&x_train, &()).unwrap();
496 let x_bad = array![[0.0]];
497 assert!(fitted.transform(&x_bad).is_err());
498 }
499
500 #[test]
501 fn test_spline_unfitted_error() {
502 let st = SplineTransformer::<f64>::new(5, 3, KnotStrategy::Uniform);
503 let x = array![[0.0]];
504 assert!(st.transform(&x).is_err());
505 }
506
507 #[test]
508 fn test_spline_default() {
509 let st = SplineTransformer::<f64>::default();
510 assert_eq!(st.n_knots(), 5);
511 assert_eq!(st.degree(), 3);
512 assert_eq!(st.knot_strategy(), KnotStrategy::Uniform);
513 }
514
515 #[test]
516 fn test_spline_degree1() {
517 let st = SplineTransformer::<f64>::new(3, 1, KnotStrategy::Uniform);
519 let x = array![[0.0], [0.5], [1.0]];
520 let fitted = st.fit(&x, &()).unwrap();
521 let out = fitted.transform(&x).unwrap();
522 assert_eq!(out.ncols(), 3);
524 for i in 0..out.nrows() {
526 let row_sum: f64 = out.row(i).iter().sum();
527 assert_abs_diff_eq!(row_sum, 1.0, epsilon = 1e-10);
528 }
529 }
530}