Skip to main content

ferrolearn_preprocess/
column_transformer.rs

1//! Column transformer: apply different transformers to different column subsets.
2//!
3//! [`ColumnTransformer`] applies each registered transformer to its designated
4//! column subset, then horizontally concatenates the outputs into a single
5//! `Array2<f64>`. Columns not captured by any transformer can be dropped or
6//! passed through unchanged via the [`Remainder`] policy.
7//!
8//! # Examples
9//!
10//! ```
11//! use ferrolearn_preprocess::column_transformer::{
12//!     ColumnSelector, ColumnTransformer, Remainder,
13//! };
14//! use ferrolearn_preprocess::{StandardScaler, MinMaxScaler};
15//! use ferrolearn_core::Fit;
16//! use ferrolearn_core::Transform;
17//! use ndarray::array;
18//!
19//! let x = array![
20//!     [1.0_f64, 2.0, 10.0, 20.0],
21//!     [3.0, 4.0, 30.0, 40.0],
22//!     [5.0, 6.0, 50.0, 60.0],
23//! ];
24//!
25//! let ct = ColumnTransformer::new(
26//!     vec![
27//!         ("std".into(),  Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1])),
28//!         ("mm".into(),   Box::new(MinMaxScaler::<f64>::new()),   ColumnSelector::Indices(vec![2, 3])),
29//!     ],
30//!     Remainder::Drop,
31//! );
32//!
33//! let fitted = ct.fit(&x, &()).unwrap();
34//! let out    = fitted.transform(&x).unwrap();
35//! assert_eq!(out.ncols(), 4);
36//! assert_eq!(out.nrows(), 3);
37//! ```
38//!
39//! ## REQ status
40//!
41//! Translation target: scikit-learn 1.5.2 `ColumnTransformer` +
42//! `make_column_transformer` (`sklearn/compose/_column_transformer.py:59`).
43//! Tracking: #1434. Each REQ is BINARY — SHIPPED (impl + non-test consumer +
44//! tests + green verification) or NOT-STARTED (with a concrete open blocker).
45//! DETERMINISTIC composition meta-transformer: this unit owns the column
46//! routing / output ordering / remainder; sub-transformer VALUES come from
47//! their own routed units (StandardScaler, MinMaxScaler, …).
48//!
49//! | REQ | Scope | Status | Evidence / Blocker |
50//! |-----|-------|--------|--------------------|
51//! | REQ-1 | Column routing + output ORDERING (transformer outputs in registration/list order, remainder appended LAST in ascending column order) + `Drop`/`Passthrough` + OVERLAPPING columns + combined output VALUES (index selectors) | SHIPPED | [`ColumnTransformer`] `fit`/`transform` (hstack in list order + remainder `(0..n).filter(!covered)` ascending) matches sklearn `_hstack` `_column_transformer.py:976-1006,1091` + `_validate_remainder` `sorted(set(range)-cols)` `:546`; 8 oracle tests in `tests/divergence_column_transformer.rs`. Consumer: re-export `lib.rs:128` + `PipelineTransformer` |
52//! | REQ-2 | [`make_column_transformer`] builds a composing `ColumnTransformer` | SHIPPED (scoped) | auto-names `transformer-N` (sklearn uses lowercased class names — naming gap folded into REQ-7); oracle value test |
53//! | REQ-3 | Error/parameter contracts (out-of-range column index, transform ncols mismatch) | SHIPPED (scoped) | [`ColumnTransformer::fit`]/[`FittedColumnTransformer`] `transform`; divergence error tests |
54//! | REQ-4 | Non-index `ColumnSelector`s (str/slice/bool-mask/callable) + `make_column_selector` | NOT-STARTED | `Indices` only; sklearn `_column_transformer.py:1468` — blocker #1435 |
55//! | REQ-5 | `remainder` as an estimator + `'drop'`/`'passthrough'` as a STEP transformer | NOT-STARTED | sklearn `_column_transformer.py:277-281,460-462` — blocker #1436 |
56//! | REQ-6 | `sparse_threshold` + sparse output + `transformer_weights` + `n_jobs`/`verbose` | NOT-STARTED | dense only; sklearn `_column_transformer.py:282,998,1091-1200` — blocker #1437 |
57//! | REQ-7 | `get_feature_names_out` + `verbose_feature_names_out` (`name__feature`) + class-name auto-naming | NOT-STARTED | sklearn `_column_transformer.py:599,662,1456-1465` — blocker #1438 |
58//! | REQ-8 | `named_transformers_`/`transformers_`/`n_features_in_`/`feature_names_in_` fitted-attr surface | NOT-STARTED | sklearn `_column_transformer.py:574-582,694-726` — blocker #1439 |
59//! | REQ-9 | Generic `F` (currently `f64`-only) | NOT-STARTED | R-DEV-3 — blocker #1440 |
60//! | REQ-10 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1441 |
61//! | REQ-11 | ferray substrate | NOT-STARTED | dense `Array2<f64>` only — blocker #1442 |
62
63use ferrolearn_core::error::FerroError;
64use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
65use ferrolearn_core::traits::{Fit, Transform};
66use ndarray::{Array1, Array2};
67
68// ---------------------------------------------------------------------------
69// ColumnSelector
70// ---------------------------------------------------------------------------
71
72/// Specifies which columns a transformer should operate on.
73///
74/// Currently the only supported variant is [`Indices`](ColumnSelector::Indices),
75/// which selects columns by their zero-based integer positions.
76#[derive(Debug, Clone)]
77pub enum ColumnSelector {
78    /// Select columns by zero-based index.
79    ///
80    /// The indices do not need to be sorted, but every index must be strictly
81    /// less than the number of columns in the input matrix. Duplicate indices
82    /// are allowed; the same column will simply appear twice in the sub-matrix
83    /// passed to the transformer.
84    Indices(Vec<usize>),
85}
86
87impl ColumnSelector {
88    /// Resolve the selector to a concrete list of column indices.
89    ///
90    /// # Errors
91    ///
92    /// Returns [`FerroError::InvalidParameter`] if any index is out of range
93    /// (i.e., `>= n_features`).
94    fn resolve(&self, n_features: usize) -> Result<Vec<usize>, FerroError> {
95        match self {
96            ColumnSelector::Indices(indices) => {
97                for &idx in indices {
98                    if idx >= n_features {
99                        return Err(FerroError::InvalidParameter {
100                            name: "ColumnSelector::Indices".into(),
101                            reason: format!(
102                                "column index {idx} is out of range for input with {n_features} features"
103                            ),
104                        });
105                    }
106                }
107                Ok(indices.clone())
108            }
109        }
110    }
111}
112
113// ---------------------------------------------------------------------------
114// Remainder
115// ---------------------------------------------------------------------------
116
117/// Policy for columns that are not selected by any transformer.
118///
119/// When at least one column is not covered by any registered transformer,
120/// `Remainder` determines what happens to those columns in the output.
121#[derive(Debug, Clone)]
122pub enum Remainder {
123    /// Discard remainder columns — they do not appear in the output.
124    Drop,
125    /// Pass remainder columns through unchanged, appended after all
126    /// transformer outputs.
127    Passthrough,
128}
129
130// ---------------------------------------------------------------------------
131// Helper: extract a sub-matrix by column indices
132// ---------------------------------------------------------------------------
133
134/// Build a new `Array2<f64>` containing only the columns at `indices`.
135///
136/// Columns are emitted in the order they appear in `indices`.
137fn select_columns(x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
138    let nrows = x.nrows();
139    let ncols = indices.len();
140    if ncols == 0 {
141        return Array2::zeros((nrows, 0));
142    }
143    let mut out = Array2::zeros((nrows, ncols));
144    for (new_j, &old_j) in indices.iter().enumerate() {
145        out.column_mut(new_j).assign(&x.column(old_j));
146    }
147    out
148}
149
150/// Horizontally concatenate a slice of `Array2<f64>` matrices.
151///
152/// All matrices must have the same number of rows.
153///
154/// # Errors
155///
156/// Returns [`FerroError::ShapeMismatch`] if row counts differ.
157fn hstack(matrices: &[Array2<f64>]) -> Result<Array2<f64>, FerroError> {
158    if matrices.is_empty() {
159        return Ok(Array2::zeros((0, 0)));
160    }
161    let nrows = matrices[0].nrows();
162    let total_cols: usize = matrices.iter().map(ndarray::ArrayBase::ncols).sum();
163
164    // Handle the case where the first matrix establishes nrows = 0 separately.
165    if total_cols == 0 {
166        return Ok(Array2::zeros((nrows, 0)));
167    }
168
169    let mut out = Array2::zeros((nrows, total_cols));
170    let mut col_offset = 0;
171    for m in matrices {
172        if m.nrows() != nrows {
173            return Err(FerroError::ShapeMismatch {
174                expected: vec![nrows, m.ncols()],
175                actual: vec![m.nrows(), m.ncols()],
176                context: "ColumnTransformer hstack: row count mismatch".into(),
177            });
178        }
179        let end = col_offset + m.ncols();
180        if m.ncols() > 0 {
181            out.slice_mut(ndarray::s![.., col_offset..end]).assign(m);
182        }
183        col_offset = end;
184    }
185    Ok(out)
186}
187
188// ---------------------------------------------------------------------------
189// ColumnTransformer (unfitted)
190// ---------------------------------------------------------------------------
191
192/// An unfitted column transformer.
193///
194/// Applies each registered transformer to its designated column subset, then
195/// horizontally concatenates all outputs. The [`Remainder`] policy controls
196/// what happens to columns not covered by any transformer.
197///
198/// # Transformer order
199///
200/// Transformers are applied and their outputs concatenated in the order they
201/// were registered. Remainder columns (when
202/// `remainder = `[`Remainder::Passthrough`]) are appended last.
203///
204/// # Overlapping selections
205///
206/// Each transformer receives its own copy of the selected columns, so
207/// overlapping `ColumnSelector`s are fully supported.
208///
209/// # Examples
210///
211/// ```
212/// use ferrolearn_preprocess::column_transformer::{
213///     ColumnSelector, ColumnTransformer, Remainder,
214/// };
215/// use ferrolearn_preprocess::StandardScaler;
216/// use ferrolearn_core::Fit;
217/// use ferrolearn_core::Transform;
218/// use ndarray::array;
219///
220/// let x = array![[1.0_f64, 10.0, 100.0], [2.0, 20.0, 200.0], [3.0, 30.0, 300.0]];
221/// let ct = ColumnTransformer::new(
222///     vec![("scaler".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
223///     Remainder::Passthrough,
224/// );
225/// let fitted = ct.fit(&x, &()).unwrap();
226/// let out = fitted.transform(&x).unwrap();
227/// // 2 scaled columns + 1 passthrough column
228/// assert_eq!(out.ncols(), 3);
229/// ```
230pub struct ColumnTransformer {
231    /// Named transformer steps with their column selectors.
232    transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
233    /// Policy for columns not covered by any transformer.
234    remainder: Remainder,
235}
236
237impl ColumnTransformer {
238    /// Create a new `ColumnTransformer`.
239    ///
240    /// # Parameters
241    ///
242    /// - `transformers`: A list of `(name, transformer, selector)` triples.
243    /// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
244    #[must_use]
245    pub fn new(
246        transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
247        remainder: Remainder,
248    ) -> Self {
249        Self {
250            transformers,
251            remainder,
252        }
253    }
254}
255
256// ---------------------------------------------------------------------------
257// Fit implementation
258// ---------------------------------------------------------------------------
259
260impl Fit<Array2<f64>, ()> for ColumnTransformer {
261    type Fitted = FittedColumnTransformer;
262    type Error = FerroError;
263
264    /// Fit each transformer on its selected column subset.
265    ///
266    /// Validates that all selected column indices are within bounds before
267    /// fitting any transformer.
268    ///
269    /// # Errors
270    ///
271    /// - [`FerroError::InvalidParameter`] if any column index is out of range.
272    /// - Propagates any error returned by an individual transformer's
273    ///   `fit_pipeline` call.
274    fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedColumnTransformer, FerroError> {
275        let n_features = x.ncols();
276        let n_rows = x.nrows();
277
278        // A dummy y vector required by PipelineTransformer::fit_pipeline.
279        let dummy_y = Array1::<f64>::zeros(n_rows);
280
281        // Resolve all selectors up front to validate indices eagerly.
282        let mut resolved_selectors: Vec<Vec<usize>> = Vec::with_capacity(self.transformers.len());
283        for (name, _, selector) in &self.transformers {
284            let indices = selector.resolve(n_features).map_err(|e| {
285                // Enrich the error with the transformer name.
286                FerroError::InvalidParameter {
287                    name: format!("ColumnTransformer step '{name}'"),
288                    reason: e.to_string(),
289                }
290            })?;
291            resolved_selectors.push(indices);
292        }
293
294        // Build the set of covered column indices (for remainder computation).
295        let covered: std::collections::HashSet<usize> = resolved_selectors
296            .iter()
297            .flat_map(|v| v.iter().copied())
298            .collect();
299
300        let remainder_indices: Vec<usize> =
301            (0..n_features).filter(|c| !covered.contains(c)).collect();
302
303        // Fit each transformer on its sub-matrix.
304        let mut fitted_transformers: Vec<FittedSubTransformer> =
305            Vec::with_capacity(self.transformers.len());
306
307        for ((name, transformer, _), indices) in self.transformers.iter().zip(resolved_selectors) {
308            let sub_x = select_columns(x, &indices);
309            let fitted = transformer.fit_pipeline(&sub_x, &dummy_y)?;
310            fitted_transformers.push((name.clone(), fitted, indices));
311        }
312
313        Ok(FittedColumnTransformer {
314            fitted_transformers,
315            remainder: self.remainder.clone(),
316            remainder_indices,
317            n_features_in: n_features,
318        })
319    }
320}
321
322// ---------------------------------------------------------------------------
323// PipelineTransformer implementation
324// ---------------------------------------------------------------------------
325
326impl PipelineTransformer<f64> for ColumnTransformer {
327    /// Fit the column transformer using the pipeline interface.
328    ///
329    /// The `y` argument is ignored; it exists only for API compatibility.
330    ///
331    /// # Errors
332    ///
333    /// Propagates errors from [`Fit::fit`].
334    fn fit_pipeline(
335        &self,
336        x: &Array2<f64>,
337        _y: &Array1<f64>,
338    ) -> Result<Box<dyn FittedPipelineTransformer<f64>>, FerroError> {
339        let fitted = self.fit(x, &())?;
340        Ok(Box::new(fitted))
341    }
342}
343
344// ---------------------------------------------------------------------------
345// FittedColumnTransformer
346// ---------------------------------------------------------------------------
347
348/// A named, fitted sub-transformer with its column indices.
349type FittedSubTransformer = (String, Box<dyn FittedPipelineTransformer<f64>>, Vec<usize>);
350
351/// A fitted column transformer holding fitted sub-transformers and metadata.
352///
353/// Created by calling [`Fit::fit`] on a [`ColumnTransformer`].
354/// Implements [`Transform<Array2<f64>>`] to apply the fitted transformers and
355/// concatenate their outputs, as well as [`FittedPipelineTransformer`] for use
356/// inside a [`ferrolearn_core::pipeline::Pipeline`].
357pub struct FittedColumnTransformer {
358    /// Fitted transformers with their associated column indices.
359    fitted_transformers: Vec<FittedSubTransformer>,
360    /// Remainder policy from the original [`ColumnTransformer`].
361    remainder: Remainder,
362    /// Column indices not covered by any transformer.
363    remainder_indices: Vec<usize>,
364    /// Number of input features seen during fitting.
365    n_features_in: usize,
366}
367
368impl FittedColumnTransformer {
369    /// Return the number of input features seen during fitting.
370    #[must_use]
371    pub fn n_features_in(&self) -> usize {
372        self.n_features_in
373    }
374
375    /// Return the names of all registered transformer steps.
376    #[must_use]
377    pub fn transformer_names(&self) -> Vec<&str> {
378        self.fitted_transformers
379            .iter()
380            .map(|(name, _, _)| name.as_str())
381            .collect()
382    }
383
384    /// Return the remainder column indices (columns not selected by any transformer).
385    #[must_use]
386    pub fn remainder_indices(&self) -> &[usize] {
387        &self.remainder_indices
388    }
389}
390
391// ---------------------------------------------------------------------------
392// Transform implementation
393// ---------------------------------------------------------------------------
394
395impl Transform<Array2<f64>> for FittedColumnTransformer {
396    type Output = Array2<f64>;
397    type Error = FerroError;
398
399    /// Transform data by applying each fitted transformer to its column subset,
400    /// then horizontally concatenating all outputs.
401    ///
402    /// When `remainder = Passthrough`, the unselected columns are appended
403    /// after all transformer outputs. When `remainder = Drop`, they are
404    /// discarded.
405    ///
406    /// # Errors
407    ///
408    /// - [`FerroError::ShapeMismatch`] if the input does not have
409    ///   `n_features_in` columns.
410    /// - Propagates any error from individual transformer `transform_pipeline`
411    ///   calls.
412    fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
413        if x.ncols() != self.n_features_in {
414            return Err(FerroError::ShapeMismatch {
415                expected: vec![x.nrows(), self.n_features_in],
416                actual: vec![x.nrows(), x.ncols()],
417                context: "FittedColumnTransformer::transform".into(),
418            });
419        }
420
421        let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.fitted_transformers.len() + 1);
422
423        for (_, fitted, indices) in &self.fitted_transformers {
424            let sub_x = select_columns(x, indices);
425            let transformed = fitted.transform_pipeline(&sub_x)?;
426            parts.push(transformed);
427        }
428
429        // Append remainder columns if requested.
430        if matches!(self.remainder, Remainder::Passthrough) && !self.remainder_indices.is_empty() {
431            let remainder_sub = select_columns(x, &self.remainder_indices);
432            parts.push(remainder_sub);
433        }
434
435        hstack(&parts)
436    }
437}
438
439// ---------------------------------------------------------------------------
440// FittedPipelineTransformer implementation
441// ---------------------------------------------------------------------------
442
443impl FittedPipelineTransformer<f64> for FittedColumnTransformer {
444    /// Transform data using the pipeline interface.
445    ///
446    /// # Errors
447    ///
448    /// Propagates errors from [`Transform::transform`].
449    fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
450        self.transform(x)
451    }
452}
453
454// ---------------------------------------------------------------------------
455// make_column_transformer convenience function
456// ---------------------------------------------------------------------------
457
458/// Convenience function to build a [`ColumnTransformer`] with auto-generated
459/// step names.
460///
461/// Steps are named `"transformer-0"`, `"transformer-1"`, etc.
462///
463/// # Parameters
464///
465/// - `transformers`: A list of `(transformer, selector)` pairs.
466/// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
467///
468/// # Examples
469///
470/// ```
471/// use ferrolearn_preprocess::column_transformer::{
472///     make_column_transformer, ColumnSelector, Remainder,
473/// };
474/// use ferrolearn_preprocess::StandardScaler;
475/// use ferrolearn_core::Fit;
476/// use ferrolearn_core::Transform;
477/// use ndarray::array;
478///
479/// let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0]];
480/// let ct = make_column_transformer(
481///     vec![(Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
482///     Remainder::Drop,
483/// );
484/// let fitted = ct.fit(&x, &()).unwrap();
485/// let out = fitted.transform(&x).unwrap();
486/// assert_eq!(out.ncols(), 2);
487/// ```
488#[must_use]
489pub fn make_column_transformer(
490    transformers: Vec<(Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
491    remainder: Remainder,
492) -> ColumnTransformer {
493    let named: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)> = transformers
494        .into_iter()
495        .enumerate()
496        .map(|(i, (t, s))| (format!("transformer-{i}"), t, s))
497        .collect();
498    ColumnTransformer::new(named, remainder)
499}
500
501// ---------------------------------------------------------------------------
502// Tests
503// ---------------------------------------------------------------------------
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508    use approx::assert_abs_diff_eq;
509    use ferrolearn_core::pipeline::{Pipeline, PipelineEstimator};
510    use ndarray::{Array2, array};
511
512    use crate::{MinMaxScaler, StandardScaler};
513
514    // -----------------------------------------------------------------------
515    // Helpers
516    // -----------------------------------------------------------------------
517
518    /// Build a simple 4-column test matrix (rows = 4, cols = 4).
519    fn make_x() -> Array2<f64> {
520        array![
521            [1.0, 2.0, 10.0, 20.0],
522            [2.0, 4.0, 20.0, 40.0],
523            [3.0, 6.0, 30.0, 60.0],
524            [4.0, 8.0, 40.0, 80.0],
525        ]
526    }
527
528    // -----------------------------------------------------------------------
529    // 1. Basic 2-transformer usage
530    // -----------------------------------------------------------------------
531
532    #[test]
533    fn test_basic_two_transformers_drop_remainder() {
534        let x = make_x(); // 4×4
535        let ct = ColumnTransformer::new(
536            vec![
537                (
538                    "std".into(),
539                    Box::new(StandardScaler::<f64>::new()),
540                    ColumnSelector::Indices(vec![0, 1]),
541                ),
542                (
543                    "mm".into(),
544                    Box::new(MinMaxScaler::<f64>::new()),
545                    ColumnSelector::Indices(vec![2, 3]),
546                ),
547            ],
548            Remainder::Drop,
549        );
550
551        let fitted = ct.fit(&x, &()).unwrap();
552        let out = fitted.transform(&x).unwrap();
553
554        // All 4 columns covered → no remainder; output is 4 cols
555        assert_eq!(out.nrows(), 4);
556        assert_eq!(out.ncols(), 4);
557    }
558
559    // -----------------------------------------------------------------------
560    // 2. Remainder::Drop drops uncovered columns
561    // -----------------------------------------------------------------------
562
563    #[test]
564    fn test_remainder_drop() {
565        let x = make_x(); // 4×4
566        // Only cover cols 0 and 1 — cols 2 and 3 should be dropped.
567        let ct = ColumnTransformer::new(
568            vec![(
569                "std".into(),
570                Box::new(StandardScaler::<f64>::new()),
571                ColumnSelector::Indices(vec![0, 1]),
572            )],
573            Remainder::Drop,
574        );
575
576        let fitted = ct.fit(&x, &()).unwrap();
577        let out = fitted.transform(&x).unwrap();
578
579        assert_eq!(out.nrows(), 4);
580        assert_eq!(out.ncols(), 2, "uncovered cols should be dropped");
581    }
582
583    // -----------------------------------------------------------------------
584    // 3. Remainder::Passthrough passes uncovered columns through unchanged
585    // -----------------------------------------------------------------------
586
587    #[test]
588    fn test_remainder_passthrough() {
589        let x = make_x(); // 4×4
590        // Only cover cols 0 and 1 — cols 2 and 3 should pass through.
591        let ct = ColumnTransformer::new(
592            vec![(
593                "std".into(),
594                Box::new(StandardScaler::<f64>::new()),
595                ColumnSelector::Indices(vec![0, 1]),
596            )],
597            Remainder::Passthrough,
598        );
599
600        let fitted = ct.fit(&x, &()).unwrap();
601        let out = fitted.transform(&x).unwrap();
602
603        assert_eq!(out.nrows(), 4);
604        assert_eq!(out.ncols(), 4, "passthrough: 2 transformed + 2 remainder");
605
606        // The last 2 columns should be the original cols 2 and 3.
607        for i in 0..4 {
608            assert_abs_diff_eq!(out[[i, 2]], x[[i, 2]], epsilon = 1e-12);
609            assert_abs_diff_eq!(out[[i, 3]], x[[i, 3]], epsilon = 1e-12);
610        }
611    }
612
613    // -----------------------------------------------------------------------
614    // 4. Invalid column index (out of range)
615    // -----------------------------------------------------------------------
616
617    #[test]
618    fn test_invalid_column_index_out_of_range() {
619        let x = make_x(); // 4×4 — valid indices are 0..3
620        let ct = ColumnTransformer::new(
621            vec![(
622                "std".into(),
623                Box::new(StandardScaler::<f64>::new()),
624                ColumnSelector::Indices(vec![0, 99]), // 99 is out of range
625            )],
626            Remainder::Drop,
627        );
628        let result = ct.fit(&x, &());
629        assert!(result.is_err(), "expected error for out-of-range index");
630    }
631
632    // -----------------------------------------------------------------------
633    // 5. Empty transformer list with Remainder::Drop
634    // -----------------------------------------------------------------------
635
636    #[test]
637    fn test_empty_transformer_list_drop() {
638        let x = make_x();
639        let ct = ColumnTransformer::new(vec![], Remainder::Drop);
640        let fitted = ct.fit(&x, &()).unwrap();
641        let out = fitted.transform(&x).unwrap();
642        // No transformers, remainder dropped → empty output
643        assert_eq!(out.nrows(), 0, "hstack of nothing with no passthrough");
644    }
645
646    // -----------------------------------------------------------------------
647    // 6. Empty transformer list with Remainder::Passthrough
648    // -----------------------------------------------------------------------
649
650    #[test]
651    fn test_empty_transformer_list_passthrough() {
652        let x = make_x(); // 4×4
653        let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
654        let fitted = ct.fit(&x, &()).unwrap();
655        let out = fitted.transform(&x).unwrap();
656        // No transformers, all columns pass through unchanged.
657        assert_eq!(out.nrows(), 4);
658        assert_eq!(out.ncols(), 4);
659        for i in 0..4 {
660            for j in 0..4 {
661                assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
662            }
663        }
664    }
665
666    // -----------------------------------------------------------------------
667    // 7. Overlapping column selections
668    // -----------------------------------------------------------------------
669
670    #[test]
671    fn test_overlapping_column_selections() {
672        let x = make_x(); // 4×4
673        // Both transformers select col 0 (overlapping is allowed).
674        let ct = ColumnTransformer::new(
675            vec![
676                (
677                    "std1".into(),
678                    Box::new(StandardScaler::<f64>::new()),
679                    ColumnSelector::Indices(vec![0, 1]),
680                ),
681                (
682                    "mm1".into(),
683                    Box::new(MinMaxScaler::<f64>::new()),
684                    ColumnSelector::Indices(vec![0, 2]), // col 0 also used here
685                ),
686            ],
687            Remainder::Drop,
688        );
689
690        let fitted = ct.fit(&x, &()).unwrap();
691        let out = fitted.transform(&x).unwrap();
692
693        // Output: 2 cols from std1 + 2 cols from mm1 = 4 cols
694        assert_eq!(out.nrows(), 4);
695        assert_eq!(out.ncols(), 4);
696    }
697
698    // -----------------------------------------------------------------------
699    // 8. Single transformer
700    // -----------------------------------------------------------------------
701
702    #[test]
703    fn test_single_transformer() {
704        let x = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
705        let ct = ColumnTransformer::new(
706            vec![(
707                "mm".into(),
708                Box::new(MinMaxScaler::<f64>::new()),
709                ColumnSelector::Indices(vec![0, 1]),
710            )],
711            Remainder::Drop,
712        );
713
714        let fitted = ct.fit(&x, &()).unwrap();
715        let out = fitted.transform(&x).unwrap();
716
717        assert_eq!(out.nrows(), 3);
718        assert_eq!(out.ncols(), 2);
719
720        // MinMax on cols 0 and 1: first row → 0.0, last row → 1.0
721        assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
722        assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10);
723        assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10);
724        assert_abs_diff_eq!(out[[2, 1]], 1.0, epsilon = 1e-10);
725    }
726
727    // -----------------------------------------------------------------------
728    // 9. make_column_transformer convenience function
729    // -----------------------------------------------------------------------
730
731    #[test]
732    fn test_make_column_transformer_auto_names() {
733        let x = make_x();
734        let ct = make_column_transformer(
735            vec![
736                (
737                    Box::new(StandardScaler::<f64>::new()),
738                    ColumnSelector::Indices(vec![0, 1]),
739                ),
740                (
741                    Box::new(MinMaxScaler::<f64>::new()),
742                    ColumnSelector::Indices(vec![2, 3]),
743                ),
744            ],
745            Remainder::Drop,
746        );
747
748        let fitted = ct.fit(&x, &()).unwrap();
749        assert_eq!(
750            fitted.transformer_names(),
751            vec!["transformer-0", "transformer-1"]
752        );
753
754        let out = fitted.transform(&x).unwrap();
755        assert_eq!(out.nrows(), 4);
756        assert_eq!(out.ncols(), 4);
757    }
758
759    // -----------------------------------------------------------------------
760    // 10. Pipeline integration
761    // -----------------------------------------------------------------------
762
763    #[test]
764    fn test_pipeline_integration() {
765        // Wrap a ColumnTransformer as a pipeline step.
766        let x = make_x();
767        let y = Array1::<f64>::zeros(4);
768
769        let ct = ColumnTransformer::new(
770            vec![(
771                "std".into(),
772                Box::new(StandardScaler::<f64>::new()),
773                ColumnSelector::Indices(vec![0, 1, 2, 3]),
774            )],
775            Remainder::Drop,
776        );
777
778        // Use a trivial estimator that sums rows.
779        struct SumEstimator;
780        impl PipelineEstimator<f64> for SumEstimator {
781            fn fit_pipeline(
782                &self,
783                _x: &Array2<f64>,
784                _y: &Array1<f64>,
785            ) -> Result<Box<dyn ferrolearn_core::pipeline::FittedPipelineEstimator<f64>>, FerroError>
786            {
787                Ok(Box::new(FittedSum))
788            }
789        }
790        struct FittedSum;
791        impl ferrolearn_core::pipeline::FittedPipelineEstimator<f64> for FittedSum {
792            fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
793                let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
794                Ok(Array1::from_vec(sums))
795            }
796        }
797
798        let pipeline = Pipeline::new()
799            .transform_step("ct", Box::new(ct))
800            .estimator_step("sum", Box::new(SumEstimator));
801
802        use ferrolearn_core::Fit as _;
803        let fitted_pipeline = pipeline.fit(&x, &y).unwrap();
804
805        use ferrolearn_core::Predict as _;
806        let preds = fitted_pipeline.predict(&x).unwrap();
807        assert_eq!(preds.len(), 4);
808    }
809
810    // -----------------------------------------------------------------------
811    // 11. Transform shape correctness — number of output columns
812    // -----------------------------------------------------------------------
813
814    #[test]
815    fn test_output_shape_all_selected_drop() {
816        let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
817        let ct = ColumnTransformer::new(
818            vec![
819                (
820                    "s".into(),
821                    Box::new(StandardScaler::<f64>::new()),
822                    ColumnSelector::Indices(vec![0]),
823                ),
824                (
825                    "m".into(),
826                    Box::new(MinMaxScaler::<f64>::new()),
827                    ColumnSelector::Indices(vec![1, 2]),
828                ),
829            ],
830            Remainder::Drop,
831        );
832        let fitted = ct.fit(&x, &()).unwrap();
833        let out = fitted.transform(&x).unwrap();
834        assert_eq!(out.shape(), &[2, 3]);
835    }
836
837    // -----------------------------------------------------------------------
838    // 12. Transform shape — partial selection + passthrough
839    // -----------------------------------------------------------------------
840
841    #[test]
842    fn test_output_shape_partial_passthrough() {
843        // 5-column input, transform 2 cols, passthrough 3
844        let x = Array2::<f64>::from_shape_vec((3, 5), (1..=15).map(f64::from).collect()).unwrap();
845        let ct = ColumnTransformer::new(
846            vec![(
847                "std".into(),
848                Box::new(StandardScaler::<f64>::new()),
849                ColumnSelector::Indices(vec![0, 1]),
850            )],
851            Remainder::Passthrough,
852        );
853        let fitted = ct.fit(&x, &()).unwrap();
854        let out = fitted.transform(&x).unwrap();
855        assert_eq!(out.shape(), &[3, 5]);
856    }
857
858    // -----------------------------------------------------------------------
859    // 13. n_features_in accessor
860    // -----------------------------------------------------------------------
861
862    #[test]
863    fn test_n_features_in() {
864        let x = make_x(); // 4×4
865        let ct = ColumnTransformer::new(
866            vec![(
867                "std".into(),
868                Box::new(StandardScaler::<f64>::new()),
869                ColumnSelector::Indices(vec![0]),
870            )],
871            Remainder::Drop,
872        );
873        let fitted = ct.fit(&x, &()).unwrap();
874        assert_eq!(fitted.n_features_in(), 4);
875    }
876
877    // -----------------------------------------------------------------------
878    // 14. Shape mismatch on transform (wrong number of columns)
879    // -----------------------------------------------------------------------
880
881    #[test]
882    fn test_shape_mismatch_on_transform() {
883        let x = make_x(); // 4×4
884        let ct = ColumnTransformer::new(
885            vec![(
886                "std".into(),
887                Box::new(StandardScaler::<f64>::new()),
888                ColumnSelector::Indices(vec![0, 1]),
889            )],
890            Remainder::Drop,
891        );
892        let fitted = ct.fit(&x, &()).unwrap();
893
894        // Now pass a matrix with only 2 columns — should fail.
895        let x_bad = array![[1.0_f64, 2.0], [3.0, 4.0]];
896        let result = fitted.transform(&x_bad);
897        assert!(result.is_err(), "expected shape mismatch error");
898    }
899
900    // -----------------------------------------------------------------------
901    // 15. remainder_indices accessor
902    // -----------------------------------------------------------------------
903
904    #[test]
905    fn test_remainder_indices_accessor() {
906        let x = make_x(); // 4×4
907        let ct = ColumnTransformer::new(
908            vec![(
909                "std".into(),
910                Box::new(StandardScaler::<f64>::new()),
911                ColumnSelector::Indices(vec![0, 2]),
912            )],
913            Remainder::Passthrough,
914        );
915        let fitted = ct.fit(&x, &()).unwrap();
916        // Remainder should be cols 1 and 3.
917        assert_eq!(fitted.remainder_indices(), &[1, 3]);
918    }
919
920    // -----------------------------------------------------------------------
921    // 16. StandardScaler output values are correct (zero-mean)
922    // -----------------------------------------------------------------------
923
924    #[test]
925    fn test_standard_scaler_zero_mean_in_output() {
926        let x = array![[1.0_f64, 100.0, 0.5], [2.0, 200.0, 1.5], [3.0, 300.0, 2.5],];
927        let ct = ColumnTransformer::new(
928            vec![(
929                "std".into(),
930                Box::new(StandardScaler::<f64>::new()),
931                ColumnSelector::Indices(vec![0, 1]),
932            )],
933            Remainder::Drop,
934        );
935        let fitted = ct.fit(&x, &()).unwrap();
936        let out = fitted.transform(&x).unwrap();
937
938        // Cols 0 and 1 of output should have mean ≈ 0.
939        for j in 0..2 {
940            let mean: f64 = out.column(j).iter().sum::<f64>() / 3.0;
941            assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
942        }
943    }
944
945    // -----------------------------------------------------------------------
946    // 17. MinMaxScaler output values are in [0, 1]
947    // -----------------------------------------------------------------------
948
949    #[test]
950    fn test_min_max_values_in_range() {
951        let x = make_x();
952        let ct = ColumnTransformer::new(
953            vec![(
954                "mm".into(),
955                Box::new(MinMaxScaler::<f64>::new()),
956                ColumnSelector::Indices(vec![0, 1, 2, 3]),
957            )],
958            Remainder::Drop,
959        );
960        let fitted = ct.fit(&x, &()).unwrap();
961        let out = fitted.transform(&x).unwrap();
962
963        for j in 0..4 {
964            let col_min = out.column(j).iter().copied().fold(f64::INFINITY, f64::min);
965            let col_max = out
966                .column(j)
967                .iter()
968                .copied()
969                .fold(f64::NEG_INFINITY, f64::max);
970            assert_abs_diff_eq!(col_min, 0.0, epsilon = 1e-10);
971            assert_abs_diff_eq!(col_max, 1.0, epsilon = 1e-10);
972        }
973    }
974
975    // -----------------------------------------------------------------------
976    // 18. Pipeline transformer interface (fit_pipeline / transform_pipeline)
977    // -----------------------------------------------------------------------
978
979    #[test]
980    fn test_pipeline_transformer_interface() {
981        let x = make_x();
982        let y = Array1::<f64>::zeros(4);
983        let ct = ColumnTransformer::new(
984            vec![(
985                "std".into(),
986                Box::new(StandardScaler::<f64>::new()),
987                ColumnSelector::Indices(vec![0, 1]),
988            )],
989            Remainder::Passthrough,
990        );
991        let fitted_box = ct.fit_pipeline(&x, &y).unwrap();
992        let out = fitted_box.transform_pipeline(&x).unwrap();
993        assert_eq!(out.nrows(), 4);
994        assert_eq!(out.ncols(), 4);
995    }
996
997    // -----------------------------------------------------------------------
998    // 19. Remainder passthrough values are identical to input values
999    // -----------------------------------------------------------------------
1000
1001    #[test]
1002    fn test_passthrough_values_are_exact() {
1003        let x = array![[10.0_f64, 20.0, 30.0], [40.0, 50.0, 60.0],];
1004        // Only transform col 0; cols 1 and 2 pass through.
1005        let ct = ColumnTransformer::new(
1006            vec![(
1007                "mm".into(),
1008                Box::new(MinMaxScaler::<f64>::new()),
1009                ColumnSelector::Indices(vec![0]),
1010            )],
1011            Remainder::Passthrough,
1012        );
1013        let fitted = ct.fit(&x, &()).unwrap();
1014        let out = fitted.transform(&x).unwrap();
1015        // out[:, 1] == x[:, 1] and out[:, 2] == x[:, 2]
1016        assert_abs_diff_eq!(out[[0, 1]], 20.0, epsilon = 1e-12);
1017        assert_abs_diff_eq!(out[[1, 1]], 50.0, epsilon = 1e-12);
1018        assert_abs_diff_eq!(out[[0, 2]], 30.0, epsilon = 1e-12);
1019        assert_abs_diff_eq!(out[[1, 2]], 60.0, epsilon = 1e-12);
1020    }
1021
1022    // -----------------------------------------------------------------------
1023    // 20. Transformer names from explicit ColumnTransformer::new
1024    // -----------------------------------------------------------------------
1025
1026    #[test]
1027    fn test_transformer_names_explicit() {
1028        let x = make_x();
1029        let ct = ColumnTransformer::new(
1030            vec![
1031                (
1032                    "alpha".into(),
1033                    Box::new(StandardScaler::<f64>::new()),
1034                    ColumnSelector::Indices(vec![0]),
1035                ),
1036                (
1037                    "beta".into(),
1038                    Box::new(MinMaxScaler::<f64>::new()),
1039                    ColumnSelector::Indices(vec![1]),
1040                ),
1041            ],
1042            Remainder::Drop,
1043        );
1044        let fitted = ct.fit(&x, &()).unwrap();
1045        assert_eq!(fitted.transformer_names(), vec!["alpha", "beta"]);
1046    }
1047
1048    // -----------------------------------------------------------------------
1049    // 21. make_column_transformer with single step
1050    // -----------------------------------------------------------------------
1051
1052    #[test]
1053    fn test_make_column_transformer_single() {
1054        let x = array![[1.0_f64, 2.0], [3.0, 4.0]];
1055        let ct = make_column_transformer(
1056            vec![(
1057                Box::new(StandardScaler::<f64>::new()),
1058                ColumnSelector::Indices(vec![0, 1]),
1059            )],
1060            Remainder::Drop,
1061        );
1062        let fitted = ct.fit(&x, &()).unwrap();
1063        assert_eq!(fitted.transformer_names(), vec!["transformer-0"]);
1064        let out = fitted.transform(&x).unwrap();
1065        assert_eq!(out.shape(), &[2, 2]);
1066    }
1067
1068    // -----------------------------------------------------------------------
1069    // 22. Edge case: all columns as remainder with Passthrough
1070    // -----------------------------------------------------------------------
1071
1072    #[test]
1073    fn test_all_remainder_passthrough_unchanged() {
1074        let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1075        let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
1076        let fitted = ct.fit(&x, &()).unwrap();
1077        let out = fitted.transform(&x).unwrap();
1078        assert_eq!(out.shape(), &[2, 3]);
1079        for i in 0..2 {
1080            for j in 0..3 {
1081                assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
1082            }
1083        }
1084    }
1085}