ferrolearn_preprocess/column_transformer.rs
1//! Column transformer: apply different transformers to different column subsets.
2//!
3//! [`ColumnTransformer`] applies each registered transformer to its designated
4//! column subset, then horizontally concatenates the outputs into a single
5//! `Array2<f64>`. Columns not captured by any transformer can be dropped or
6//! passed through unchanged via the [`Remainder`] policy.
7//!
8//! # Examples
9//!
10//! ```
11//! use ferrolearn_preprocess::column_transformer::{
12//! ColumnSelector, ColumnTransformer, Remainder,
13//! };
14//! use ferrolearn_preprocess::{StandardScaler, MinMaxScaler};
15//! use ferrolearn_core::Fit;
16//! use ferrolearn_core::Transform;
17//! use ndarray::array;
18//!
19//! let x = array![
20//! [1.0_f64, 2.0, 10.0, 20.0],
21//! [3.0, 4.0, 30.0, 40.0],
22//! [5.0, 6.0, 50.0, 60.0],
23//! ];
24//!
25//! let ct = ColumnTransformer::new(
26//! vec![
27//! ("std".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1])),
28//! ("mm".into(), Box::new(MinMaxScaler::<f64>::new()), ColumnSelector::Indices(vec![2, 3])),
29//! ],
30//! Remainder::Drop,
31//! );
32//!
33//! let fitted = ct.fit(&x, &()).unwrap();
34//! let out = fitted.transform(&x).unwrap();
35//! assert_eq!(out.ncols(), 4);
36//! assert_eq!(out.nrows(), 3);
37//! ```
38
39use ferrolearn_core::error::FerroError;
40use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
41use ferrolearn_core::traits::{Fit, Transform};
42use ndarray::{Array1, Array2};
43
44// ---------------------------------------------------------------------------
45// ColumnSelector
46// ---------------------------------------------------------------------------
47
48/// Specifies which columns a transformer should operate on.
49///
50/// Currently the only supported variant is [`Indices`](ColumnSelector::Indices),
51/// which selects columns by their zero-based integer positions.
52#[derive(Debug, Clone)]
53pub enum ColumnSelector {
54 /// Select columns by zero-based index.
55 ///
56 /// The indices do not need to be sorted, but every index must be strictly
57 /// less than the number of columns in the input matrix. Duplicate indices
58 /// are allowed; the same column will simply appear twice in the sub-matrix
59 /// passed to the transformer.
60 Indices(Vec<usize>),
61}
62
63impl ColumnSelector {
64 /// Resolve the selector to a concrete list of column indices.
65 ///
66 /// # Errors
67 ///
68 /// Returns [`FerroError::InvalidParameter`] if any index is out of range
69 /// (i.e., `>= n_features`).
70 fn resolve(&self, n_features: usize) -> Result<Vec<usize>, FerroError> {
71 match self {
72 ColumnSelector::Indices(indices) => {
73 for &idx in indices {
74 if idx >= n_features {
75 return Err(FerroError::InvalidParameter {
76 name: "ColumnSelector::Indices".into(),
77 reason: format!(
78 "column index {idx} is out of range for input with {n_features} features"
79 ),
80 });
81 }
82 }
83 Ok(indices.clone())
84 }
85 }
86 }
87}
88
89// ---------------------------------------------------------------------------
90// Remainder
91// ---------------------------------------------------------------------------
92
93/// Policy for columns that are not selected by any transformer.
94///
95/// When at least one column is not covered by any registered transformer,
96/// `Remainder` determines what happens to those columns in the output.
97#[derive(Debug, Clone)]
98pub enum Remainder {
99 /// Discard remainder columns — they do not appear in the output.
100 Drop,
101 /// Pass remainder columns through unchanged, appended after all
102 /// transformer outputs.
103 Passthrough,
104}
105
106// ---------------------------------------------------------------------------
107// Helper: extract a sub-matrix by column indices
108// ---------------------------------------------------------------------------
109
110/// Build a new `Array2<f64>` containing only the columns at `indices`.
111///
112/// Columns are emitted in the order they appear in `indices`.
113fn select_columns(x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
114 let nrows = x.nrows();
115 let ncols = indices.len();
116 if ncols == 0 {
117 return Array2::zeros((nrows, 0));
118 }
119 let mut out = Array2::zeros((nrows, ncols));
120 for (new_j, &old_j) in indices.iter().enumerate() {
121 out.column_mut(new_j).assign(&x.column(old_j));
122 }
123 out
124}
125
126/// Horizontally concatenate a slice of `Array2<f64>` matrices.
127///
128/// All matrices must have the same number of rows.
129///
130/// # Errors
131///
132/// Returns [`FerroError::ShapeMismatch`] if row counts differ.
133fn hstack(matrices: &[Array2<f64>]) -> Result<Array2<f64>, FerroError> {
134 if matrices.is_empty() {
135 return Ok(Array2::zeros((0, 0)));
136 }
137 let nrows = matrices[0].nrows();
138 let total_cols: usize = matrices.iter().map(|m| m.ncols()).sum();
139
140 // Handle the case where the first matrix establishes nrows = 0 separately.
141 if total_cols == 0 {
142 return Ok(Array2::zeros((nrows, 0)));
143 }
144
145 let mut out = Array2::zeros((nrows, total_cols));
146 let mut col_offset = 0;
147 for m in matrices {
148 if m.nrows() != nrows {
149 return Err(FerroError::ShapeMismatch {
150 expected: vec![nrows, m.ncols()],
151 actual: vec![m.nrows(), m.ncols()],
152 context: "ColumnTransformer hstack: row count mismatch".into(),
153 });
154 }
155 let end = col_offset + m.ncols();
156 if m.ncols() > 0 {
157 out.slice_mut(ndarray::s![.., col_offset..end]).assign(m);
158 }
159 col_offset = end;
160 }
161 Ok(out)
162}
163
164// ---------------------------------------------------------------------------
165// ColumnTransformer (unfitted)
166// ---------------------------------------------------------------------------
167
168/// An unfitted column transformer.
169///
170/// Applies each registered transformer to its designated column subset, then
171/// horizontally concatenates all outputs. The [`Remainder`] policy controls
172/// what happens to columns not covered by any transformer.
173///
174/// # Transformer order
175///
176/// Transformers are applied and their outputs concatenated in the order they
177/// were registered. Remainder columns (when
178/// `remainder = `[`Remainder::Passthrough`]) are appended last.
179///
180/// # Overlapping selections
181///
182/// Each transformer receives its own copy of the selected columns, so
183/// overlapping `ColumnSelector`s are fully supported.
184///
185/// # Examples
186///
187/// ```
188/// use ferrolearn_preprocess::column_transformer::{
189/// ColumnSelector, ColumnTransformer, Remainder,
190/// };
191/// use ferrolearn_preprocess::StandardScaler;
192/// use ferrolearn_core::Fit;
193/// use ferrolearn_core::Transform;
194/// use ndarray::array;
195///
196/// let x = array![[1.0_f64, 10.0, 100.0], [2.0, 20.0, 200.0], [3.0, 30.0, 300.0]];
197/// let ct = ColumnTransformer::new(
198/// vec![("scaler".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
199/// Remainder::Passthrough,
200/// );
201/// let fitted = ct.fit(&x, &()).unwrap();
202/// let out = fitted.transform(&x).unwrap();
203/// // 2 scaled columns + 1 passthrough column
204/// assert_eq!(out.ncols(), 3);
205/// ```
206pub struct ColumnTransformer {
207 /// Named transformer steps with their column selectors.
208 transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
209 /// Policy for columns not covered by any transformer.
210 remainder: Remainder,
211}
212
213impl ColumnTransformer {
214 /// Create a new `ColumnTransformer`.
215 ///
216 /// # Parameters
217 ///
218 /// - `transformers`: A list of `(name, transformer, selector)` triples.
219 /// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
220 #[must_use]
221 pub fn new(
222 transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
223 remainder: Remainder,
224 ) -> Self {
225 Self {
226 transformers,
227 remainder,
228 }
229 }
230}
231
232// ---------------------------------------------------------------------------
233// Fit implementation
234// ---------------------------------------------------------------------------
235
236impl Fit<Array2<f64>, ()> for ColumnTransformer {
237 type Fitted = FittedColumnTransformer;
238 type Error = FerroError;
239
240 /// Fit each transformer on its selected column subset.
241 ///
242 /// Validates that all selected column indices are within bounds before
243 /// fitting any transformer.
244 ///
245 /// # Errors
246 ///
247 /// - [`FerroError::InvalidParameter`] if any column index is out of range.
248 /// - Propagates any error returned by an individual transformer's
249 /// `fit_pipeline` call.
250 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedColumnTransformer, FerroError> {
251 let n_features = x.ncols();
252 let n_rows = x.nrows();
253
254 // A dummy y vector required by PipelineTransformer::fit_pipeline.
255 let dummy_y = Array1::<f64>::zeros(n_rows);
256
257 // Resolve all selectors up front to validate indices eagerly.
258 let mut resolved_selectors: Vec<Vec<usize>> = Vec::with_capacity(self.transformers.len());
259 for (name, _, selector) in &self.transformers {
260 let indices = selector.resolve(n_features).map_err(|e| {
261 // Enrich the error with the transformer name.
262 FerroError::InvalidParameter {
263 name: format!("ColumnTransformer step '{name}'"),
264 reason: e.to_string(),
265 }
266 })?;
267 resolved_selectors.push(indices);
268 }
269
270 // Build the set of covered column indices (for remainder computation).
271 let covered: std::collections::HashSet<usize> = resolved_selectors
272 .iter()
273 .flat_map(|v| v.iter().copied())
274 .collect();
275
276 let remainder_indices: Vec<usize> =
277 (0..n_features).filter(|c| !covered.contains(c)).collect();
278
279 // Fit each transformer on its sub-matrix.
280 let mut fitted_transformers: Vec<FittedSubTransformer> =
281 Vec::with_capacity(self.transformers.len());
282
283 for ((name, transformer, _), indices) in
284 self.transformers.iter().zip(resolved_selectors.into_iter())
285 {
286 let sub_x = select_columns(x, &indices);
287 let fitted = transformer.fit_pipeline(&sub_x, &dummy_y)?;
288 fitted_transformers.push((name.clone(), fitted, indices));
289 }
290
291 Ok(FittedColumnTransformer {
292 fitted_transformers,
293 remainder: self.remainder.clone(),
294 remainder_indices,
295 n_features_in: n_features,
296 })
297 }
298}
299
300// ---------------------------------------------------------------------------
301// PipelineTransformer implementation
302// ---------------------------------------------------------------------------
303
304impl PipelineTransformer<f64> for ColumnTransformer {
305 /// Fit the column transformer using the pipeline interface.
306 ///
307 /// The `y` argument is ignored; it exists only for API compatibility.
308 ///
309 /// # Errors
310 ///
311 /// Propagates errors from [`Fit::fit`].
312 fn fit_pipeline(
313 &self,
314 x: &Array2<f64>,
315 _y: &Array1<f64>,
316 ) -> Result<Box<dyn FittedPipelineTransformer<f64>>, FerroError> {
317 let fitted = self.fit(x, &())?;
318 Ok(Box::new(fitted))
319 }
320}
321
322// ---------------------------------------------------------------------------
323// FittedColumnTransformer
324// ---------------------------------------------------------------------------
325
326/// A named, fitted sub-transformer with its column indices.
327type FittedSubTransformer = (String, Box<dyn FittedPipelineTransformer<f64>>, Vec<usize>);
328
329/// A fitted column transformer holding fitted sub-transformers and metadata.
330///
331/// Created by calling [`Fit::fit`] on a [`ColumnTransformer`].
332/// Implements [`Transform<Array2<f64>>`] to apply the fitted transformers and
333/// concatenate their outputs, as well as [`FittedPipelineTransformer`] for use
334/// inside a [`ferrolearn_core::pipeline::Pipeline`].
335pub struct FittedColumnTransformer {
336 /// Fitted transformers with their associated column indices.
337 fitted_transformers: Vec<FittedSubTransformer>,
338 /// Remainder policy from the original [`ColumnTransformer`].
339 remainder: Remainder,
340 /// Column indices not covered by any transformer.
341 remainder_indices: Vec<usize>,
342 /// Number of input features seen during fitting.
343 n_features_in: usize,
344}
345
346impl FittedColumnTransformer {
347 /// Return the number of input features seen during fitting.
348 #[must_use]
349 pub fn n_features_in(&self) -> usize {
350 self.n_features_in
351 }
352
353 /// Return the names of all registered transformer steps.
354 #[must_use]
355 pub fn transformer_names(&self) -> Vec<&str> {
356 self.fitted_transformers
357 .iter()
358 .map(|(name, _, _)| name.as_str())
359 .collect()
360 }
361
362 /// Return the remainder column indices (columns not selected by any transformer).
363 #[must_use]
364 pub fn remainder_indices(&self) -> &[usize] {
365 &self.remainder_indices
366 }
367}
368
369// ---------------------------------------------------------------------------
370// Transform implementation
371// ---------------------------------------------------------------------------
372
373impl Transform<Array2<f64>> for FittedColumnTransformer {
374 type Output = Array2<f64>;
375 type Error = FerroError;
376
377 /// Transform data by applying each fitted transformer to its column subset,
378 /// then horizontally concatenating all outputs.
379 ///
380 /// When `remainder = Passthrough`, the unselected columns are appended
381 /// after all transformer outputs. When `remainder = Drop`, they are
382 /// discarded.
383 ///
384 /// # Errors
385 ///
386 /// - [`FerroError::ShapeMismatch`] if the input does not have
387 /// `n_features_in` columns.
388 /// - Propagates any error from individual transformer `transform_pipeline`
389 /// calls.
390 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
391 if x.ncols() != self.n_features_in {
392 return Err(FerroError::ShapeMismatch {
393 expected: vec![x.nrows(), self.n_features_in],
394 actual: vec![x.nrows(), x.ncols()],
395 context: "FittedColumnTransformer::transform".into(),
396 });
397 }
398
399 let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.fitted_transformers.len() + 1);
400
401 for (_, fitted, indices) in &self.fitted_transformers {
402 let sub_x = select_columns(x, indices);
403 let transformed = fitted.transform_pipeline(&sub_x)?;
404 parts.push(transformed);
405 }
406
407 // Append remainder columns if requested.
408 if matches!(self.remainder, Remainder::Passthrough) && !self.remainder_indices.is_empty() {
409 let remainder_sub = select_columns(x, &self.remainder_indices);
410 parts.push(remainder_sub);
411 }
412
413 hstack(&parts)
414 }
415}
416
417// ---------------------------------------------------------------------------
418// FittedPipelineTransformer implementation
419// ---------------------------------------------------------------------------
420
421impl FittedPipelineTransformer<f64> for FittedColumnTransformer {
422 /// Transform data using the pipeline interface.
423 ///
424 /// # Errors
425 ///
426 /// Propagates errors from [`Transform::transform`].
427 fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
428 self.transform(x)
429 }
430}
431
432// ---------------------------------------------------------------------------
433// make_column_transformer convenience function
434// ---------------------------------------------------------------------------
435
436/// Convenience function to build a [`ColumnTransformer`] with auto-generated
437/// step names.
438///
439/// Steps are named `"transformer-0"`, `"transformer-1"`, etc.
440///
441/// # Parameters
442///
443/// - `transformers`: A list of `(transformer, selector)` pairs.
444/// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
445///
446/// # Examples
447///
448/// ```
449/// use ferrolearn_preprocess::column_transformer::{
450/// make_column_transformer, ColumnSelector, Remainder,
451/// };
452/// use ferrolearn_preprocess::StandardScaler;
453/// use ferrolearn_core::Fit;
454/// use ferrolearn_core::Transform;
455/// use ndarray::array;
456///
457/// let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0]];
458/// let ct = make_column_transformer(
459/// vec![(Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
460/// Remainder::Drop,
461/// );
462/// let fitted = ct.fit(&x, &()).unwrap();
463/// let out = fitted.transform(&x).unwrap();
464/// assert_eq!(out.ncols(), 2);
465/// ```
466#[must_use]
467pub fn make_column_transformer(
468 transformers: Vec<(Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
469 remainder: Remainder,
470) -> ColumnTransformer {
471 let named: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)> = transformers
472 .into_iter()
473 .enumerate()
474 .map(|(i, (t, s))| (format!("transformer-{i}"), t, s))
475 .collect();
476 ColumnTransformer::new(named, remainder)
477}
478
479// ---------------------------------------------------------------------------
480// Tests
481// ---------------------------------------------------------------------------
482
483#[cfg(test)]
484mod tests {
485 use super::*;
486 use approx::assert_abs_diff_eq;
487 use ferrolearn_core::pipeline::{Pipeline, PipelineEstimator};
488 use ndarray::{Array2, array};
489
490 use crate::{MinMaxScaler, StandardScaler};
491
492 // -----------------------------------------------------------------------
493 // Helpers
494 // -----------------------------------------------------------------------
495
496 /// Build a simple 4-column test matrix (rows = 4, cols = 4).
497 fn make_x() -> Array2<f64> {
498 array![
499 [1.0, 2.0, 10.0, 20.0],
500 [2.0, 4.0, 20.0, 40.0],
501 [3.0, 6.0, 30.0, 60.0],
502 [4.0, 8.0, 40.0, 80.0],
503 ]
504 }
505
506 // -----------------------------------------------------------------------
507 // 1. Basic 2-transformer usage
508 // -----------------------------------------------------------------------
509
510 #[test]
511 fn test_basic_two_transformers_drop_remainder() {
512 let x = make_x(); // 4×4
513 let ct = ColumnTransformer::new(
514 vec![
515 (
516 "std".into(),
517 Box::new(StandardScaler::<f64>::new()),
518 ColumnSelector::Indices(vec![0, 1]),
519 ),
520 (
521 "mm".into(),
522 Box::new(MinMaxScaler::<f64>::new()),
523 ColumnSelector::Indices(vec![2, 3]),
524 ),
525 ],
526 Remainder::Drop,
527 );
528
529 let fitted = ct.fit(&x, &()).unwrap();
530 let out = fitted.transform(&x).unwrap();
531
532 // All 4 columns covered → no remainder; output is 4 cols
533 assert_eq!(out.nrows(), 4);
534 assert_eq!(out.ncols(), 4);
535 }
536
537 // -----------------------------------------------------------------------
538 // 2. Remainder::Drop drops uncovered columns
539 // -----------------------------------------------------------------------
540
541 #[test]
542 fn test_remainder_drop() {
543 let x = make_x(); // 4×4
544 // Only cover cols 0 and 1 — cols 2 and 3 should be dropped.
545 let ct = ColumnTransformer::new(
546 vec![(
547 "std".into(),
548 Box::new(StandardScaler::<f64>::new()),
549 ColumnSelector::Indices(vec![0, 1]),
550 )],
551 Remainder::Drop,
552 );
553
554 let fitted = ct.fit(&x, &()).unwrap();
555 let out = fitted.transform(&x).unwrap();
556
557 assert_eq!(out.nrows(), 4);
558 assert_eq!(out.ncols(), 2, "uncovered cols should be dropped");
559 }
560
561 // -----------------------------------------------------------------------
562 // 3. Remainder::Passthrough passes uncovered columns through unchanged
563 // -----------------------------------------------------------------------
564
565 #[test]
566 fn test_remainder_passthrough() {
567 let x = make_x(); // 4×4
568 // Only cover cols 0 and 1 — cols 2 and 3 should pass through.
569 let ct = ColumnTransformer::new(
570 vec![(
571 "std".into(),
572 Box::new(StandardScaler::<f64>::new()),
573 ColumnSelector::Indices(vec![0, 1]),
574 )],
575 Remainder::Passthrough,
576 );
577
578 let fitted = ct.fit(&x, &()).unwrap();
579 let out = fitted.transform(&x).unwrap();
580
581 assert_eq!(out.nrows(), 4);
582 assert_eq!(out.ncols(), 4, "passthrough: 2 transformed + 2 remainder");
583
584 // The last 2 columns should be the original cols 2 and 3.
585 for i in 0..4 {
586 assert_abs_diff_eq!(out[[i, 2]], x[[i, 2]], epsilon = 1e-12);
587 assert_abs_diff_eq!(out[[i, 3]], x[[i, 3]], epsilon = 1e-12);
588 }
589 }
590
591 // -----------------------------------------------------------------------
592 // 4. Invalid column index (out of range)
593 // -----------------------------------------------------------------------
594
595 #[test]
596 fn test_invalid_column_index_out_of_range() {
597 let x = make_x(); // 4×4 — valid indices are 0..3
598 let ct = ColumnTransformer::new(
599 vec![(
600 "std".into(),
601 Box::new(StandardScaler::<f64>::new()),
602 ColumnSelector::Indices(vec![0, 99]), // 99 is out of range
603 )],
604 Remainder::Drop,
605 );
606 let result = ct.fit(&x, &());
607 assert!(result.is_err(), "expected error for out-of-range index");
608 }
609
610 // -----------------------------------------------------------------------
611 // 5. Empty transformer list with Remainder::Drop
612 // -----------------------------------------------------------------------
613
614 #[test]
615 fn test_empty_transformer_list_drop() {
616 let x = make_x();
617 let ct = ColumnTransformer::new(vec![], Remainder::Drop);
618 let fitted = ct.fit(&x, &()).unwrap();
619 let out = fitted.transform(&x).unwrap();
620 // No transformers, remainder dropped → empty output
621 assert_eq!(out.nrows(), 0, "hstack of nothing with no passthrough");
622 }
623
624 // -----------------------------------------------------------------------
625 // 6. Empty transformer list with Remainder::Passthrough
626 // -----------------------------------------------------------------------
627
628 #[test]
629 fn test_empty_transformer_list_passthrough() {
630 let x = make_x(); // 4×4
631 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
632 let fitted = ct.fit(&x, &()).unwrap();
633 let out = fitted.transform(&x).unwrap();
634 // No transformers, all columns pass through unchanged.
635 assert_eq!(out.nrows(), 4);
636 assert_eq!(out.ncols(), 4);
637 for i in 0..4 {
638 for j in 0..4 {
639 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
640 }
641 }
642 }
643
644 // -----------------------------------------------------------------------
645 // 7. Overlapping column selections
646 // -----------------------------------------------------------------------
647
648 #[test]
649 fn test_overlapping_column_selections() {
650 let x = make_x(); // 4×4
651 // Both transformers select col 0 (overlapping is allowed).
652 let ct = ColumnTransformer::new(
653 vec![
654 (
655 "std1".into(),
656 Box::new(StandardScaler::<f64>::new()),
657 ColumnSelector::Indices(vec![0, 1]),
658 ),
659 (
660 "mm1".into(),
661 Box::new(MinMaxScaler::<f64>::new()),
662 ColumnSelector::Indices(vec![0, 2]), // col 0 also used here
663 ),
664 ],
665 Remainder::Drop,
666 );
667
668 let fitted = ct.fit(&x, &()).unwrap();
669 let out = fitted.transform(&x).unwrap();
670
671 // Output: 2 cols from std1 + 2 cols from mm1 = 4 cols
672 assert_eq!(out.nrows(), 4);
673 assert_eq!(out.ncols(), 4);
674 }
675
676 // -----------------------------------------------------------------------
677 // 8. Single transformer
678 // -----------------------------------------------------------------------
679
680 #[test]
681 fn test_single_transformer() {
682 let x = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
683 let ct = ColumnTransformer::new(
684 vec![(
685 "mm".into(),
686 Box::new(MinMaxScaler::<f64>::new()),
687 ColumnSelector::Indices(vec![0, 1]),
688 )],
689 Remainder::Drop,
690 );
691
692 let fitted = ct.fit(&x, &()).unwrap();
693 let out = fitted.transform(&x).unwrap();
694
695 assert_eq!(out.nrows(), 3);
696 assert_eq!(out.ncols(), 2);
697
698 // MinMax on cols 0 and 1: first row → 0.0, last row → 1.0
699 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
700 assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10);
701 assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10);
702 assert_abs_diff_eq!(out[[2, 1]], 1.0, epsilon = 1e-10);
703 }
704
705 // -----------------------------------------------------------------------
706 // 9. make_column_transformer convenience function
707 // -----------------------------------------------------------------------
708
709 #[test]
710 fn test_make_column_transformer_auto_names() {
711 let x = make_x();
712 let ct = make_column_transformer(
713 vec![
714 (
715 Box::new(StandardScaler::<f64>::new()),
716 ColumnSelector::Indices(vec![0, 1]),
717 ),
718 (
719 Box::new(MinMaxScaler::<f64>::new()),
720 ColumnSelector::Indices(vec![2, 3]),
721 ),
722 ],
723 Remainder::Drop,
724 );
725
726 let fitted = ct.fit(&x, &()).unwrap();
727 assert_eq!(
728 fitted.transformer_names(),
729 vec!["transformer-0", "transformer-1"]
730 );
731
732 let out = fitted.transform(&x).unwrap();
733 assert_eq!(out.nrows(), 4);
734 assert_eq!(out.ncols(), 4);
735 }
736
737 // -----------------------------------------------------------------------
738 // 10. Pipeline integration
739 // -----------------------------------------------------------------------
740
741 #[test]
742 fn test_pipeline_integration() {
743 // Wrap a ColumnTransformer as a pipeline step.
744 let x = make_x();
745 let y = Array1::<f64>::zeros(4);
746
747 let ct = ColumnTransformer::new(
748 vec![(
749 "std".into(),
750 Box::new(StandardScaler::<f64>::new()),
751 ColumnSelector::Indices(vec![0, 1, 2, 3]),
752 )],
753 Remainder::Drop,
754 );
755
756 // Use a trivial estimator that sums rows.
757 struct SumEstimator;
758 impl PipelineEstimator<f64> for SumEstimator {
759 fn fit_pipeline(
760 &self,
761 _x: &Array2<f64>,
762 _y: &Array1<f64>,
763 ) -> Result<Box<dyn ferrolearn_core::pipeline::FittedPipelineEstimator<f64>>, FerroError>
764 {
765 Ok(Box::new(FittedSum))
766 }
767 }
768 struct FittedSum;
769 impl ferrolearn_core::pipeline::FittedPipelineEstimator<f64> for FittedSum {
770 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
771 let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
772 Ok(Array1::from_vec(sums))
773 }
774 }
775
776 let pipeline = Pipeline::new()
777 .transform_step("ct", Box::new(ct))
778 .estimator_step("sum", Box::new(SumEstimator));
779
780 use ferrolearn_core::Fit as _;
781 let fitted_pipeline = pipeline.fit(&x, &y).unwrap();
782
783 use ferrolearn_core::Predict as _;
784 let preds = fitted_pipeline.predict(&x).unwrap();
785 assert_eq!(preds.len(), 4);
786 }
787
788 // -----------------------------------------------------------------------
789 // 11. Transform shape correctness — number of output columns
790 // -----------------------------------------------------------------------
791
792 #[test]
793 fn test_output_shape_all_selected_drop() {
794 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
795 let ct = ColumnTransformer::new(
796 vec![
797 (
798 "s".into(),
799 Box::new(StandardScaler::<f64>::new()),
800 ColumnSelector::Indices(vec![0]),
801 ),
802 (
803 "m".into(),
804 Box::new(MinMaxScaler::<f64>::new()),
805 ColumnSelector::Indices(vec![1, 2]),
806 ),
807 ],
808 Remainder::Drop,
809 );
810 let fitted = ct.fit(&x, &()).unwrap();
811 let out = fitted.transform(&x).unwrap();
812 assert_eq!(out.shape(), &[2, 3]);
813 }
814
815 // -----------------------------------------------------------------------
816 // 12. Transform shape — partial selection + passthrough
817 // -----------------------------------------------------------------------
818
819 #[test]
820 fn test_output_shape_partial_passthrough() {
821 // 5-column input, transform 2 cols, passthrough 3
822 let x =
823 Array2::<f64>::from_shape_vec((3, 5), (1..=15).map(|v| v as f64).collect()).unwrap();
824 let ct = ColumnTransformer::new(
825 vec![(
826 "std".into(),
827 Box::new(StandardScaler::<f64>::new()),
828 ColumnSelector::Indices(vec![0, 1]),
829 )],
830 Remainder::Passthrough,
831 );
832 let fitted = ct.fit(&x, &()).unwrap();
833 let out = fitted.transform(&x).unwrap();
834 assert_eq!(out.shape(), &[3, 5]);
835 }
836
837 // -----------------------------------------------------------------------
838 // 13. n_features_in accessor
839 // -----------------------------------------------------------------------
840
841 #[test]
842 fn test_n_features_in() {
843 let x = make_x(); // 4×4
844 let ct = ColumnTransformer::new(
845 vec![(
846 "std".into(),
847 Box::new(StandardScaler::<f64>::new()),
848 ColumnSelector::Indices(vec![0]),
849 )],
850 Remainder::Drop,
851 );
852 let fitted = ct.fit(&x, &()).unwrap();
853 assert_eq!(fitted.n_features_in(), 4);
854 }
855
856 // -----------------------------------------------------------------------
857 // 14. Shape mismatch on transform (wrong number of columns)
858 // -----------------------------------------------------------------------
859
860 #[test]
861 fn test_shape_mismatch_on_transform() {
862 let x = make_x(); // 4×4
863 let ct = ColumnTransformer::new(
864 vec![(
865 "std".into(),
866 Box::new(StandardScaler::<f64>::new()),
867 ColumnSelector::Indices(vec![0, 1]),
868 )],
869 Remainder::Drop,
870 );
871 let fitted = ct.fit(&x, &()).unwrap();
872
873 // Now pass a matrix with only 2 columns — should fail.
874 let x_bad = array![[1.0_f64, 2.0], [3.0, 4.0]];
875 let result = fitted.transform(&x_bad);
876 assert!(result.is_err(), "expected shape mismatch error");
877 }
878
879 // -----------------------------------------------------------------------
880 // 15. remainder_indices accessor
881 // -----------------------------------------------------------------------
882
883 #[test]
884 fn test_remainder_indices_accessor() {
885 let x = make_x(); // 4×4
886 let ct = ColumnTransformer::new(
887 vec![(
888 "std".into(),
889 Box::new(StandardScaler::<f64>::new()),
890 ColumnSelector::Indices(vec![0, 2]),
891 )],
892 Remainder::Passthrough,
893 );
894 let fitted = ct.fit(&x, &()).unwrap();
895 // Remainder should be cols 1 and 3.
896 assert_eq!(fitted.remainder_indices(), &[1, 3]);
897 }
898
899 // -----------------------------------------------------------------------
900 // 16. StandardScaler output values are correct (zero-mean)
901 // -----------------------------------------------------------------------
902
903 #[test]
904 fn test_standard_scaler_zero_mean_in_output() {
905 let x = array![[1.0_f64, 100.0, 0.5], [2.0, 200.0, 1.5], [3.0, 300.0, 2.5],];
906 let ct = ColumnTransformer::new(
907 vec![(
908 "std".into(),
909 Box::new(StandardScaler::<f64>::new()),
910 ColumnSelector::Indices(vec![0, 1]),
911 )],
912 Remainder::Drop,
913 );
914 let fitted = ct.fit(&x, &()).unwrap();
915 let out = fitted.transform(&x).unwrap();
916
917 // Cols 0 and 1 of output should have mean ≈ 0.
918 for j in 0..2 {
919 let mean: f64 = out.column(j).iter().sum::<f64>() / 3.0;
920 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
921 }
922 }
923
924 // -----------------------------------------------------------------------
925 // 17. MinMaxScaler output values are in [0, 1]
926 // -----------------------------------------------------------------------
927
928 #[test]
929 fn test_min_max_values_in_range() {
930 let x = make_x();
931 let ct = ColumnTransformer::new(
932 vec![(
933 "mm".into(),
934 Box::new(MinMaxScaler::<f64>::new()),
935 ColumnSelector::Indices(vec![0, 1, 2, 3]),
936 )],
937 Remainder::Drop,
938 );
939 let fitted = ct.fit(&x, &()).unwrap();
940 let out = fitted.transform(&x).unwrap();
941
942 for j in 0..4 {
943 let col_min = out.column(j).iter().copied().fold(f64::INFINITY, f64::min);
944 let col_max = out
945 .column(j)
946 .iter()
947 .copied()
948 .fold(f64::NEG_INFINITY, f64::max);
949 assert_abs_diff_eq!(col_min, 0.0, epsilon = 1e-10);
950 assert_abs_diff_eq!(col_max, 1.0, epsilon = 1e-10);
951 }
952 }
953
954 // -----------------------------------------------------------------------
955 // 18. Pipeline transformer interface (fit_pipeline / transform_pipeline)
956 // -----------------------------------------------------------------------
957
958 #[test]
959 fn test_pipeline_transformer_interface() {
960 let x = make_x();
961 let y = Array1::<f64>::zeros(4);
962 let ct = ColumnTransformer::new(
963 vec![(
964 "std".into(),
965 Box::new(StandardScaler::<f64>::new()),
966 ColumnSelector::Indices(vec![0, 1]),
967 )],
968 Remainder::Passthrough,
969 );
970 let fitted_box = ct.fit_pipeline(&x, &y).unwrap();
971 let out = fitted_box.transform_pipeline(&x).unwrap();
972 assert_eq!(out.nrows(), 4);
973 assert_eq!(out.ncols(), 4);
974 }
975
976 // -----------------------------------------------------------------------
977 // 19. Remainder passthrough values are identical to input values
978 // -----------------------------------------------------------------------
979
980 #[test]
981 fn test_passthrough_values_are_exact() {
982 let x = array![[10.0_f64, 20.0, 30.0], [40.0, 50.0, 60.0],];
983 // Only transform col 0; cols 1 and 2 pass through.
984 let ct = ColumnTransformer::new(
985 vec![(
986 "mm".into(),
987 Box::new(MinMaxScaler::<f64>::new()),
988 ColumnSelector::Indices(vec![0]),
989 )],
990 Remainder::Passthrough,
991 );
992 let fitted = ct.fit(&x, &()).unwrap();
993 let out = fitted.transform(&x).unwrap();
994 // out[:, 1] == x[:, 1] and out[:, 2] == x[:, 2]
995 assert_abs_diff_eq!(out[[0, 1]], 20.0, epsilon = 1e-12);
996 assert_abs_diff_eq!(out[[1, 1]], 50.0, epsilon = 1e-12);
997 assert_abs_diff_eq!(out[[0, 2]], 30.0, epsilon = 1e-12);
998 assert_abs_diff_eq!(out[[1, 2]], 60.0, epsilon = 1e-12);
999 }
1000
1001 // -----------------------------------------------------------------------
1002 // 20. Transformer names from explicit ColumnTransformer::new
1003 // -----------------------------------------------------------------------
1004
1005 #[test]
1006 fn test_transformer_names_explicit() {
1007 let x = make_x();
1008 let ct = ColumnTransformer::new(
1009 vec![
1010 (
1011 "alpha".into(),
1012 Box::new(StandardScaler::<f64>::new()),
1013 ColumnSelector::Indices(vec![0]),
1014 ),
1015 (
1016 "beta".into(),
1017 Box::new(MinMaxScaler::<f64>::new()),
1018 ColumnSelector::Indices(vec![1]),
1019 ),
1020 ],
1021 Remainder::Drop,
1022 );
1023 let fitted = ct.fit(&x, &()).unwrap();
1024 assert_eq!(fitted.transformer_names(), vec!["alpha", "beta"]);
1025 }
1026
1027 // -----------------------------------------------------------------------
1028 // 21. make_column_transformer with single step
1029 // -----------------------------------------------------------------------
1030
1031 #[test]
1032 fn test_make_column_transformer_single() {
1033 let x = array![[1.0_f64, 2.0], [3.0, 4.0]];
1034 let ct = make_column_transformer(
1035 vec![(
1036 Box::new(StandardScaler::<f64>::new()),
1037 ColumnSelector::Indices(vec![0, 1]),
1038 )],
1039 Remainder::Drop,
1040 );
1041 let fitted = ct.fit(&x, &()).unwrap();
1042 assert_eq!(fitted.transformer_names(), vec!["transformer-0"]);
1043 let out = fitted.transform(&x).unwrap();
1044 assert_eq!(out.shape(), &[2, 2]);
1045 }
1046
1047 // -----------------------------------------------------------------------
1048 // 22. Edge case: all columns as remainder with Passthrough
1049 // -----------------------------------------------------------------------
1050
1051 #[test]
1052 fn test_all_remainder_passthrough_unchanged() {
1053 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1054 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
1055 let fitted = ct.fit(&x, &()).unwrap();
1056 let out = fitted.transform(&x).unwrap();
1057 assert_eq!(out.shape(), &[2, 3]);
1058 for i in 0..2 {
1059 for j in 0..3 {
1060 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
1061 }
1062 }
1063 }
1064}