Skip to main content

ferrolearn_preprocess/
column_transformer.rs

1//! Column transformer: apply different transformers to different column subsets.
2//!
3//! [`ColumnTransformer`] applies each registered transformer to its designated
4//! column subset, then horizontally concatenates the outputs into a single
5//! `Array2<f64>`. Columns not captured by any transformer can be dropped or
6//! passed through unchanged via the [`Remainder`] policy.
7//!
8//! # Examples
9//!
10//! ```
11//! use ferrolearn_preprocess::column_transformer::{
12//!     ColumnSelector, ColumnTransformer, Remainder,
13//! };
14//! use ferrolearn_preprocess::{StandardScaler, MinMaxScaler};
15//! use ferrolearn_core::Fit;
16//! use ferrolearn_core::Transform;
17//! use ndarray::array;
18//!
19//! let x = array![
20//!     [1.0_f64, 2.0, 10.0, 20.0],
21//!     [3.0, 4.0, 30.0, 40.0],
22//!     [5.0, 6.0, 50.0, 60.0],
23//! ];
24//!
25//! let ct = ColumnTransformer::new(
26//!     vec![
27//!         ("std".into(),  Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1])),
28//!         ("mm".into(),   Box::new(MinMaxScaler::<f64>::new()),   ColumnSelector::Indices(vec![2, 3])),
29//!     ],
30//!     Remainder::Drop,
31//! );
32//!
33//! let fitted = ct.fit(&x, &()).unwrap();
34//! let out    = fitted.transform(&x).unwrap();
35//! assert_eq!(out.ncols(), 4);
36//! assert_eq!(out.nrows(), 3);
37//! ```
38
39use ferrolearn_core::error::FerroError;
40use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
41use ferrolearn_core::traits::{Fit, Transform};
42use ndarray::{Array1, Array2};
43
44// ---------------------------------------------------------------------------
45// ColumnSelector
46// ---------------------------------------------------------------------------
47
48/// Specifies which columns a transformer should operate on.
49///
50/// Currently the only supported variant is [`Indices`](ColumnSelector::Indices),
51/// which selects columns by their zero-based integer positions.
52#[derive(Debug, Clone)]
53pub enum ColumnSelector {
54    /// Select columns by zero-based index.
55    ///
56    /// The indices do not need to be sorted, but every index must be strictly
57    /// less than the number of columns in the input matrix. Duplicate indices
58    /// are allowed; the same column will simply appear twice in the sub-matrix
59    /// passed to the transformer.
60    Indices(Vec<usize>),
61}
62
63impl ColumnSelector {
64    /// Resolve the selector to a concrete list of column indices.
65    ///
66    /// # Errors
67    ///
68    /// Returns [`FerroError::InvalidParameter`] if any index is out of range
69    /// (i.e., `>= n_features`).
70    fn resolve(&self, n_features: usize) -> Result<Vec<usize>, FerroError> {
71        match self {
72            ColumnSelector::Indices(indices) => {
73                for &idx in indices {
74                    if idx >= n_features {
75                        return Err(FerroError::InvalidParameter {
76                            name: "ColumnSelector::Indices".into(),
77                            reason: format!(
78                                "column index {idx} is out of range for input with {n_features} features"
79                            ),
80                        });
81                    }
82                }
83                Ok(indices.clone())
84            }
85        }
86    }
87}
88
89// ---------------------------------------------------------------------------
90// Remainder
91// ---------------------------------------------------------------------------
92
93/// Policy for columns that are not selected by any transformer.
94///
95/// When at least one column is not covered by any registered transformer,
96/// `Remainder` determines what happens to those columns in the output.
97#[derive(Debug, Clone)]
98pub enum Remainder {
99    /// Discard remainder columns — they do not appear in the output.
100    Drop,
101    /// Pass remainder columns through unchanged, appended after all
102    /// transformer outputs.
103    Passthrough,
104}
105
106// ---------------------------------------------------------------------------
107// Helper: extract a sub-matrix by column indices
108// ---------------------------------------------------------------------------
109
110/// Build a new `Array2<f64>` containing only the columns at `indices`.
111///
112/// Columns are emitted in the order they appear in `indices`.
113fn select_columns(x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
114    let nrows = x.nrows();
115    let ncols = indices.len();
116    if ncols == 0 {
117        return Array2::zeros((nrows, 0));
118    }
119    let mut out = Array2::zeros((nrows, ncols));
120    for (new_j, &old_j) in indices.iter().enumerate() {
121        out.column_mut(new_j).assign(&x.column(old_j));
122    }
123    out
124}
125
126/// Horizontally concatenate a slice of `Array2<f64>` matrices.
127///
128/// All matrices must have the same number of rows.
129///
130/// # Errors
131///
132/// Returns [`FerroError::ShapeMismatch`] if row counts differ.
133fn hstack(matrices: &[Array2<f64>]) -> Result<Array2<f64>, FerroError> {
134    if matrices.is_empty() {
135        return Ok(Array2::zeros((0, 0)));
136    }
137    let nrows = matrices[0].nrows();
138    let total_cols: usize = matrices.iter().map(|m| m.ncols()).sum();
139
140    // Handle the case where the first matrix establishes nrows = 0 separately.
141    if total_cols == 0 {
142        return Ok(Array2::zeros((nrows, 0)));
143    }
144
145    let mut out = Array2::zeros((nrows, total_cols));
146    let mut col_offset = 0;
147    for m in matrices {
148        if m.nrows() != nrows {
149            return Err(FerroError::ShapeMismatch {
150                expected: vec![nrows, m.ncols()],
151                actual: vec![m.nrows(), m.ncols()],
152                context: "ColumnTransformer hstack: row count mismatch".into(),
153            });
154        }
155        let end = col_offset + m.ncols();
156        if m.ncols() > 0 {
157            out.slice_mut(ndarray::s![.., col_offset..end]).assign(m);
158        }
159        col_offset = end;
160    }
161    Ok(out)
162}
163
164// ---------------------------------------------------------------------------
165// ColumnTransformer (unfitted)
166// ---------------------------------------------------------------------------
167
168/// An unfitted column transformer.
169///
170/// Applies each registered transformer to its designated column subset, then
171/// horizontally concatenates all outputs. The [`Remainder`] policy controls
172/// what happens to columns not covered by any transformer.
173///
174/// # Transformer order
175///
176/// Transformers are applied and their outputs concatenated in the order they
177/// were registered. Remainder columns (when
178/// `remainder = `[`Remainder::Passthrough`]) are appended last.
179///
180/// # Overlapping selections
181///
182/// Each transformer receives its own copy of the selected columns, so
183/// overlapping `ColumnSelector`s are fully supported.
184///
185/// # Examples
186///
187/// ```
188/// use ferrolearn_preprocess::column_transformer::{
189///     ColumnSelector, ColumnTransformer, Remainder,
190/// };
191/// use ferrolearn_preprocess::StandardScaler;
192/// use ferrolearn_core::Fit;
193/// use ferrolearn_core::Transform;
194/// use ndarray::array;
195///
196/// let x = array![[1.0_f64, 10.0, 100.0], [2.0, 20.0, 200.0], [3.0, 30.0, 300.0]];
197/// let ct = ColumnTransformer::new(
198///     vec![("scaler".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
199///     Remainder::Passthrough,
200/// );
201/// let fitted = ct.fit(&x, &()).unwrap();
202/// let out = fitted.transform(&x).unwrap();
203/// // 2 scaled columns + 1 passthrough column
204/// assert_eq!(out.ncols(), 3);
205/// ```
206pub struct ColumnTransformer {
207    /// Named transformer steps with their column selectors.
208    transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
209    /// Policy for columns not covered by any transformer.
210    remainder: Remainder,
211}
212
213impl ColumnTransformer {
214    /// Create a new `ColumnTransformer`.
215    ///
216    /// # Parameters
217    ///
218    /// - `transformers`: A list of `(name, transformer, selector)` triples.
219    /// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
220    #[must_use]
221    pub fn new(
222        transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
223        remainder: Remainder,
224    ) -> Self {
225        Self {
226            transformers,
227            remainder,
228        }
229    }
230}
231
232// ---------------------------------------------------------------------------
233// Fit implementation
234// ---------------------------------------------------------------------------
235
236impl Fit<Array2<f64>, ()> for ColumnTransformer {
237    type Fitted = FittedColumnTransformer;
238    type Error = FerroError;
239
240    /// Fit each transformer on its selected column subset.
241    ///
242    /// Validates that all selected column indices are within bounds before
243    /// fitting any transformer.
244    ///
245    /// # Errors
246    ///
247    /// - [`FerroError::InvalidParameter`] if any column index is out of range.
248    /// - Propagates any error returned by an individual transformer's
249    ///   `fit_pipeline` call.
250    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedColumnTransformer, FerroError> {
251        let n_features = x.ncols();
252        let n_rows = x.nrows();
253
254        // A dummy y vector required by PipelineTransformer::fit_pipeline.
255        let dummy_y = Array1::<f64>::zeros(n_rows);
256
257        // Resolve all selectors up front to validate indices eagerly.
258        let mut resolved_selectors: Vec<Vec<usize>> = Vec::with_capacity(self.transformers.len());
259        for (name, _, selector) in &self.transformers {
260            let indices = selector.resolve(n_features).map_err(|e| {
261                // Enrich the error with the transformer name.
262                FerroError::InvalidParameter {
263                    name: format!("ColumnTransformer step '{name}'"),
264                    reason: e.to_string(),
265                }
266            })?;
267            resolved_selectors.push(indices);
268        }
269
270        // Build the set of covered column indices (for remainder computation).
271        let covered: std::collections::HashSet<usize> = resolved_selectors
272            .iter()
273            .flat_map(|v| v.iter().copied())
274            .collect();
275
276        let remainder_indices: Vec<usize> =
277            (0..n_features).filter(|c| !covered.contains(c)).collect();
278
279        // Fit each transformer on its sub-matrix.
280        let mut fitted_transformers: Vec<FittedSubTransformer> =
281            Vec::with_capacity(self.transformers.len());
282
283        for ((name, transformer, _), indices) in
284            self.transformers.iter().zip(resolved_selectors.into_iter())
285        {
286            let sub_x = select_columns(x, &indices);
287            let fitted = transformer.fit_pipeline(&sub_x, &dummy_y)?;
288            fitted_transformers.push((name.clone(), fitted, indices));
289        }
290
291        Ok(FittedColumnTransformer {
292            fitted_transformers,
293            remainder: self.remainder.clone(),
294            remainder_indices,
295            n_features_in: n_features,
296        })
297    }
298}
299
300// ---------------------------------------------------------------------------
301// PipelineTransformer implementation
302// ---------------------------------------------------------------------------
303
304impl PipelineTransformer<f64> for ColumnTransformer {
305    /// Fit the column transformer using the pipeline interface.
306    ///
307    /// The `y` argument is ignored; it exists only for API compatibility.
308    ///
309    /// # Errors
310    ///
311    /// Propagates errors from [`Fit::fit`].
312    fn fit_pipeline(
313        &self,
314        x: &Array2<f64>,
315        _y: &Array1<f64>,
316    ) -> Result<Box<dyn FittedPipelineTransformer<f64>>, FerroError> {
317        let fitted = self.fit(x, &())?;
318        Ok(Box::new(fitted))
319    }
320}
321
322// ---------------------------------------------------------------------------
323// FittedColumnTransformer
324// ---------------------------------------------------------------------------
325
326/// A named, fitted sub-transformer with its column indices.
327type FittedSubTransformer = (String, Box<dyn FittedPipelineTransformer<f64>>, Vec<usize>);
328
329/// A fitted column transformer holding fitted sub-transformers and metadata.
330///
331/// Created by calling [`Fit::fit`] on a [`ColumnTransformer`].
332/// Implements [`Transform<Array2<f64>>`] to apply the fitted transformers and
333/// concatenate their outputs, as well as [`FittedPipelineTransformer`] for use
334/// inside a [`ferrolearn_core::pipeline::Pipeline`].
335pub struct FittedColumnTransformer {
336    /// Fitted transformers with their associated column indices.
337    fitted_transformers: Vec<FittedSubTransformer>,
338    /// Remainder policy from the original [`ColumnTransformer`].
339    remainder: Remainder,
340    /// Column indices not covered by any transformer.
341    remainder_indices: Vec<usize>,
342    /// Number of input features seen during fitting.
343    n_features_in: usize,
344}
345
346impl FittedColumnTransformer {
347    /// Return the number of input features seen during fitting.
348    #[must_use]
349    pub fn n_features_in(&self) -> usize {
350        self.n_features_in
351    }
352
353    /// Return the names of all registered transformer steps.
354    #[must_use]
355    pub fn transformer_names(&self) -> Vec<&str> {
356        self.fitted_transformers
357            .iter()
358            .map(|(name, _, _)| name.as_str())
359            .collect()
360    }
361
362    /// Return the remainder column indices (columns not selected by any transformer).
363    #[must_use]
364    pub fn remainder_indices(&self) -> &[usize] {
365        &self.remainder_indices
366    }
367}
368
369// ---------------------------------------------------------------------------
370// Transform implementation
371// ---------------------------------------------------------------------------
372
373impl Transform<Array2<f64>> for FittedColumnTransformer {
374    type Output = Array2<f64>;
375    type Error = FerroError;
376
377    /// Transform data by applying each fitted transformer to its column subset,
378    /// then horizontally concatenating all outputs.
379    ///
380    /// When `remainder = Passthrough`, the unselected columns are appended
381    /// after all transformer outputs. When `remainder = Drop`, they are
382    /// discarded.
383    ///
384    /// # Errors
385    ///
386    /// - [`FerroError::ShapeMismatch`] if the input does not have
387    ///   `n_features_in` columns.
388    /// - Propagates any error from individual transformer `transform_pipeline`
389    ///   calls.
390    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
391        if x.ncols() != self.n_features_in {
392            return Err(FerroError::ShapeMismatch {
393                expected: vec![x.nrows(), self.n_features_in],
394                actual: vec![x.nrows(), x.ncols()],
395                context: "FittedColumnTransformer::transform".into(),
396            });
397        }
398
399        let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.fitted_transformers.len() + 1);
400
401        for (_, fitted, indices) in &self.fitted_transformers {
402            let sub_x = select_columns(x, indices);
403            let transformed = fitted.transform_pipeline(&sub_x)?;
404            parts.push(transformed);
405        }
406
407        // Append remainder columns if requested.
408        if matches!(self.remainder, Remainder::Passthrough) && !self.remainder_indices.is_empty() {
409            let remainder_sub = select_columns(x, &self.remainder_indices);
410            parts.push(remainder_sub);
411        }
412
413        hstack(&parts)
414    }
415}
416
417// ---------------------------------------------------------------------------
418// FittedPipelineTransformer implementation
419// ---------------------------------------------------------------------------
420
421impl FittedPipelineTransformer<f64> for FittedColumnTransformer {
422    /// Transform data using the pipeline interface.
423    ///
424    /// # Errors
425    ///
426    /// Propagates errors from [`Transform::transform`].
427    fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
428        self.transform(x)
429    }
430}
431
432// ---------------------------------------------------------------------------
433// make_column_transformer convenience function
434// ---------------------------------------------------------------------------
435
436/// Convenience function to build a [`ColumnTransformer`] with auto-generated
437/// step names.
438///
439/// Steps are named `"transformer-0"`, `"transformer-1"`, etc.
440///
441/// # Parameters
442///
443/// - `transformers`: A list of `(transformer, selector)` pairs.
444/// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
445///
446/// # Examples
447///
448/// ```
449/// use ferrolearn_preprocess::column_transformer::{
450///     make_column_transformer, ColumnSelector, Remainder,
451/// };
452/// use ferrolearn_preprocess::StandardScaler;
453/// use ferrolearn_core::Fit;
454/// use ferrolearn_core::Transform;
455/// use ndarray::array;
456///
457/// let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0]];
458/// let ct = make_column_transformer(
459///     vec![(Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
460///     Remainder::Drop,
461/// );
462/// let fitted = ct.fit(&x, &()).unwrap();
463/// let out = fitted.transform(&x).unwrap();
464/// assert_eq!(out.ncols(), 2);
465/// ```
466#[must_use]
467pub fn make_column_transformer(
468    transformers: Vec<(Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
469    remainder: Remainder,
470) -> ColumnTransformer {
471    let named: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)> = transformers
472        .into_iter()
473        .enumerate()
474        .map(|(i, (t, s))| (format!("transformer-{i}"), t, s))
475        .collect();
476    ColumnTransformer::new(named, remainder)
477}
478
479// ---------------------------------------------------------------------------
480// Tests
481// ---------------------------------------------------------------------------
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486    use approx::assert_abs_diff_eq;
487    use ferrolearn_core::pipeline::{Pipeline, PipelineEstimator};
488    use ndarray::{Array2, array};
489
490    use crate::{MinMaxScaler, StandardScaler};
491
492    // -----------------------------------------------------------------------
493    // Helpers
494    // -----------------------------------------------------------------------
495
496    /// Build a simple 4-column test matrix (rows = 4, cols = 4).
497    fn make_x() -> Array2<f64> {
498        array![
499            [1.0, 2.0, 10.0, 20.0],
500            [2.0, 4.0, 20.0, 40.0],
501            [3.0, 6.0, 30.0, 60.0],
502            [4.0, 8.0, 40.0, 80.0],
503        ]
504    }
505
506    // -----------------------------------------------------------------------
507    // 1. Basic 2-transformer usage
508    // -----------------------------------------------------------------------
509
510    #[test]
511    fn test_basic_two_transformers_drop_remainder() {
512        let x = make_x(); // 4×4
513        let ct = ColumnTransformer::new(
514            vec![
515                (
516                    "std".into(),
517                    Box::new(StandardScaler::<f64>::new()),
518                    ColumnSelector::Indices(vec![0, 1]),
519                ),
520                (
521                    "mm".into(),
522                    Box::new(MinMaxScaler::<f64>::new()),
523                    ColumnSelector::Indices(vec![2, 3]),
524                ),
525            ],
526            Remainder::Drop,
527        );
528
529        let fitted = ct.fit(&x, &()).unwrap();
530        let out = fitted.transform(&x).unwrap();
531
532        // All 4 columns covered → no remainder; output is 4 cols
533        assert_eq!(out.nrows(), 4);
534        assert_eq!(out.ncols(), 4);
535    }
536
537    // -----------------------------------------------------------------------
538    // 2. Remainder::Drop drops uncovered columns
539    // -----------------------------------------------------------------------
540
541    #[test]
542    fn test_remainder_drop() {
543        let x = make_x(); // 4×4
544        // Only cover cols 0 and 1 — cols 2 and 3 should be dropped.
545        let ct = ColumnTransformer::new(
546            vec![(
547                "std".into(),
548                Box::new(StandardScaler::<f64>::new()),
549                ColumnSelector::Indices(vec![0, 1]),
550            )],
551            Remainder::Drop,
552        );
553
554        let fitted = ct.fit(&x, &()).unwrap();
555        let out = fitted.transform(&x).unwrap();
556
557        assert_eq!(out.nrows(), 4);
558        assert_eq!(out.ncols(), 2, "uncovered cols should be dropped");
559    }
560
561    // -----------------------------------------------------------------------
562    // 3. Remainder::Passthrough passes uncovered columns through unchanged
563    // -----------------------------------------------------------------------
564
565    #[test]
566    fn test_remainder_passthrough() {
567        let x = make_x(); // 4×4
568        // Only cover cols 0 and 1 — cols 2 and 3 should pass through.
569        let ct = ColumnTransformer::new(
570            vec![(
571                "std".into(),
572                Box::new(StandardScaler::<f64>::new()),
573                ColumnSelector::Indices(vec![0, 1]),
574            )],
575            Remainder::Passthrough,
576        );
577
578        let fitted = ct.fit(&x, &()).unwrap();
579        let out = fitted.transform(&x).unwrap();
580
581        assert_eq!(out.nrows(), 4);
582        assert_eq!(out.ncols(), 4, "passthrough: 2 transformed + 2 remainder");
583
584        // The last 2 columns should be the original cols 2 and 3.
585        for i in 0..4 {
586            assert_abs_diff_eq!(out[[i, 2]], x[[i, 2]], epsilon = 1e-12);
587            assert_abs_diff_eq!(out[[i, 3]], x[[i, 3]], epsilon = 1e-12);
588        }
589    }
590
591    // -----------------------------------------------------------------------
592    // 4. Invalid column index (out of range)
593    // -----------------------------------------------------------------------
594
595    #[test]
596    fn test_invalid_column_index_out_of_range() {
597        let x = make_x(); // 4×4 — valid indices are 0..3
598        let ct = ColumnTransformer::new(
599            vec![(
600                "std".into(),
601                Box::new(StandardScaler::<f64>::new()),
602                ColumnSelector::Indices(vec![0, 99]), // 99 is out of range
603            )],
604            Remainder::Drop,
605        );
606        let result = ct.fit(&x, &());
607        assert!(result.is_err(), "expected error for out-of-range index");
608    }
609
610    // -----------------------------------------------------------------------
611    // 5. Empty transformer list with Remainder::Drop
612    // -----------------------------------------------------------------------
613
614    #[test]
615    fn test_empty_transformer_list_drop() {
616        let x = make_x();
617        let ct = ColumnTransformer::new(vec![], Remainder::Drop);
618        let fitted = ct.fit(&x, &()).unwrap();
619        let out = fitted.transform(&x).unwrap();
620        // No transformers, remainder dropped → empty output
621        assert_eq!(out.nrows(), 0, "hstack of nothing with no passthrough");
622    }
623
624    // -----------------------------------------------------------------------
625    // 6. Empty transformer list with Remainder::Passthrough
626    // -----------------------------------------------------------------------
627
628    #[test]
629    fn test_empty_transformer_list_passthrough() {
630        let x = make_x(); // 4×4
631        let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
632        let fitted = ct.fit(&x, &()).unwrap();
633        let out = fitted.transform(&x).unwrap();
634        // No transformers, all columns pass through unchanged.
635        assert_eq!(out.nrows(), 4);
636        assert_eq!(out.ncols(), 4);
637        for i in 0..4 {
638            for j in 0..4 {
639                assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
640            }
641        }
642    }
643
644    // -----------------------------------------------------------------------
645    // 7. Overlapping column selections
646    // -----------------------------------------------------------------------
647
648    #[test]
649    fn test_overlapping_column_selections() {
650        let x = make_x(); // 4×4
651        // Both transformers select col 0 (overlapping is allowed).
652        let ct = ColumnTransformer::new(
653            vec![
654                (
655                    "std1".into(),
656                    Box::new(StandardScaler::<f64>::new()),
657                    ColumnSelector::Indices(vec![0, 1]),
658                ),
659                (
660                    "mm1".into(),
661                    Box::new(MinMaxScaler::<f64>::new()),
662                    ColumnSelector::Indices(vec![0, 2]), // col 0 also used here
663                ),
664            ],
665            Remainder::Drop,
666        );
667
668        let fitted = ct.fit(&x, &()).unwrap();
669        let out = fitted.transform(&x).unwrap();
670
671        // Output: 2 cols from std1 + 2 cols from mm1 = 4 cols
672        assert_eq!(out.nrows(), 4);
673        assert_eq!(out.ncols(), 4);
674    }
675
676    // -----------------------------------------------------------------------
677    // 8. Single transformer
678    // -----------------------------------------------------------------------
679
680    #[test]
681    fn test_single_transformer() {
682        let x = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
683        let ct = ColumnTransformer::new(
684            vec![(
685                "mm".into(),
686                Box::new(MinMaxScaler::<f64>::new()),
687                ColumnSelector::Indices(vec![0, 1]),
688            )],
689            Remainder::Drop,
690        );
691
692        let fitted = ct.fit(&x, &()).unwrap();
693        let out = fitted.transform(&x).unwrap();
694
695        assert_eq!(out.nrows(), 3);
696        assert_eq!(out.ncols(), 2);
697
698        // MinMax on cols 0 and 1: first row → 0.0, last row → 1.0
699        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
700        assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10);
701        assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10);
702        assert_abs_diff_eq!(out[[2, 1]], 1.0, epsilon = 1e-10);
703    }
704
705    // -----------------------------------------------------------------------
706    // 9. make_column_transformer convenience function
707    // -----------------------------------------------------------------------
708
709    #[test]
710    fn test_make_column_transformer_auto_names() {
711        let x = make_x();
712        let ct = make_column_transformer(
713            vec![
714                (
715                    Box::new(StandardScaler::<f64>::new()),
716                    ColumnSelector::Indices(vec![0, 1]),
717                ),
718                (
719                    Box::new(MinMaxScaler::<f64>::new()),
720                    ColumnSelector::Indices(vec![2, 3]),
721                ),
722            ],
723            Remainder::Drop,
724        );
725
726        let fitted = ct.fit(&x, &()).unwrap();
727        assert_eq!(
728            fitted.transformer_names(),
729            vec!["transformer-0", "transformer-1"]
730        );
731
732        let out = fitted.transform(&x).unwrap();
733        assert_eq!(out.nrows(), 4);
734        assert_eq!(out.ncols(), 4);
735    }
736
737    // -----------------------------------------------------------------------
738    // 10. Pipeline integration
739    // -----------------------------------------------------------------------
740
741    #[test]
742    fn test_pipeline_integration() {
743        // Wrap a ColumnTransformer as a pipeline step.
744        let x = make_x();
745        let y = Array1::<f64>::zeros(4);
746
747        let ct = ColumnTransformer::new(
748            vec![(
749                "std".into(),
750                Box::new(StandardScaler::<f64>::new()),
751                ColumnSelector::Indices(vec![0, 1, 2, 3]),
752            )],
753            Remainder::Drop,
754        );
755
756        // Use a trivial estimator that sums rows.
757        struct SumEstimator;
758        impl PipelineEstimator<f64> for SumEstimator {
759            fn fit_pipeline(
760                &self,
761                _x: &Array2<f64>,
762                _y: &Array1<f64>,
763            ) -> Result<Box<dyn ferrolearn_core::pipeline::FittedPipelineEstimator<f64>>, FerroError>
764            {
765                Ok(Box::new(FittedSum))
766            }
767        }
768        struct FittedSum;
769        impl ferrolearn_core::pipeline::FittedPipelineEstimator<f64> for FittedSum {
770            fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
771                let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
772                Ok(Array1::from_vec(sums))
773            }
774        }
775
776        let pipeline = Pipeline::new()
777            .transform_step("ct", Box::new(ct))
778            .estimator_step("sum", Box::new(SumEstimator));
779
780        use ferrolearn_core::Fit as _;
781        let fitted_pipeline = pipeline.fit(&x, &y).unwrap();
782
783        use ferrolearn_core::Predict as _;
784        let preds = fitted_pipeline.predict(&x).unwrap();
785        assert_eq!(preds.len(), 4);
786    }
787
788    // -----------------------------------------------------------------------
789    // 11. Transform shape correctness — number of output columns
790    // -----------------------------------------------------------------------
791
792    #[test]
793    fn test_output_shape_all_selected_drop() {
794        let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
795        let ct = ColumnTransformer::new(
796            vec![
797                (
798                    "s".into(),
799                    Box::new(StandardScaler::<f64>::new()),
800                    ColumnSelector::Indices(vec![0]),
801                ),
802                (
803                    "m".into(),
804                    Box::new(MinMaxScaler::<f64>::new()),
805                    ColumnSelector::Indices(vec![1, 2]),
806                ),
807            ],
808            Remainder::Drop,
809        );
810        let fitted = ct.fit(&x, &()).unwrap();
811        let out = fitted.transform(&x).unwrap();
812        assert_eq!(out.shape(), &[2, 3]);
813    }
814
815    // -----------------------------------------------------------------------
816    // 12. Transform shape — partial selection + passthrough
817    // -----------------------------------------------------------------------
818
819    #[test]
820    fn test_output_shape_partial_passthrough() {
821        // 5-column input, transform 2 cols, passthrough 3
822        let x =
823            Array2::<f64>::from_shape_vec((3, 5), (1..=15).map(|v| v as f64).collect()).unwrap();
824        let ct = ColumnTransformer::new(
825            vec![(
826                "std".into(),
827                Box::new(StandardScaler::<f64>::new()),
828                ColumnSelector::Indices(vec![0, 1]),
829            )],
830            Remainder::Passthrough,
831        );
832        let fitted = ct.fit(&x, &()).unwrap();
833        let out = fitted.transform(&x).unwrap();
834        assert_eq!(out.shape(), &[3, 5]);
835    }
836
837    // -----------------------------------------------------------------------
838    // 13. n_features_in accessor
839    // -----------------------------------------------------------------------
840
841    #[test]
842    fn test_n_features_in() {
843        let x = make_x(); // 4×4
844        let ct = ColumnTransformer::new(
845            vec![(
846                "std".into(),
847                Box::new(StandardScaler::<f64>::new()),
848                ColumnSelector::Indices(vec![0]),
849            )],
850            Remainder::Drop,
851        );
852        let fitted = ct.fit(&x, &()).unwrap();
853        assert_eq!(fitted.n_features_in(), 4);
854    }
855
856    // -----------------------------------------------------------------------
857    // 14. Shape mismatch on transform (wrong number of columns)
858    // -----------------------------------------------------------------------
859
860    #[test]
861    fn test_shape_mismatch_on_transform() {
862        let x = make_x(); // 4×4
863        let ct = ColumnTransformer::new(
864            vec![(
865                "std".into(),
866                Box::new(StandardScaler::<f64>::new()),
867                ColumnSelector::Indices(vec![0, 1]),
868            )],
869            Remainder::Drop,
870        );
871        let fitted = ct.fit(&x, &()).unwrap();
872
873        // Now pass a matrix with only 2 columns — should fail.
874        let x_bad = array![[1.0_f64, 2.0], [3.0, 4.0]];
875        let result = fitted.transform(&x_bad);
876        assert!(result.is_err(), "expected shape mismatch error");
877    }
878
879    // -----------------------------------------------------------------------
880    // 15. remainder_indices accessor
881    // -----------------------------------------------------------------------
882
883    #[test]
884    fn test_remainder_indices_accessor() {
885        let x = make_x(); // 4×4
886        let ct = ColumnTransformer::new(
887            vec![(
888                "std".into(),
889                Box::new(StandardScaler::<f64>::new()),
890                ColumnSelector::Indices(vec![0, 2]),
891            )],
892            Remainder::Passthrough,
893        );
894        let fitted = ct.fit(&x, &()).unwrap();
895        // Remainder should be cols 1 and 3.
896        assert_eq!(fitted.remainder_indices(), &[1, 3]);
897    }
898
899    // -----------------------------------------------------------------------
900    // 16. StandardScaler output values are correct (zero-mean)
901    // -----------------------------------------------------------------------
902
903    #[test]
904    fn test_standard_scaler_zero_mean_in_output() {
905        let x = array![[1.0_f64, 100.0, 0.5], [2.0, 200.0, 1.5], [3.0, 300.0, 2.5],];
906        let ct = ColumnTransformer::new(
907            vec![(
908                "std".into(),
909                Box::new(StandardScaler::<f64>::new()),
910                ColumnSelector::Indices(vec![0, 1]),
911            )],
912            Remainder::Drop,
913        );
914        let fitted = ct.fit(&x, &()).unwrap();
915        let out = fitted.transform(&x).unwrap();
916
917        // Cols 0 and 1 of output should have mean ≈ 0.
918        for j in 0..2 {
919            let mean: f64 = out.column(j).iter().sum::<f64>() / 3.0;
920            assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
921        }
922    }
923
924    // -----------------------------------------------------------------------
925    // 17. MinMaxScaler output values are in [0, 1]
926    // -----------------------------------------------------------------------
927
928    #[test]
929    fn test_min_max_values_in_range() {
930        let x = make_x();
931        let ct = ColumnTransformer::new(
932            vec![(
933                "mm".into(),
934                Box::new(MinMaxScaler::<f64>::new()),
935                ColumnSelector::Indices(vec![0, 1, 2, 3]),
936            )],
937            Remainder::Drop,
938        );
939        let fitted = ct.fit(&x, &()).unwrap();
940        let out = fitted.transform(&x).unwrap();
941
942        for j in 0..4 {
943            let col_min = out.column(j).iter().copied().fold(f64::INFINITY, f64::min);
944            let col_max = out
945                .column(j)
946                .iter()
947                .copied()
948                .fold(f64::NEG_INFINITY, f64::max);
949            assert_abs_diff_eq!(col_min, 0.0, epsilon = 1e-10);
950            assert_abs_diff_eq!(col_max, 1.0, epsilon = 1e-10);
951        }
952    }
953
954    // -----------------------------------------------------------------------
955    // 18. Pipeline transformer interface (fit_pipeline / transform_pipeline)
956    // -----------------------------------------------------------------------
957
958    #[test]
959    fn test_pipeline_transformer_interface() {
960        let x = make_x();
961        let y = Array1::<f64>::zeros(4);
962        let ct = ColumnTransformer::new(
963            vec![(
964                "std".into(),
965                Box::new(StandardScaler::<f64>::new()),
966                ColumnSelector::Indices(vec![0, 1]),
967            )],
968            Remainder::Passthrough,
969        );
970        let fitted_box = ct.fit_pipeline(&x, &y).unwrap();
971        let out = fitted_box.transform_pipeline(&x).unwrap();
972        assert_eq!(out.nrows(), 4);
973        assert_eq!(out.ncols(), 4);
974    }
975
976    // -----------------------------------------------------------------------
977    // 19. Remainder passthrough values are identical to input values
978    // -----------------------------------------------------------------------
979
980    #[test]
981    fn test_passthrough_values_are_exact() {
982        let x = array![[10.0_f64, 20.0, 30.0], [40.0, 50.0, 60.0],];
983        // Only transform col 0; cols 1 and 2 pass through.
984        let ct = ColumnTransformer::new(
985            vec![(
986                "mm".into(),
987                Box::new(MinMaxScaler::<f64>::new()),
988                ColumnSelector::Indices(vec![0]),
989            )],
990            Remainder::Passthrough,
991        );
992        let fitted = ct.fit(&x, &()).unwrap();
993        let out = fitted.transform(&x).unwrap();
994        // out[:, 1] == x[:, 1] and out[:, 2] == x[:, 2]
995        assert_abs_diff_eq!(out[[0, 1]], 20.0, epsilon = 1e-12);
996        assert_abs_diff_eq!(out[[1, 1]], 50.0, epsilon = 1e-12);
997        assert_abs_diff_eq!(out[[0, 2]], 30.0, epsilon = 1e-12);
998        assert_abs_diff_eq!(out[[1, 2]], 60.0, epsilon = 1e-12);
999    }
1000
1001    // -----------------------------------------------------------------------
1002    // 20. Transformer names from explicit ColumnTransformer::new
1003    // -----------------------------------------------------------------------
1004
1005    #[test]
1006    fn test_transformer_names_explicit() {
1007        let x = make_x();
1008        let ct = ColumnTransformer::new(
1009            vec![
1010                (
1011                    "alpha".into(),
1012                    Box::new(StandardScaler::<f64>::new()),
1013                    ColumnSelector::Indices(vec![0]),
1014                ),
1015                (
1016                    "beta".into(),
1017                    Box::new(MinMaxScaler::<f64>::new()),
1018                    ColumnSelector::Indices(vec![1]),
1019                ),
1020            ],
1021            Remainder::Drop,
1022        );
1023        let fitted = ct.fit(&x, &()).unwrap();
1024        assert_eq!(fitted.transformer_names(), vec!["alpha", "beta"]);
1025    }
1026
1027    // -----------------------------------------------------------------------
1028    // 21. make_column_transformer with single step
1029    // -----------------------------------------------------------------------
1030
1031    #[test]
1032    fn test_make_column_transformer_single() {
1033        let x = array![[1.0_f64, 2.0], [3.0, 4.0]];
1034        let ct = make_column_transformer(
1035            vec![(
1036                Box::new(StandardScaler::<f64>::new()),
1037                ColumnSelector::Indices(vec![0, 1]),
1038            )],
1039            Remainder::Drop,
1040        );
1041        let fitted = ct.fit(&x, &()).unwrap();
1042        assert_eq!(fitted.transformer_names(), vec!["transformer-0"]);
1043        let out = fitted.transform(&x).unwrap();
1044        assert_eq!(out.shape(), &[2, 2]);
1045    }
1046
1047    // -----------------------------------------------------------------------
1048    // 22. Edge case: all columns as remainder with Passthrough
1049    // -----------------------------------------------------------------------
1050
1051    #[test]
1052    fn test_all_remainder_passthrough_unchanged() {
1053        let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1054        let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
1055        let fitted = ct.fit(&x, &()).unwrap();
1056        let out = fitted.transform(&x).unwrap();
1057        assert_eq!(out.shape(), &[2, 3]);
1058        for i in 0..2 {
1059            for j in 0..3 {
1060                assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
1061            }
1062        }
1063    }
1064}