ferrolearn_preprocess/column_transformer.rs
1//! Column transformer: apply different transformers to different column subsets.
2//!
3//! [`ColumnTransformer`] applies each registered transformer to its designated
4//! column subset, then horizontally concatenates the outputs into a single
5//! `Array2<f64>`. Columns not captured by any transformer can be dropped or
6//! passed through unchanged via the [`Remainder`] policy.
7//!
8//! # Examples
9//!
10//! ```
11//! use ferrolearn_preprocess::column_transformer::{
12//! ColumnSelector, ColumnTransformer, Remainder,
13//! };
14//! use ferrolearn_preprocess::{StandardScaler, MinMaxScaler};
15//! use ferrolearn_core::Fit;
16//! use ferrolearn_core::Transform;
17//! use ndarray::array;
18//!
19//! let x = array![
20//! [1.0_f64, 2.0, 10.0, 20.0],
21//! [3.0, 4.0, 30.0, 40.0],
22//! [5.0, 6.0, 50.0, 60.0],
23//! ];
24//!
25//! let ct = ColumnTransformer::new(
26//! vec![
27//! ("std".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1])),
28//! ("mm".into(), Box::new(MinMaxScaler::<f64>::new()), ColumnSelector::Indices(vec![2, 3])),
29//! ],
30//! Remainder::Drop,
31//! );
32//!
33//! let fitted = ct.fit(&x, &()).unwrap();
34//! let out = fitted.transform(&x).unwrap();
35//! assert_eq!(out.ncols(), 4);
36//! assert_eq!(out.nrows(), 3);
37//! ```
38
39use ferrolearn_core::error::FerroError;
40use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
41use ferrolearn_core::traits::{Fit, Transform};
42use ndarray::{Array1, Array2};
43
44// ---------------------------------------------------------------------------
45// ColumnSelector
46// ---------------------------------------------------------------------------
47
48/// Specifies which columns a transformer should operate on.
49///
50/// Currently the only supported variant is [`Indices`](ColumnSelector::Indices),
51/// which selects columns by their zero-based integer positions.
52#[derive(Debug, Clone)]
53pub enum ColumnSelector {
54 /// Select columns by zero-based index.
55 ///
56 /// The indices do not need to be sorted, but every index must be strictly
57 /// less than the number of columns in the input matrix. Duplicate indices
58 /// are allowed; the same column will simply appear twice in the sub-matrix
59 /// passed to the transformer.
60 Indices(Vec<usize>),
61}
62
63impl ColumnSelector {
64 /// Resolve the selector to a concrete list of column indices.
65 ///
66 /// # Errors
67 ///
68 /// Returns [`FerroError::InvalidParameter`] if any index is out of range
69 /// (i.e., `>= n_features`).
70 fn resolve(&self, n_features: usize) -> Result<Vec<usize>, FerroError> {
71 match self {
72 ColumnSelector::Indices(indices) => {
73 for &idx in indices {
74 if idx >= n_features {
75 return Err(FerroError::InvalidParameter {
76 name: "ColumnSelector::Indices".into(),
77 reason: format!(
78 "column index {idx} is out of range for input with {n_features} features"
79 ),
80 });
81 }
82 }
83 Ok(indices.clone())
84 }
85 }
86 }
87}
88
89// ---------------------------------------------------------------------------
90// Remainder
91// ---------------------------------------------------------------------------
92
93/// Policy for columns that are not selected by any transformer.
94///
95/// When at least one column is not covered by any registered transformer,
96/// `Remainder` determines what happens to those columns in the output.
97#[derive(Debug, Clone)]
98pub enum Remainder {
99 /// Discard remainder columns — they do not appear in the output.
100 Drop,
101 /// Pass remainder columns through unchanged, appended after all
102 /// transformer outputs.
103 Passthrough,
104}
105
106// ---------------------------------------------------------------------------
107// Helper: extract a sub-matrix by column indices
108// ---------------------------------------------------------------------------
109
110/// Build a new `Array2<f64>` containing only the columns at `indices`.
111///
112/// Columns are emitted in the order they appear in `indices`.
113fn select_columns(x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
114 let nrows = x.nrows();
115 let ncols = indices.len();
116 if ncols == 0 {
117 return Array2::zeros((nrows, 0));
118 }
119 let mut out = Array2::zeros((nrows, ncols));
120 for (new_j, &old_j) in indices.iter().enumerate() {
121 out.column_mut(new_j).assign(&x.column(old_j));
122 }
123 out
124}
125
126/// Horizontally concatenate a slice of `Array2<f64>` matrices.
127///
128/// All matrices must have the same number of rows.
129///
130/// # Errors
131///
132/// Returns [`FerroError::ShapeMismatch`] if row counts differ.
133fn hstack(matrices: &[Array2<f64>]) -> Result<Array2<f64>, FerroError> {
134 if matrices.is_empty() {
135 return Ok(Array2::zeros((0, 0)));
136 }
137 let nrows = matrices[0].nrows();
138 let total_cols: usize = matrices.iter().map(|m| m.ncols()).sum();
139
140 // Handle the case where the first matrix establishes nrows = 0 separately.
141 if total_cols == 0 {
142 return Ok(Array2::zeros((nrows, 0)));
143 }
144
145 let mut out = Array2::zeros((nrows, total_cols));
146 let mut col_offset = 0;
147 for m in matrices {
148 if m.nrows() != nrows {
149 return Err(FerroError::ShapeMismatch {
150 expected: vec![nrows, m.ncols()],
151 actual: vec![m.nrows(), m.ncols()],
152 context: "ColumnTransformer hstack: row count mismatch".into(),
153 });
154 }
155 let end = col_offset + m.ncols();
156 if m.ncols() > 0 {
157 out.slice_mut(ndarray::s![.., col_offset..end]).assign(m);
158 }
159 col_offset = end;
160 }
161 Ok(out)
162}
163
164// ---------------------------------------------------------------------------
165// ColumnTransformer (unfitted)
166// ---------------------------------------------------------------------------
167
168/// An unfitted column transformer.
169///
170/// Applies each registered transformer to its designated column subset, then
171/// horizontally concatenates all outputs. The [`Remainder`] policy controls
172/// what happens to columns not covered by any transformer.
173///
174/// # Transformer order
175///
176/// Transformers are applied and their outputs concatenated in the order they
177/// were registered. Remainder columns (when
178/// `remainder = `[`Remainder::Passthrough`]) are appended last.
179///
180/// # Overlapping selections
181///
182/// Each transformer receives its own copy of the selected columns, so
183/// overlapping `ColumnSelector`s are fully supported.
184///
185/// # Examples
186///
187/// ```
188/// use ferrolearn_preprocess::column_transformer::{
189/// ColumnSelector, ColumnTransformer, Remainder,
190/// };
191/// use ferrolearn_preprocess::StandardScaler;
192/// use ferrolearn_core::Fit;
193/// use ferrolearn_core::Transform;
194/// use ndarray::array;
195///
196/// let x = array![[1.0_f64, 10.0, 100.0], [2.0, 20.0, 200.0], [3.0, 30.0, 300.0]];
197/// let ct = ColumnTransformer::new(
198/// vec![("scaler".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
199/// Remainder::Passthrough,
200/// );
201/// let fitted = ct.fit(&x, &()).unwrap();
202/// let out = fitted.transform(&x).unwrap();
203/// // 2 scaled columns + 1 passthrough column
204/// assert_eq!(out.ncols(), 3);
205/// ```
206pub struct ColumnTransformer {
207 /// Named transformer steps with their column selectors.
208 transformers: Vec<(String, Box<dyn PipelineTransformer>, ColumnSelector)>,
209 /// Policy for columns not covered by any transformer.
210 remainder: Remainder,
211}
212
213impl ColumnTransformer {
214 /// Create a new `ColumnTransformer`.
215 ///
216 /// # Parameters
217 ///
218 /// - `transformers`: A list of `(name, transformer, selector)` triples.
219 /// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
220 #[must_use]
221 pub fn new(
222 transformers: Vec<(String, Box<dyn PipelineTransformer>, ColumnSelector)>,
223 remainder: Remainder,
224 ) -> Self {
225 Self {
226 transformers,
227 remainder,
228 }
229 }
230}
231
232// ---------------------------------------------------------------------------
233// Fit implementation
234// ---------------------------------------------------------------------------
235
236impl Fit<Array2<f64>, ()> for ColumnTransformer {
237 type Fitted = FittedColumnTransformer;
238 type Error = FerroError;
239
240 /// Fit each transformer on its selected column subset.
241 ///
242 /// Validates that all selected column indices are within bounds before
243 /// fitting any transformer.
244 ///
245 /// # Errors
246 ///
247 /// - [`FerroError::InvalidParameter`] if any column index is out of range.
248 /// - Propagates any error returned by an individual transformer's
249 /// `fit_pipeline` call.
250 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedColumnTransformer, FerroError> {
251 let n_features = x.ncols();
252 let n_rows = x.nrows();
253
254 // A dummy y vector required by PipelineTransformer::fit_pipeline.
255 let dummy_y = Array1::<f64>::zeros(n_rows);
256
257 // Resolve all selectors up front to validate indices eagerly.
258 let mut resolved_selectors: Vec<Vec<usize>> = Vec::with_capacity(self.transformers.len());
259 for (name, _, selector) in &self.transformers {
260 let indices = selector.resolve(n_features).map_err(|e| {
261 // Enrich the error with the transformer name.
262 FerroError::InvalidParameter {
263 name: format!("ColumnTransformer step '{name}'"),
264 reason: e.to_string(),
265 }
266 })?;
267 resolved_selectors.push(indices);
268 }
269
270 // Build the set of covered column indices (for remainder computation).
271 let covered: std::collections::HashSet<usize> = resolved_selectors
272 .iter()
273 .flat_map(|v| v.iter().copied())
274 .collect();
275
276 let remainder_indices: Vec<usize> =
277 (0..n_features).filter(|c| !covered.contains(c)).collect();
278
279 // Fit each transformer on its sub-matrix.
280 let mut fitted_transformers: Vec<(String, Box<dyn FittedPipelineTransformer>, Vec<usize>)> =
281 Vec::with_capacity(self.transformers.len());
282
283 for ((name, transformer, _), indices) in
284 self.transformers.iter().zip(resolved_selectors.into_iter())
285 {
286 let sub_x = select_columns(x, &indices);
287 let fitted = transformer.fit_pipeline(&sub_x, &dummy_y)?;
288 fitted_transformers.push((name.clone(), fitted, indices));
289 }
290
291 Ok(FittedColumnTransformer {
292 fitted_transformers,
293 remainder: self.remainder.clone(),
294 remainder_indices,
295 n_features_in: n_features,
296 })
297 }
298}
299
300// ---------------------------------------------------------------------------
301// PipelineTransformer implementation
302// ---------------------------------------------------------------------------
303
304impl PipelineTransformer for ColumnTransformer {
305 /// Fit the column transformer using the pipeline interface.
306 ///
307 /// The `y` argument is ignored; it exists only for API compatibility.
308 ///
309 /// # Errors
310 ///
311 /// Propagates errors from [`Fit::fit`].
312 fn fit_pipeline(
313 &self,
314 x: &Array2<f64>,
315 _y: &Array1<f64>,
316 ) -> Result<Box<dyn FittedPipelineTransformer>, FerroError> {
317 let fitted = self.fit(x, &())?;
318 Ok(Box::new(fitted))
319 }
320}
321
322// ---------------------------------------------------------------------------
323// FittedColumnTransformer
324// ---------------------------------------------------------------------------
325
326/// A fitted column transformer holding fitted sub-transformers and metadata.
327///
328/// Created by calling [`Fit::fit`] on a [`ColumnTransformer`].
329/// Implements [`Transform<Array2<f64>>`] to apply the fitted transformers and
330/// concatenate their outputs, as well as [`FittedPipelineTransformer`] for use
331/// inside a [`ferrolearn_core::pipeline::Pipeline`].
332pub struct FittedColumnTransformer {
333 /// Fitted transformers with their associated column indices.
334 fitted_transformers: Vec<(String, Box<dyn FittedPipelineTransformer>, Vec<usize>)>,
335 /// Remainder policy from the original [`ColumnTransformer`].
336 remainder: Remainder,
337 /// Column indices not covered by any transformer.
338 remainder_indices: Vec<usize>,
339 /// Number of input features seen during fitting.
340 n_features_in: usize,
341}
342
343impl FittedColumnTransformer {
344 /// Return the number of input features seen during fitting.
345 #[must_use]
346 pub fn n_features_in(&self) -> usize {
347 self.n_features_in
348 }
349
350 /// Return the names of all registered transformer steps.
351 #[must_use]
352 pub fn transformer_names(&self) -> Vec<&str> {
353 self.fitted_transformers
354 .iter()
355 .map(|(name, _, _)| name.as_str())
356 .collect()
357 }
358
359 /// Return the remainder column indices (columns not selected by any transformer).
360 #[must_use]
361 pub fn remainder_indices(&self) -> &[usize] {
362 &self.remainder_indices
363 }
364}
365
366// ---------------------------------------------------------------------------
367// Transform implementation
368// ---------------------------------------------------------------------------
369
370impl Transform<Array2<f64>> for FittedColumnTransformer {
371 type Output = Array2<f64>;
372 type Error = FerroError;
373
374 /// Transform data by applying each fitted transformer to its column subset,
375 /// then horizontally concatenating all outputs.
376 ///
377 /// When `remainder = Passthrough`, the unselected columns are appended
378 /// after all transformer outputs. When `remainder = Drop`, they are
379 /// discarded.
380 ///
381 /// # Errors
382 ///
383 /// - [`FerroError::ShapeMismatch`] if the input does not have
384 /// `n_features_in` columns.
385 /// - Propagates any error from individual transformer `transform_pipeline`
386 /// calls.
387 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
388 if x.ncols() != self.n_features_in {
389 return Err(FerroError::ShapeMismatch {
390 expected: vec![x.nrows(), self.n_features_in],
391 actual: vec![x.nrows(), x.ncols()],
392 context: "FittedColumnTransformer::transform".into(),
393 });
394 }
395
396 let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.fitted_transformers.len() + 1);
397
398 for (_, fitted, indices) in &self.fitted_transformers {
399 let sub_x = select_columns(x, indices);
400 let transformed = fitted.transform_pipeline(&sub_x)?;
401 parts.push(transformed);
402 }
403
404 // Append remainder columns if requested.
405 if matches!(self.remainder, Remainder::Passthrough) && !self.remainder_indices.is_empty() {
406 let remainder_sub = select_columns(x, &self.remainder_indices);
407 parts.push(remainder_sub);
408 }
409
410 hstack(&parts)
411 }
412}
413
414// ---------------------------------------------------------------------------
415// FittedPipelineTransformer implementation
416// ---------------------------------------------------------------------------
417
418impl FittedPipelineTransformer for FittedColumnTransformer {
419 /// Transform data using the pipeline interface.
420 ///
421 /// # Errors
422 ///
423 /// Propagates errors from [`Transform::transform`].
424 fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
425 self.transform(x)
426 }
427}
428
429// ---------------------------------------------------------------------------
430// make_column_transformer convenience function
431// ---------------------------------------------------------------------------
432
433/// Convenience function to build a [`ColumnTransformer`] with auto-generated
434/// step names.
435///
436/// Steps are named `"transformer-0"`, `"transformer-1"`, etc.
437///
438/// # Parameters
439///
440/// - `transformers`: A list of `(transformer, selector)` pairs.
441/// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
442///
443/// # Examples
444///
445/// ```
446/// use ferrolearn_preprocess::column_transformer::{
447/// make_column_transformer, ColumnSelector, Remainder,
448/// };
449/// use ferrolearn_preprocess::StandardScaler;
450/// use ferrolearn_core::Fit;
451/// use ferrolearn_core::Transform;
452/// use ndarray::array;
453///
454/// let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0]];
455/// let ct = make_column_transformer(
456/// vec![(Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
457/// Remainder::Drop,
458/// );
459/// let fitted = ct.fit(&x, &()).unwrap();
460/// let out = fitted.transform(&x).unwrap();
461/// assert_eq!(out.ncols(), 2);
462/// ```
463#[must_use]
464pub fn make_column_transformer(
465 transformers: Vec<(Box<dyn PipelineTransformer>, ColumnSelector)>,
466 remainder: Remainder,
467) -> ColumnTransformer {
468 let named: Vec<(String, Box<dyn PipelineTransformer>, ColumnSelector)> = transformers
469 .into_iter()
470 .enumerate()
471 .map(|(i, (t, s))| (format!("transformer-{i}"), t, s))
472 .collect();
473 ColumnTransformer::new(named, remainder)
474}
475
476// ---------------------------------------------------------------------------
477// Tests
478// ---------------------------------------------------------------------------
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use approx::assert_abs_diff_eq;
484 use ferrolearn_core::pipeline::{Pipeline, PipelineEstimator};
485 use ndarray::{Array2, array};
486
487 use crate::{MinMaxScaler, StandardScaler};
488
489 // -----------------------------------------------------------------------
490 // Helpers
491 // -----------------------------------------------------------------------
492
493 /// Build a simple 4-column test matrix (rows = 4, cols = 4).
494 fn make_x() -> Array2<f64> {
495 array![
496 [1.0, 2.0, 10.0, 20.0],
497 [2.0, 4.0, 20.0, 40.0],
498 [3.0, 6.0, 30.0, 60.0],
499 [4.0, 8.0, 40.0, 80.0],
500 ]
501 }
502
503 // -----------------------------------------------------------------------
504 // 1. Basic 2-transformer usage
505 // -----------------------------------------------------------------------
506
507 #[test]
508 fn test_basic_two_transformers_drop_remainder() {
509 let x = make_x(); // 4×4
510 let ct = ColumnTransformer::new(
511 vec![
512 (
513 "std".into(),
514 Box::new(StandardScaler::<f64>::new()),
515 ColumnSelector::Indices(vec![0, 1]),
516 ),
517 (
518 "mm".into(),
519 Box::new(MinMaxScaler::<f64>::new()),
520 ColumnSelector::Indices(vec![2, 3]),
521 ),
522 ],
523 Remainder::Drop,
524 );
525
526 let fitted = ct.fit(&x, &()).unwrap();
527 let out = fitted.transform(&x).unwrap();
528
529 // All 4 columns covered → no remainder; output is 4 cols
530 assert_eq!(out.nrows(), 4);
531 assert_eq!(out.ncols(), 4);
532 }
533
534 // -----------------------------------------------------------------------
535 // 2. Remainder::Drop drops uncovered columns
536 // -----------------------------------------------------------------------
537
538 #[test]
539 fn test_remainder_drop() {
540 let x = make_x(); // 4×4
541 // Only cover cols 0 and 1 — cols 2 and 3 should be dropped.
542 let ct = ColumnTransformer::new(
543 vec![(
544 "std".into(),
545 Box::new(StandardScaler::<f64>::new()),
546 ColumnSelector::Indices(vec![0, 1]),
547 )],
548 Remainder::Drop,
549 );
550
551 let fitted = ct.fit(&x, &()).unwrap();
552 let out = fitted.transform(&x).unwrap();
553
554 assert_eq!(out.nrows(), 4);
555 assert_eq!(out.ncols(), 2, "uncovered cols should be dropped");
556 }
557
558 // -----------------------------------------------------------------------
559 // 3. Remainder::Passthrough passes uncovered columns through unchanged
560 // -----------------------------------------------------------------------
561
562 #[test]
563 fn test_remainder_passthrough() {
564 let x = make_x(); // 4×4
565 // Only cover cols 0 and 1 — cols 2 and 3 should pass through.
566 let ct = ColumnTransformer::new(
567 vec![(
568 "std".into(),
569 Box::new(StandardScaler::<f64>::new()),
570 ColumnSelector::Indices(vec![0, 1]),
571 )],
572 Remainder::Passthrough,
573 );
574
575 let fitted = ct.fit(&x, &()).unwrap();
576 let out = fitted.transform(&x).unwrap();
577
578 assert_eq!(out.nrows(), 4);
579 assert_eq!(out.ncols(), 4, "passthrough: 2 transformed + 2 remainder");
580
581 // The last 2 columns should be the original cols 2 and 3.
582 for i in 0..4 {
583 assert_abs_diff_eq!(out[[i, 2]], x[[i, 2]], epsilon = 1e-12);
584 assert_abs_diff_eq!(out[[i, 3]], x[[i, 3]], epsilon = 1e-12);
585 }
586 }
587
588 // -----------------------------------------------------------------------
589 // 4. Invalid column index (out of range)
590 // -----------------------------------------------------------------------
591
592 #[test]
593 fn test_invalid_column_index_out_of_range() {
594 let x = make_x(); // 4×4 — valid indices are 0..3
595 let ct = ColumnTransformer::new(
596 vec![(
597 "std".into(),
598 Box::new(StandardScaler::<f64>::new()),
599 ColumnSelector::Indices(vec![0, 99]), // 99 is out of range
600 )],
601 Remainder::Drop,
602 );
603 let result = ct.fit(&x, &());
604 assert!(result.is_err(), "expected error for out-of-range index");
605 }
606
607 // -----------------------------------------------------------------------
608 // 5. Empty transformer list with Remainder::Drop
609 // -----------------------------------------------------------------------
610
611 #[test]
612 fn test_empty_transformer_list_drop() {
613 let x = make_x();
614 let ct = ColumnTransformer::new(vec![], Remainder::Drop);
615 let fitted = ct.fit(&x, &()).unwrap();
616 let out = fitted.transform(&x).unwrap();
617 // No transformers, remainder dropped → empty output
618 assert_eq!(out.nrows(), 0, "hstack of nothing with no passthrough");
619 }
620
621 // -----------------------------------------------------------------------
622 // 6. Empty transformer list with Remainder::Passthrough
623 // -----------------------------------------------------------------------
624
625 #[test]
626 fn test_empty_transformer_list_passthrough() {
627 let x = make_x(); // 4×4
628 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
629 let fitted = ct.fit(&x, &()).unwrap();
630 let out = fitted.transform(&x).unwrap();
631 // No transformers, all columns pass through unchanged.
632 assert_eq!(out.nrows(), 4);
633 assert_eq!(out.ncols(), 4);
634 for i in 0..4 {
635 for j in 0..4 {
636 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
637 }
638 }
639 }
640
641 // -----------------------------------------------------------------------
642 // 7. Overlapping column selections
643 // -----------------------------------------------------------------------
644
645 #[test]
646 fn test_overlapping_column_selections() {
647 let x = make_x(); // 4×4
648 // Both transformers select col 0 (overlapping is allowed).
649 let ct = ColumnTransformer::new(
650 vec![
651 (
652 "std1".into(),
653 Box::new(StandardScaler::<f64>::new()),
654 ColumnSelector::Indices(vec![0, 1]),
655 ),
656 (
657 "mm1".into(),
658 Box::new(MinMaxScaler::<f64>::new()),
659 ColumnSelector::Indices(vec![0, 2]), // col 0 also used here
660 ),
661 ],
662 Remainder::Drop,
663 );
664
665 let fitted = ct.fit(&x, &()).unwrap();
666 let out = fitted.transform(&x).unwrap();
667
668 // Output: 2 cols from std1 + 2 cols from mm1 = 4 cols
669 assert_eq!(out.nrows(), 4);
670 assert_eq!(out.ncols(), 4);
671 }
672
673 // -----------------------------------------------------------------------
674 // 8. Single transformer
675 // -----------------------------------------------------------------------
676
677 #[test]
678 fn test_single_transformer() {
679 let x = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
680 let ct = ColumnTransformer::new(
681 vec![(
682 "mm".into(),
683 Box::new(MinMaxScaler::<f64>::new()),
684 ColumnSelector::Indices(vec![0, 1]),
685 )],
686 Remainder::Drop,
687 );
688
689 let fitted = ct.fit(&x, &()).unwrap();
690 let out = fitted.transform(&x).unwrap();
691
692 assert_eq!(out.nrows(), 3);
693 assert_eq!(out.ncols(), 2);
694
695 // MinMax on cols 0 and 1: first row → 0.0, last row → 1.0
696 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
697 assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10);
698 assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10);
699 assert_abs_diff_eq!(out[[2, 1]], 1.0, epsilon = 1e-10);
700 }
701
702 // -----------------------------------------------------------------------
703 // 9. make_column_transformer convenience function
704 // -----------------------------------------------------------------------
705
706 #[test]
707 fn test_make_column_transformer_auto_names() {
708 let x = make_x();
709 let ct = make_column_transformer(
710 vec![
711 (
712 Box::new(StandardScaler::<f64>::new()),
713 ColumnSelector::Indices(vec![0, 1]),
714 ),
715 (
716 Box::new(MinMaxScaler::<f64>::new()),
717 ColumnSelector::Indices(vec![2, 3]),
718 ),
719 ],
720 Remainder::Drop,
721 );
722
723 let fitted = ct.fit(&x, &()).unwrap();
724 assert_eq!(
725 fitted.transformer_names(),
726 vec!["transformer-0", "transformer-1"]
727 );
728
729 let out = fitted.transform(&x).unwrap();
730 assert_eq!(out.nrows(), 4);
731 assert_eq!(out.ncols(), 4);
732 }
733
734 // -----------------------------------------------------------------------
735 // 10. Pipeline integration
736 // -----------------------------------------------------------------------
737
738 #[test]
739 fn test_pipeline_integration() {
740 // Wrap a ColumnTransformer as a pipeline step.
741 let x = make_x();
742 let y = Array1::<f64>::zeros(4);
743
744 let ct = ColumnTransformer::new(
745 vec![(
746 "std".into(),
747 Box::new(StandardScaler::<f64>::new()),
748 ColumnSelector::Indices(vec![0, 1, 2, 3]),
749 )],
750 Remainder::Drop,
751 );
752
753 // Use a trivial estimator that sums rows.
754 struct SumEstimator;
755 impl PipelineEstimator for SumEstimator {
756 fn fit_pipeline(
757 &self,
758 _x: &Array2<f64>,
759 _y: &Array1<f64>,
760 ) -> Result<Box<dyn ferrolearn_core::pipeline::FittedPipelineEstimator>, FerroError>
761 {
762 Ok(Box::new(FittedSum))
763 }
764 }
765 struct FittedSum;
766 impl ferrolearn_core::pipeline::FittedPipelineEstimator for FittedSum {
767 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
768 let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
769 Ok(Array1::from_vec(sums))
770 }
771 }
772
773 let pipeline = Pipeline::new()
774 .transform_step("ct", Box::new(ct))
775 .estimator_step("sum", Box::new(SumEstimator));
776
777 use ferrolearn_core::Fit as _;
778 let fitted_pipeline = pipeline.fit(&x, &y).unwrap();
779
780 use ferrolearn_core::Predict as _;
781 let preds = fitted_pipeline.predict(&x).unwrap();
782 assert_eq!(preds.len(), 4);
783 }
784
785 // -----------------------------------------------------------------------
786 // 11. Transform shape correctness — number of output columns
787 // -----------------------------------------------------------------------
788
789 #[test]
790 fn test_output_shape_all_selected_drop() {
791 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
792 let ct = ColumnTransformer::new(
793 vec![
794 (
795 "s".into(),
796 Box::new(StandardScaler::<f64>::new()),
797 ColumnSelector::Indices(vec![0]),
798 ),
799 (
800 "m".into(),
801 Box::new(MinMaxScaler::<f64>::new()),
802 ColumnSelector::Indices(vec![1, 2]),
803 ),
804 ],
805 Remainder::Drop,
806 );
807 let fitted = ct.fit(&x, &()).unwrap();
808 let out = fitted.transform(&x).unwrap();
809 assert_eq!(out.shape(), &[2, 3]);
810 }
811
812 // -----------------------------------------------------------------------
813 // 12. Transform shape — partial selection + passthrough
814 // -----------------------------------------------------------------------
815
816 #[test]
817 fn test_output_shape_partial_passthrough() {
818 // 5-column input, transform 2 cols, passthrough 3
819 let x =
820 Array2::<f64>::from_shape_vec((3, 5), (1..=15).map(|v| v as f64).collect()).unwrap();
821 let ct = ColumnTransformer::new(
822 vec![(
823 "std".into(),
824 Box::new(StandardScaler::<f64>::new()),
825 ColumnSelector::Indices(vec![0, 1]),
826 )],
827 Remainder::Passthrough,
828 );
829 let fitted = ct.fit(&x, &()).unwrap();
830 let out = fitted.transform(&x).unwrap();
831 assert_eq!(out.shape(), &[3, 5]);
832 }
833
834 // -----------------------------------------------------------------------
835 // 13. n_features_in accessor
836 // -----------------------------------------------------------------------
837
838 #[test]
839 fn test_n_features_in() {
840 let x = make_x(); // 4×4
841 let ct = ColumnTransformer::new(
842 vec![(
843 "std".into(),
844 Box::new(StandardScaler::<f64>::new()),
845 ColumnSelector::Indices(vec![0]),
846 )],
847 Remainder::Drop,
848 );
849 let fitted = ct.fit(&x, &()).unwrap();
850 assert_eq!(fitted.n_features_in(), 4);
851 }
852
853 // -----------------------------------------------------------------------
854 // 14. Shape mismatch on transform (wrong number of columns)
855 // -----------------------------------------------------------------------
856
857 #[test]
858 fn test_shape_mismatch_on_transform() {
859 let x = make_x(); // 4×4
860 let ct = ColumnTransformer::new(
861 vec![(
862 "std".into(),
863 Box::new(StandardScaler::<f64>::new()),
864 ColumnSelector::Indices(vec![0, 1]),
865 )],
866 Remainder::Drop,
867 );
868 let fitted = ct.fit(&x, &()).unwrap();
869
870 // Now pass a matrix with only 2 columns — should fail.
871 let x_bad = array![[1.0_f64, 2.0], [3.0, 4.0]];
872 let result = fitted.transform(&x_bad);
873 assert!(result.is_err(), "expected shape mismatch error");
874 }
875
876 // -----------------------------------------------------------------------
877 // 15. remainder_indices accessor
878 // -----------------------------------------------------------------------
879
880 #[test]
881 fn test_remainder_indices_accessor() {
882 let x = make_x(); // 4×4
883 let ct = ColumnTransformer::new(
884 vec![(
885 "std".into(),
886 Box::new(StandardScaler::<f64>::new()),
887 ColumnSelector::Indices(vec![0, 2]),
888 )],
889 Remainder::Passthrough,
890 );
891 let fitted = ct.fit(&x, &()).unwrap();
892 // Remainder should be cols 1 and 3.
893 assert_eq!(fitted.remainder_indices(), &[1, 3]);
894 }
895
896 // -----------------------------------------------------------------------
897 // 16. StandardScaler output values are correct (zero-mean)
898 // -----------------------------------------------------------------------
899
900 #[test]
901 fn test_standard_scaler_zero_mean_in_output() {
902 let x = array![[1.0_f64, 100.0, 0.5], [2.0, 200.0, 1.5], [3.0, 300.0, 2.5],];
903 let ct = ColumnTransformer::new(
904 vec![(
905 "std".into(),
906 Box::new(StandardScaler::<f64>::new()),
907 ColumnSelector::Indices(vec![0, 1]),
908 )],
909 Remainder::Drop,
910 );
911 let fitted = ct.fit(&x, &()).unwrap();
912 let out = fitted.transform(&x).unwrap();
913
914 // Cols 0 and 1 of output should have mean ≈ 0.
915 for j in 0..2 {
916 let mean: f64 = out.column(j).iter().sum::<f64>() / 3.0;
917 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
918 }
919 }
920
921 // -----------------------------------------------------------------------
922 // 17. MinMaxScaler output values are in [0, 1]
923 // -----------------------------------------------------------------------
924
925 #[test]
926 fn test_min_max_values_in_range() {
927 let x = make_x();
928 let ct = ColumnTransformer::new(
929 vec![(
930 "mm".into(),
931 Box::new(MinMaxScaler::<f64>::new()),
932 ColumnSelector::Indices(vec![0, 1, 2, 3]),
933 )],
934 Remainder::Drop,
935 );
936 let fitted = ct.fit(&x, &()).unwrap();
937 let out = fitted.transform(&x).unwrap();
938
939 for j in 0..4 {
940 let col_min = out.column(j).iter().copied().fold(f64::INFINITY, f64::min);
941 let col_max = out
942 .column(j)
943 .iter()
944 .copied()
945 .fold(f64::NEG_INFINITY, f64::max);
946 assert_abs_diff_eq!(col_min, 0.0, epsilon = 1e-10);
947 assert_abs_diff_eq!(col_max, 1.0, epsilon = 1e-10);
948 }
949 }
950
951 // -----------------------------------------------------------------------
952 // 18. Pipeline transformer interface (fit_pipeline / transform_pipeline)
953 // -----------------------------------------------------------------------
954
955 #[test]
956 fn test_pipeline_transformer_interface() {
957 let x = make_x();
958 let y = Array1::<f64>::zeros(4);
959 let ct = ColumnTransformer::new(
960 vec![(
961 "std".into(),
962 Box::new(StandardScaler::<f64>::new()),
963 ColumnSelector::Indices(vec![0, 1]),
964 )],
965 Remainder::Passthrough,
966 );
967 let fitted_box = ct.fit_pipeline(&x, &y).unwrap();
968 let out = fitted_box.transform_pipeline(&x).unwrap();
969 assert_eq!(out.nrows(), 4);
970 assert_eq!(out.ncols(), 4);
971 }
972
973 // -----------------------------------------------------------------------
974 // 19. Remainder passthrough values are identical to input values
975 // -----------------------------------------------------------------------
976
977 #[test]
978 fn test_passthrough_values_are_exact() {
979 let x = array![[10.0_f64, 20.0, 30.0], [40.0, 50.0, 60.0],];
980 // Only transform col 0; cols 1 and 2 pass through.
981 let ct = ColumnTransformer::new(
982 vec![(
983 "mm".into(),
984 Box::new(MinMaxScaler::<f64>::new()),
985 ColumnSelector::Indices(vec![0]),
986 )],
987 Remainder::Passthrough,
988 );
989 let fitted = ct.fit(&x, &()).unwrap();
990 let out = fitted.transform(&x).unwrap();
991 // out[:, 1] == x[:, 1] and out[:, 2] == x[:, 2]
992 assert_abs_diff_eq!(out[[0, 1]], 20.0, epsilon = 1e-12);
993 assert_abs_diff_eq!(out[[1, 1]], 50.0, epsilon = 1e-12);
994 assert_abs_diff_eq!(out[[0, 2]], 30.0, epsilon = 1e-12);
995 assert_abs_diff_eq!(out[[1, 2]], 60.0, epsilon = 1e-12);
996 }
997
998 // -----------------------------------------------------------------------
999 // 20. Transformer names from explicit ColumnTransformer::new
1000 // -----------------------------------------------------------------------
1001
1002 #[test]
1003 fn test_transformer_names_explicit() {
1004 let x = make_x();
1005 let ct = ColumnTransformer::new(
1006 vec![
1007 (
1008 "alpha".into(),
1009 Box::new(StandardScaler::<f64>::new()),
1010 ColumnSelector::Indices(vec![0]),
1011 ),
1012 (
1013 "beta".into(),
1014 Box::new(MinMaxScaler::<f64>::new()),
1015 ColumnSelector::Indices(vec![1]),
1016 ),
1017 ],
1018 Remainder::Drop,
1019 );
1020 let fitted = ct.fit(&x, &()).unwrap();
1021 assert_eq!(fitted.transformer_names(), vec!["alpha", "beta"]);
1022 }
1023
1024 // -----------------------------------------------------------------------
1025 // 21. make_column_transformer with single step
1026 // -----------------------------------------------------------------------
1027
1028 #[test]
1029 fn test_make_column_transformer_single() {
1030 let x = array![[1.0_f64, 2.0], [3.0, 4.0]];
1031 let ct = make_column_transformer(
1032 vec![(
1033 Box::new(StandardScaler::<f64>::new()),
1034 ColumnSelector::Indices(vec![0, 1]),
1035 )],
1036 Remainder::Drop,
1037 );
1038 let fitted = ct.fit(&x, &()).unwrap();
1039 assert_eq!(fitted.transformer_names(), vec!["transformer-0"]);
1040 let out = fitted.transform(&x).unwrap();
1041 assert_eq!(out.shape(), &[2, 2]);
1042 }
1043
1044 // -----------------------------------------------------------------------
1045 // 22. Edge case: all columns as remainder with Passthrough
1046 // -----------------------------------------------------------------------
1047
1048 #[test]
1049 fn test_all_remainder_passthrough_unchanged() {
1050 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1051 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
1052 let fitted = ct.fit(&x, &()).unwrap();
1053 let out = fitted.transform(&x).unwrap();
1054 assert_eq!(out.shape(), &[2, 3]);
1055 for i in 0..2 {
1056 for j in 0..3 {
1057 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
1058 }
1059 }
1060 }
1061}