ferrolearn_preprocess/column_transformer.rs
1//! Column transformer: apply different transformers to different column subsets.
2//!
3//! [`ColumnTransformer`] applies each registered transformer to its designated
4//! column subset, then horizontally concatenates the outputs into a single
5//! `Array2<f64>`. Columns not captured by any transformer can be dropped or
6//! passed through unchanged via the [`Remainder`] policy.
7//!
8//! # Examples
9//!
10//! ```
11//! use ferrolearn_preprocess::column_transformer::{
12//! ColumnSelector, ColumnTransformer, Remainder,
13//! };
14//! use ferrolearn_preprocess::{StandardScaler, MinMaxScaler};
15//! use ferrolearn_core::Fit;
16//! use ferrolearn_core::Transform;
17//! use ndarray::array;
18//!
19//! let x = array![
20//! [1.0_f64, 2.0, 10.0, 20.0],
21//! [3.0, 4.0, 30.0, 40.0],
22//! [5.0, 6.0, 50.0, 60.0],
23//! ];
24//!
25//! let ct = ColumnTransformer::new(
26//! vec![
27//! ("std".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1])),
28//! ("mm".into(), Box::new(MinMaxScaler::<f64>::new()), ColumnSelector::Indices(vec![2, 3])),
29//! ],
30//! Remainder::Drop,
31//! );
32//!
33//! let fitted = ct.fit(&x, &()).unwrap();
34//! let out = fitted.transform(&x).unwrap();
35//! assert_eq!(out.ncols(), 4);
36//! assert_eq!(out.nrows(), 3);
37//! ```
38
39use ferrolearn_core::error::FerroError;
40use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
41use ferrolearn_core::traits::{Fit, Transform};
42use ndarray::{Array1, Array2};
43
44// ---------------------------------------------------------------------------
45// ColumnSelector
46// ---------------------------------------------------------------------------
47
48/// Specifies which columns a transformer should operate on.
49///
50/// Currently the only supported variant is [`Indices`](ColumnSelector::Indices),
51/// which selects columns by their zero-based integer positions.
52#[derive(Debug, Clone)]
53pub enum ColumnSelector {
54 /// Select columns by zero-based index.
55 ///
56 /// The indices do not need to be sorted, but every index must be strictly
57 /// less than the number of columns in the input matrix. Duplicate indices
58 /// are allowed; the same column will simply appear twice in the sub-matrix
59 /// passed to the transformer.
60 Indices(Vec<usize>),
61}
62
63impl ColumnSelector {
64 /// Resolve the selector to a concrete list of column indices.
65 ///
66 /// # Errors
67 ///
68 /// Returns [`FerroError::InvalidParameter`] if any index is out of range
69 /// (i.e., `>= n_features`).
70 fn resolve(&self, n_features: usize) -> Result<Vec<usize>, FerroError> {
71 match self {
72 ColumnSelector::Indices(indices) => {
73 for &idx in indices {
74 if idx >= n_features {
75 return Err(FerroError::InvalidParameter {
76 name: "ColumnSelector::Indices".into(),
77 reason: format!(
78 "column index {idx} is out of range for input with {n_features} features"
79 ),
80 });
81 }
82 }
83 Ok(indices.clone())
84 }
85 }
86 }
87}
88
89// ---------------------------------------------------------------------------
90// Remainder
91// ---------------------------------------------------------------------------
92
93/// Policy for columns that are not selected by any transformer.
94///
95/// When at least one column is not covered by any registered transformer,
96/// `Remainder` determines what happens to those columns in the output.
97#[derive(Debug, Clone)]
98pub enum Remainder {
99 /// Discard remainder columns — they do not appear in the output.
100 Drop,
101 /// Pass remainder columns through unchanged, appended after all
102 /// transformer outputs.
103 Passthrough,
104}
105
106// ---------------------------------------------------------------------------
107// Helper: extract a sub-matrix by column indices
108// ---------------------------------------------------------------------------
109
110/// Build a new `Array2<f64>` containing only the columns at `indices`.
111///
112/// Columns are emitted in the order they appear in `indices`.
113fn select_columns(x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
114 let nrows = x.nrows();
115 let ncols = indices.len();
116 if ncols == 0 {
117 return Array2::zeros((nrows, 0));
118 }
119 let mut out = Array2::zeros((nrows, ncols));
120 for (new_j, &old_j) in indices.iter().enumerate() {
121 out.column_mut(new_j).assign(&x.column(old_j));
122 }
123 out
124}
125
126/// Horizontally concatenate a slice of `Array2<f64>` matrices.
127///
128/// All matrices must have the same number of rows.
129///
130/// # Errors
131///
132/// Returns [`FerroError::ShapeMismatch`] if row counts differ.
133fn hstack(matrices: &[Array2<f64>]) -> Result<Array2<f64>, FerroError> {
134 if matrices.is_empty() {
135 return Ok(Array2::zeros((0, 0)));
136 }
137 let nrows = matrices[0].nrows();
138 let total_cols: usize = matrices.iter().map(ndarray::ArrayBase::ncols).sum();
139
140 // Handle the case where the first matrix establishes nrows = 0 separately.
141 if total_cols == 0 {
142 return Ok(Array2::zeros((nrows, 0)));
143 }
144
145 let mut out = Array2::zeros((nrows, total_cols));
146 let mut col_offset = 0;
147 for m in matrices {
148 if m.nrows() != nrows {
149 return Err(FerroError::ShapeMismatch {
150 expected: vec![nrows, m.ncols()],
151 actual: vec![m.nrows(), m.ncols()],
152 context: "ColumnTransformer hstack: row count mismatch".into(),
153 });
154 }
155 let end = col_offset + m.ncols();
156 if m.ncols() > 0 {
157 out.slice_mut(ndarray::s![.., col_offset..end]).assign(m);
158 }
159 col_offset = end;
160 }
161 Ok(out)
162}
163
164// ---------------------------------------------------------------------------
165// ColumnTransformer (unfitted)
166// ---------------------------------------------------------------------------
167
168/// An unfitted column transformer.
169///
170/// Applies each registered transformer to its designated column subset, then
171/// horizontally concatenates all outputs. The [`Remainder`] policy controls
172/// what happens to columns not covered by any transformer.
173///
174/// # Transformer order
175///
176/// Transformers are applied and their outputs concatenated in the order they
177/// were registered. Remainder columns (when
178/// `remainder = `[`Remainder::Passthrough`]) are appended last.
179///
180/// # Overlapping selections
181///
182/// Each transformer receives its own copy of the selected columns, so
183/// overlapping `ColumnSelector`s are fully supported.
184///
185/// # Examples
186///
187/// ```
188/// use ferrolearn_preprocess::column_transformer::{
189/// ColumnSelector, ColumnTransformer, Remainder,
190/// };
191/// use ferrolearn_preprocess::StandardScaler;
192/// use ferrolearn_core::Fit;
193/// use ferrolearn_core::Transform;
194/// use ndarray::array;
195///
196/// let x = array![[1.0_f64, 10.0, 100.0], [2.0, 20.0, 200.0], [3.0, 30.0, 300.0]];
197/// let ct = ColumnTransformer::new(
198/// vec![("scaler".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
199/// Remainder::Passthrough,
200/// );
201/// let fitted = ct.fit(&x, &()).unwrap();
202/// let out = fitted.transform(&x).unwrap();
203/// // 2 scaled columns + 1 passthrough column
204/// assert_eq!(out.ncols(), 3);
205/// ```
206pub struct ColumnTransformer {
207 /// Named transformer steps with their column selectors.
208 transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
209 /// Policy for columns not covered by any transformer.
210 remainder: Remainder,
211}
212
213impl ColumnTransformer {
214 /// Create a new `ColumnTransformer`.
215 ///
216 /// # Parameters
217 ///
218 /// - `transformers`: A list of `(name, transformer, selector)` triples.
219 /// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
220 #[must_use]
221 pub fn new(
222 transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
223 remainder: Remainder,
224 ) -> Self {
225 Self {
226 transformers,
227 remainder,
228 }
229 }
230}
231
232// ---------------------------------------------------------------------------
233// Fit implementation
234// ---------------------------------------------------------------------------
235
236impl Fit<Array2<f64>, ()> for ColumnTransformer {
237 type Fitted = FittedColumnTransformer;
238 type Error = FerroError;
239
240 /// Fit each transformer on its selected column subset.
241 ///
242 /// Validates that all selected column indices are within bounds before
243 /// fitting any transformer.
244 ///
245 /// # Errors
246 ///
247 /// - [`FerroError::InvalidParameter`] if any column index is out of range.
248 /// - Propagates any error returned by an individual transformer's
249 /// `fit_pipeline` call.
250 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedColumnTransformer, FerroError> {
251 let n_features = x.ncols();
252 let n_rows = x.nrows();
253
254 // A dummy y vector required by PipelineTransformer::fit_pipeline.
255 let dummy_y = Array1::<f64>::zeros(n_rows);
256
257 // Resolve all selectors up front to validate indices eagerly.
258 let mut resolved_selectors: Vec<Vec<usize>> = Vec::with_capacity(self.transformers.len());
259 for (name, _, selector) in &self.transformers {
260 let indices = selector.resolve(n_features).map_err(|e| {
261 // Enrich the error with the transformer name.
262 FerroError::InvalidParameter {
263 name: format!("ColumnTransformer step '{name}'"),
264 reason: e.to_string(),
265 }
266 })?;
267 resolved_selectors.push(indices);
268 }
269
270 // Build the set of covered column indices (for remainder computation).
271 let covered: std::collections::HashSet<usize> = resolved_selectors
272 .iter()
273 .flat_map(|v| v.iter().copied())
274 .collect();
275
276 let remainder_indices: Vec<usize> =
277 (0..n_features).filter(|c| !covered.contains(c)).collect();
278
279 // Fit each transformer on its sub-matrix.
280 let mut fitted_transformers: Vec<FittedSubTransformer> =
281 Vec::with_capacity(self.transformers.len());
282
283 for ((name, transformer, _), indices) in self.transformers.iter().zip(resolved_selectors) {
284 let sub_x = select_columns(x, &indices);
285 let fitted = transformer.fit_pipeline(&sub_x, &dummy_y)?;
286 fitted_transformers.push((name.clone(), fitted, indices));
287 }
288
289 Ok(FittedColumnTransformer {
290 fitted_transformers,
291 remainder: self.remainder.clone(),
292 remainder_indices,
293 n_features_in: n_features,
294 })
295 }
296}
297
298// ---------------------------------------------------------------------------
299// PipelineTransformer implementation
300// ---------------------------------------------------------------------------
301
302impl PipelineTransformer<f64> for ColumnTransformer {
303 /// Fit the column transformer using the pipeline interface.
304 ///
305 /// The `y` argument is ignored; it exists only for API compatibility.
306 ///
307 /// # Errors
308 ///
309 /// Propagates errors from [`Fit::fit`].
310 fn fit_pipeline(
311 &self,
312 x: &Array2<f64>,
313 _y: &Array1<f64>,
314 ) -> Result<Box<dyn FittedPipelineTransformer<f64>>, FerroError> {
315 let fitted = self.fit(x, &())?;
316 Ok(Box::new(fitted))
317 }
318}
319
320// ---------------------------------------------------------------------------
321// FittedColumnTransformer
322// ---------------------------------------------------------------------------
323
324/// A named, fitted sub-transformer with its column indices.
325type FittedSubTransformer = (String, Box<dyn FittedPipelineTransformer<f64>>, Vec<usize>);
326
327/// A fitted column transformer holding fitted sub-transformers and metadata.
328///
329/// Created by calling [`Fit::fit`] on a [`ColumnTransformer`].
330/// Implements [`Transform<Array2<f64>>`] to apply the fitted transformers and
331/// concatenate their outputs, as well as [`FittedPipelineTransformer`] for use
332/// inside a [`ferrolearn_core::pipeline::Pipeline`].
333pub struct FittedColumnTransformer {
334 /// Fitted transformers with their associated column indices.
335 fitted_transformers: Vec<FittedSubTransformer>,
336 /// Remainder policy from the original [`ColumnTransformer`].
337 remainder: Remainder,
338 /// Column indices not covered by any transformer.
339 remainder_indices: Vec<usize>,
340 /// Number of input features seen during fitting.
341 n_features_in: usize,
342}
343
344impl FittedColumnTransformer {
345 /// Return the number of input features seen during fitting.
346 #[must_use]
347 pub fn n_features_in(&self) -> usize {
348 self.n_features_in
349 }
350
351 /// Return the names of all registered transformer steps.
352 #[must_use]
353 pub fn transformer_names(&self) -> Vec<&str> {
354 self.fitted_transformers
355 .iter()
356 .map(|(name, _, _)| name.as_str())
357 .collect()
358 }
359
360 /// Return the remainder column indices (columns not selected by any transformer).
361 #[must_use]
362 pub fn remainder_indices(&self) -> &[usize] {
363 &self.remainder_indices
364 }
365}
366
367// ---------------------------------------------------------------------------
368// Transform implementation
369// ---------------------------------------------------------------------------
370
371impl Transform<Array2<f64>> for FittedColumnTransformer {
372 type Output = Array2<f64>;
373 type Error = FerroError;
374
375 /// Transform data by applying each fitted transformer to its column subset,
376 /// then horizontally concatenating all outputs.
377 ///
378 /// When `remainder = Passthrough`, the unselected columns are appended
379 /// after all transformer outputs. When `remainder = Drop`, they are
380 /// discarded.
381 ///
382 /// # Errors
383 ///
384 /// - [`FerroError::ShapeMismatch`] if the input does not have
385 /// `n_features_in` columns.
386 /// - Propagates any error from individual transformer `transform_pipeline`
387 /// calls.
388 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
389 if x.ncols() != self.n_features_in {
390 return Err(FerroError::ShapeMismatch {
391 expected: vec![x.nrows(), self.n_features_in],
392 actual: vec![x.nrows(), x.ncols()],
393 context: "FittedColumnTransformer::transform".into(),
394 });
395 }
396
397 let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.fitted_transformers.len() + 1);
398
399 for (_, fitted, indices) in &self.fitted_transformers {
400 let sub_x = select_columns(x, indices);
401 let transformed = fitted.transform_pipeline(&sub_x)?;
402 parts.push(transformed);
403 }
404
405 // Append remainder columns if requested.
406 if matches!(self.remainder, Remainder::Passthrough) && !self.remainder_indices.is_empty() {
407 let remainder_sub = select_columns(x, &self.remainder_indices);
408 parts.push(remainder_sub);
409 }
410
411 hstack(&parts)
412 }
413}
414
415// ---------------------------------------------------------------------------
416// FittedPipelineTransformer implementation
417// ---------------------------------------------------------------------------
418
419impl FittedPipelineTransformer<f64> for FittedColumnTransformer {
420 /// Transform data using the pipeline interface.
421 ///
422 /// # Errors
423 ///
424 /// Propagates errors from [`Transform::transform`].
425 fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
426 self.transform(x)
427 }
428}
429
430// ---------------------------------------------------------------------------
431// make_column_transformer convenience function
432// ---------------------------------------------------------------------------
433
434/// Convenience function to build a [`ColumnTransformer`] with auto-generated
435/// step names.
436///
437/// Steps are named `"transformer-0"`, `"transformer-1"`, etc.
438///
439/// # Parameters
440///
441/// - `transformers`: A list of `(transformer, selector)` pairs.
442/// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
443///
444/// # Examples
445///
446/// ```
447/// use ferrolearn_preprocess::column_transformer::{
448/// make_column_transformer, ColumnSelector, Remainder,
449/// };
450/// use ferrolearn_preprocess::StandardScaler;
451/// use ferrolearn_core::Fit;
452/// use ferrolearn_core::Transform;
453/// use ndarray::array;
454///
455/// let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0]];
456/// let ct = make_column_transformer(
457/// vec![(Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
458/// Remainder::Drop,
459/// );
460/// let fitted = ct.fit(&x, &()).unwrap();
461/// let out = fitted.transform(&x).unwrap();
462/// assert_eq!(out.ncols(), 2);
463/// ```
464#[must_use]
465pub fn make_column_transformer(
466 transformers: Vec<(Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
467 remainder: Remainder,
468) -> ColumnTransformer {
469 let named: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)> = transformers
470 .into_iter()
471 .enumerate()
472 .map(|(i, (t, s))| (format!("transformer-{i}"), t, s))
473 .collect();
474 ColumnTransformer::new(named, remainder)
475}
476
477// ---------------------------------------------------------------------------
478// Tests
479// ---------------------------------------------------------------------------
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484 use approx::assert_abs_diff_eq;
485 use ferrolearn_core::pipeline::{Pipeline, PipelineEstimator};
486 use ndarray::{Array2, array};
487
488 use crate::{MinMaxScaler, StandardScaler};
489
490 // -----------------------------------------------------------------------
491 // Helpers
492 // -----------------------------------------------------------------------
493
494 /// Build a simple 4-column test matrix (rows = 4, cols = 4).
495 fn make_x() -> Array2<f64> {
496 array![
497 [1.0, 2.0, 10.0, 20.0],
498 [2.0, 4.0, 20.0, 40.0],
499 [3.0, 6.0, 30.0, 60.0],
500 [4.0, 8.0, 40.0, 80.0],
501 ]
502 }
503
504 // -----------------------------------------------------------------------
505 // 1. Basic 2-transformer usage
506 // -----------------------------------------------------------------------
507
508 #[test]
509 fn test_basic_two_transformers_drop_remainder() {
510 let x = make_x(); // 4×4
511 let ct = ColumnTransformer::new(
512 vec![
513 (
514 "std".into(),
515 Box::new(StandardScaler::<f64>::new()),
516 ColumnSelector::Indices(vec![0, 1]),
517 ),
518 (
519 "mm".into(),
520 Box::new(MinMaxScaler::<f64>::new()),
521 ColumnSelector::Indices(vec![2, 3]),
522 ),
523 ],
524 Remainder::Drop,
525 );
526
527 let fitted = ct.fit(&x, &()).unwrap();
528 let out = fitted.transform(&x).unwrap();
529
530 // All 4 columns covered → no remainder; output is 4 cols
531 assert_eq!(out.nrows(), 4);
532 assert_eq!(out.ncols(), 4);
533 }
534
535 // -----------------------------------------------------------------------
536 // 2. Remainder::Drop drops uncovered columns
537 // -----------------------------------------------------------------------
538
539 #[test]
540 fn test_remainder_drop() {
541 let x = make_x(); // 4×4
542 // Only cover cols 0 and 1 — cols 2 and 3 should be dropped.
543 let ct = ColumnTransformer::new(
544 vec![(
545 "std".into(),
546 Box::new(StandardScaler::<f64>::new()),
547 ColumnSelector::Indices(vec![0, 1]),
548 )],
549 Remainder::Drop,
550 );
551
552 let fitted = ct.fit(&x, &()).unwrap();
553 let out = fitted.transform(&x).unwrap();
554
555 assert_eq!(out.nrows(), 4);
556 assert_eq!(out.ncols(), 2, "uncovered cols should be dropped");
557 }
558
559 // -----------------------------------------------------------------------
560 // 3. Remainder::Passthrough passes uncovered columns through unchanged
561 // -----------------------------------------------------------------------
562
563 #[test]
564 fn test_remainder_passthrough() {
565 let x = make_x(); // 4×4
566 // Only cover cols 0 and 1 — cols 2 and 3 should pass through.
567 let ct = ColumnTransformer::new(
568 vec![(
569 "std".into(),
570 Box::new(StandardScaler::<f64>::new()),
571 ColumnSelector::Indices(vec![0, 1]),
572 )],
573 Remainder::Passthrough,
574 );
575
576 let fitted = ct.fit(&x, &()).unwrap();
577 let out = fitted.transform(&x).unwrap();
578
579 assert_eq!(out.nrows(), 4);
580 assert_eq!(out.ncols(), 4, "passthrough: 2 transformed + 2 remainder");
581
582 // The last 2 columns should be the original cols 2 and 3.
583 for i in 0..4 {
584 assert_abs_diff_eq!(out[[i, 2]], x[[i, 2]], epsilon = 1e-12);
585 assert_abs_diff_eq!(out[[i, 3]], x[[i, 3]], epsilon = 1e-12);
586 }
587 }
588
589 // -----------------------------------------------------------------------
590 // 4. Invalid column index (out of range)
591 // -----------------------------------------------------------------------
592
593 #[test]
594 fn test_invalid_column_index_out_of_range() {
595 let x = make_x(); // 4×4 — valid indices are 0..3
596 let ct = ColumnTransformer::new(
597 vec![(
598 "std".into(),
599 Box::new(StandardScaler::<f64>::new()),
600 ColumnSelector::Indices(vec![0, 99]), // 99 is out of range
601 )],
602 Remainder::Drop,
603 );
604 let result = ct.fit(&x, &());
605 assert!(result.is_err(), "expected error for out-of-range index");
606 }
607
608 // -----------------------------------------------------------------------
609 // 5. Empty transformer list with Remainder::Drop
610 // -----------------------------------------------------------------------
611
612 #[test]
613 fn test_empty_transformer_list_drop() {
614 let x = make_x();
615 let ct = ColumnTransformer::new(vec![], Remainder::Drop);
616 let fitted = ct.fit(&x, &()).unwrap();
617 let out = fitted.transform(&x).unwrap();
618 // No transformers, remainder dropped → empty output
619 assert_eq!(out.nrows(), 0, "hstack of nothing with no passthrough");
620 }
621
622 // -----------------------------------------------------------------------
623 // 6. Empty transformer list with Remainder::Passthrough
624 // -----------------------------------------------------------------------
625
626 #[test]
627 fn test_empty_transformer_list_passthrough() {
628 let x = make_x(); // 4×4
629 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
630 let fitted = ct.fit(&x, &()).unwrap();
631 let out = fitted.transform(&x).unwrap();
632 // No transformers, all columns pass through unchanged.
633 assert_eq!(out.nrows(), 4);
634 assert_eq!(out.ncols(), 4);
635 for i in 0..4 {
636 for j in 0..4 {
637 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
638 }
639 }
640 }
641
642 // -----------------------------------------------------------------------
643 // 7. Overlapping column selections
644 // -----------------------------------------------------------------------
645
646 #[test]
647 fn test_overlapping_column_selections() {
648 let x = make_x(); // 4×4
649 // Both transformers select col 0 (overlapping is allowed).
650 let ct = ColumnTransformer::new(
651 vec![
652 (
653 "std1".into(),
654 Box::new(StandardScaler::<f64>::new()),
655 ColumnSelector::Indices(vec![0, 1]),
656 ),
657 (
658 "mm1".into(),
659 Box::new(MinMaxScaler::<f64>::new()),
660 ColumnSelector::Indices(vec![0, 2]), // col 0 also used here
661 ),
662 ],
663 Remainder::Drop,
664 );
665
666 let fitted = ct.fit(&x, &()).unwrap();
667 let out = fitted.transform(&x).unwrap();
668
669 // Output: 2 cols from std1 + 2 cols from mm1 = 4 cols
670 assert_eq!(out.nrows(), 4);
671 assert_eq!(out.ncols(), 4);
672 }
673
674 // -----------------------------------------------------------------------
675 // 8. Single transformer
676 // -----------------------------------------------------------------------
677
678 #[test]
679 fn test_single_transformer() {
680 let x = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
681 let ct = ColumnTransformer::new(
682 vec![(
683 "mm".into(),
684 Box::new(MinMaxScaler::<f64>::new()),
685 ColumnSelector::Indices(vec![0, 1]),
686 )],
687 Remainder::Drop,
688 );
689
690 let fitted = ct.fit(&x, &()).unwrap();
691 let out = fitted.transform(&x).unwrap();
692
693 assert_eq!(out.nrows(), 3);
694 assert_eq!(out.ncols(), 2);
695
696 // MinMax on cols 0 and 1: first row → 0.0, last row → 1.0
697 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
698 assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10);
699 assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10);
700 assert_abs_diff_eq!(out[[2, 1]], 1.0, epsilon = 1e-10);
701 }
702
703 // -----------------------------------------------------------------------
704 // 9. make_column_transformer convenience function
705 // -----------------------------------------------------------------------
706
707 #[test]
708 fn test_make_column_transformer_auto_names() {
709 let x = make_x();
710 let ct = make_column_transformer(
711 vec![
712 (
713 Box::new(StandardScaler::<f64>::new()),
714 ColumnSelector::Indices(vec![0, 1]),
715 ),
716 (
717 Box::new(MinMaxScaler::<f64>::new()),
718 ColumnSelector::Indices(vec![2, 3]),
719 ),
720 ],
721 Remainder::Drop,
722 );
723
724 let fitted = ct.fit(&x, &()).unwrap();
725 assert_eq!(
726 fitted.transformer_names(),
727 vec!["transformer-0", "transformer-1"]
728 );
729
730 let out = fitted.transform(&x).unwrap();
731 assert_eq!(out.nrows(), 4);
732 assert_eq!(out.ncols(), 4);
733 }
734
735 // -----------------------------------------------------------------------
736 // 10. Pipeline integration
737 // -----------------------------------------------------------------------
738
739 #[test]
740 fn test_pipeline_integration() {
741 // Wrap a ColumnTransformer as a pipeline step.
742 let x = make_x();
743 let y = Array1::<f64>::zeros(4);
744
745 let ct = ColumnTransformer::new(
746 vec![(
747 "std".into(),
748 Box::new(StandardScaler::<f64>::new()),
749 ColumnSelector::Indices(vec![0, 1, 2, 3]),
750 )],
751 Remainder::Drop,
752 );
753
754 // Use a trivial estimator that sums rows.
755 struct SumEstimator;
756 impl PipelineEstimator<f64> for SumEstimator {
757 fn fit_pipeline(
758 &self,
759 _x: &Array2<f64>,
760 _y: &Array1<f64>,
761 ) -> Result<Box<dyn ferrolearn_core::pipeline::FittedPipelineEstimator<f64>>, FerroError>
762 {
763 Ok(Box::new(FittedSum))
764 }
765 }
766 struct FittedSum;
767 impl ferrolearn_core::pipeline::FittedPipelineEstimator<f64> for FittedSum {
768 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
769 let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
770 Ok(Array1::from_vec(sums))
771 }
772 }
773
774 let pipeline = Pipeline::new()
775 .transform_step("ct", Box::new(ct))
776 .estimator_step("sum", Box::new(SumEstimator));
777
778 use ferrolearn_core::Fit as _;
779 let fitted_pipeline = pipeline.fit(&x, &y).unwrap();
780
781 use ferrolearn_core::Predict as _;
782 let preds = fitted_pipeline.predict(&x).unwrap();
783 assert_eq!(preds.len(), 4);
784 }
785
786 // -----------------------------------------------------------------------
787 // 11. Transform shape correctness — number of output columns
788 // -----------------------------------------------------------------------
789
790 #[test]
791 fn test_output_shape_all_selected_drop() {
792 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
793 let ct = ColumnTransformer::new(
794 vec![
795 (
796 "s".into(),
797 Box::new(StandardScaler::<f64>::new()),
798 ColumnSelector::Indices(vec![0]),
799 ),
800 (
801 "m".into(),
802 Box::new(MinMaxScaler::<f64>::new()),
803 ColumnSelector::Indices(vec![1, 2]),
804 ),
805 ],
806 Remainder::Drop,
807 );
808 let fitted = ct.fit(&x, &()).unwrap();
809 let out = fitted.transform(&x).unwrap();
810 assert_eq!(out.shape(), &[2, 3]);
811 }
812
813 // -----------------------------------------------------------------------
814 // 12. Transform shape — partial selection + passthrough
815 // -----------------------------------------------------------------------
816
817 #[test]
818 fn test_output_shape_partial_passthrough() {
819 // 5-column input, transform 2 cols, passthrough 3
820 let x = Array2::<f64>::from_shape_vec((3, 5), (1..=15).map(f64::from).collect()).unwrap();
821 let ct = ColumnTransformer::new(
822 vec![(
823 "std".into(),
824 Box::new(StandardScaler::<f64>::new()),
825 ColumnSelector::Indices(vec![0, 1]),
826 )],
827 Remainder::Passthrough,
828 );
829 let fitted = ct.fit(&x, &()).unwrap();
830 let out = fitted.transform(&x).unwrap();
831 assert_eq!(out.shape(), &[3, 5]);
832 }
833
834 // -----------------------------------------------------------------------
835 // 13. n_features_in accessor
836 // -----------------------------------------------------------------------
837
838 #[test]
839 fn test_n_features_in() {
840 let x = make_x(); // 4×4
841 let ct = ColumnTransformer::new(
842 vec![(
843 "std".into(),
844 Box::new(StandardScaler::<f64>::new()),
845 ColumnSelector::Indices(vec![0]),
846 )],
847 Remainder::Drop,
848 );
849 let fitted = ct.fit(&x, &()).unwrap();
850 assert_eq!(fitted.n_features_in(), 4);
851 }
852
853 // -----------------------------------------------------------------------
854 // 14. Shape mismatch on transform (wrong number of columns)
855 // -----------------------------------------------------------------------
856
857 #[test]
858 fn test_shape_mismatch_on_transform() {
859 let x = make_x(); // 4×4
860 let ct = ColumnTransformer::new(
861 vec![(
862 "std".into(),
863 Box::new(StandardScaler::<f64>::new()),
864 ColumnSelector::Indices(vec![0, 1]),
865 )],
866 Remainder::Drop,
867 );
868 let fitted = ct.fit(&x, &()).unwrap();
869
870 // Now pass a matrix with only 2 columns — should fail.
871 let x_bad = array![[1.0_f64, 2.0], [3.0, 4.0]];
872 let result = fitted.transform(&x_bad);
873 assert!(result.is_err(), "expected shape mismatch error");
874 }
875
876 // -----------------------------------------------------------------------
877 // 15. remainder_indices accessor
878 // -----------------------------------------------------------------------
879
880 #[test]
881 fn test_remainder_indices_accessor() {
882 let x = make_x(); // 4×4
883 let ct = ColumnTransformer::new(
884 vec![(
885 "std".into(),
886 Box::new(StandardScaler::<f64>::new()),
887 ColumnSelector::Indices(vec![0, 2]),
888 )],
889 Remainder::Passthrough,
890 );
891 let fitted = ct.fit(&x, &()).unwrap();
892 // Remainder should be cols 1 and 3.
893 assert_eq!(fitted.remainder_indices(), &[1, 3]);
894 }
895
896 // -----------------------------------------------------------------------
897 // 16. StandardScaler output values are correct (zero-mean)
898 // -----------------------------------------------------------------------
899
900 #[test]
901 fn test_standard_scaler_zero_mean_in_output() {
902 let x = array![[1.0_f64, 100.0, 0.5], [2.0, 200.0, 1.5], [3.0, 300.0, 2.5],];
903 let ct = ColumnTransformer::new(
904 vec![(
905 "std".into(),
906 Box::new(StandardScaler::<f64>::new()),
907 ColumnSelector::Indices(vec![0, 1]),
908 )],
909 Remainder::Drop,
910 );
911 let fitted = ct.fit(&x, &()).unwrap();
912 let out = fitted.transform(&x).unwrap();
913
914 // Cols 0 and 1 of output should have mean ≈ 0.
915 for j in 0..2 {
916 let mean: f64 = out.column(j).iter().sum::<f64>() / 3.0;
917 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
918 }
919 }
920
921 // -----------------------------------------------------------------------
922 // 17. MinMaxScaler output values are in [0, 1]
923 // -----------------------------------------------------------------------
924
925 #[test]
926 fn test_min_max_values_in_range() {
927 let x = make_x();
928 let ct = ColumnTransformer::new(
929 vec![(
930 "mm".into(),
931 Box::new(MinMaxScaler::<f64>::new()),
932 ColumnSelector::Indices(vec![0, 1, 2, 3]),
933 )],
934 Remainder::Drop,
935 );
936 let fitted = ct.fit(&x, &()).unwrap();
937 let out = fitted.transform(&x).unwrap();
938
939 for j in 0..4 {
940 let col_min = out.column(j).iter().copied().fold(f64::INFINITY, f64::min);
941 let col_max = out
942 .column(j)
943 .iter()
944 .copied()
945 .fold(f64::NEG_INFINITY, f64::max);
946 assert_abs_diff_eq!(col_min, 0.0, epsilon = 1e-10);
947 assert_abs_diff_eq!(col_max, 1.0, epsilon = 1e-10);
948 }
949 }
950
951 // -----------------------------------------------------------------------
952 // 18. Pipeline transformer interface (fit_pipeline / transform_pipeline)
953 // -----------------------------------------------------------------------
954
955 #[test]
956 fn test_pipeline_transformer_interface() {
957 let x = make_x();
958 let y = Array1::<f64>::zeros(4);
959 let ct = ColumnTransformer::new(
960 vec![(
961 "std".into(),
962 Box::new(StandardScaler::<f64>::new()),
963 ColumnSelector::Indices(vec![0, 1]),
964 )],
965 Remainder::Passthrough,
966 );
967 let fitted_box = ct.fit_pipeline(&x, &y).unwrap();
968 let out = fitted_box.transform_pipeline(&x).unwrap();
969 assert_eq!(out.nrows(), 4);
970 assert_eq!(out.ncols(), 4);
971 }
972
973 // -----------------------------------------------------------------------
974 // 19. Remainder passthrough values are identical to input values
975 // -----------------------------------------------------------------------
976
977 #[test]
978 fn test_passthrough_values_are_exact() {
979 let x = array![[10.0_f64, 20.0, 30.0], [40.0, 50.0, 60.0],];
980 // Only transform col 0; cols 1 and 2 pass through.
981 let ct = ColumnTransformer::new(
982 vec![(
983 "mm".into(),
984 Box::new(MinMaxScaler::<f64>::new()),
985 ColumnSelector::Indices(vec![0]),
986 )],
987 Remainder::Passthrough,
988 );
989 let fitted = ct.fit(&x, &()).unwrap();
990 let out = fitted.transform(&x).unwrap();
991 // out[:, 1] == x[:, 1] and out[:, 2] == x[:, 2]
992 assert_abs_diff_eq!(out[[0, 1]], 20.0, epsilon = 1e-12);
993 assert_abs_diff_eq!(out[[1, 1]], 50.0, epsilon = 1e-12);
994 assert_abs_diff_eq!(out[[0, 2]], 30.0, epsilon = 1e-12);
995 assert_abs_diff_eq!(out[[1, 2]], 60.0, epsilon = 1e-12);
996 }
997
998 // -----------------------------------------------------------------------
999 // 20. Transformer names from explicit ColumnTransformer::new
1000 // -----------------------------------------------------------------------
1001
1002 #[test]
1003 fn test_transformer_names_explicit() {
1004 let x = make_x();
1005 let ct = ColumnTransformer::new(
1006 vec![
1007 (
1008 "alpha".into(),
1009 Box::new(StandardScaler::<f64>::new()),
1010 ColumnSelector::Indices(vec![0]),
1011 ),
1012 (
1013 "beta".into(),
1014 Box::new(MinMaxScaler::<f64>::new()),
1015 ColumnSelector::Indices(vec![1]),
1016 ),
1017 ],
1018 Remainder::Drop,
1019 );
1020 let fitted = ct.fit(&x, &()).unwrap();
1021 assert_eq!(fitted.transformer_names(), vec!["alpha", "beta"]);
1022 }
1023
1024 // -----------------------------------------------------------------------
1025 // 21. make_column_transformer with single step
1026 // -----------------------------------------------------------------------
1027
1028 #[test]
1029 fn test_make_column_transformer_single() {
1030 let x = array![[1.0_f64, 2.0], [3.0, 4.0]];
1031 let ct = make_column_transformer(
1032 vec![(
1033 Box::new(StandardScaler::<f64>::new()),
1034 ColumnSelector::Indices(vec![0, 1]),
1035 )],
1036 Remainder::Drop,
1037 );
1038 let fitted = ct.fit(&x, &()).unwrap();
1039 assert_eq!(fitted.transformer_names(), vec!["transformer-0"]);
1040 let out = fitted.transform(&x).unwrap();
1041 assert_eq!(out.shape(), &[2, 2]);
1042 }
1043
1044 // -----------------------------------------------------------------------
1045 // 22. Edge case: all columns as remainder with Passthrough
1046 // -----------------------------------------------------------------------
1047
1048 #[test]
1049 fn test_all_remainder_passthrough_unchanged() {
1050 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1051 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
1052 let fitted = ct.fit(&x, &()).unwrap();
1053 let out = fitted.transform(&x).unwrap();
1054 assert_eq!(out.shape(), &[2, 3]);
1055 for i in 0..2 {
1056 for j in 0..3 {
1057 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
1058 }
1059 }
1060 }
1061}