ferrolearn_preprocess/column_transformer.rs
1//! Column transformer: apply different transformers to different column subsets.
2//!
3//! [`ColumnTransformer`] applies each registered transformer to its designated
4//! column subset, then horizontally concatenates the outputs into a single
5//! `Array2<f64>`. Columns not captured by any transformer can be dropped or
6//! passed through unchanged via the [`Remainder`] policy.
7//!
8//! # Examples
9//!
10//! ```
11//! use ferrolearn_preprocess::column_transformer::{
12//! ColumnSelector, ColumnTransformer, Remainder,
13//! };
14//! use ferrolearn_preprocess::{StandardScaler, MinMaxScaler};
15//! use ferrolearn_core::Fit;
16//! use ferrolearn_core::Transform;
17//! use ndarray::array;
18//!
19//! let x = array![
20//! [1.0_f64, 2.0, 10.0, 20.0],
21//! [3.0, 4.0, 30.0, 40.0],
22//! [5.0, 6.0, 50.0, 60.0],
23//! ];
24//!
25//! let ct = ColumnTransformer::new(
26//! vec![
27//! ("std".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1])),
28//! ("mm".into(), Box::new(MinMaxScaler::<f64>::new()), ColumnSelector::Indices(vec![2, 3])),
29//! ],
30//! Remainder::Drop,
31//! );
32//!
33//! let fitted = ct.fit(&x, &()).unwrap();
34//! let out = fitted.transform(&x).unwrap();
35//! assert_eq!(out.ncols(), 4);
36//! assert_eq!(out.nrows(), 3);
37//! ```
38//!
39//! ## REQ status
40//!
41//! Translation target: scikit-learn 1.5.2 `ColumnTransformer` +
42//! `make_column_transformer` (`sklearn/compose/_column_transformer.py:59`).
43//! Tracking: #1434. Each REQ is BINARY — SHIPPED (impl + non-test consumer +
44//! tests + green verification) or NOT-STARTED (with a concrete open blocker).
45//! DETERMINISTIC composition meta-transformer: this unit owns the column
46//! routing / output ordering / remainder; sub-transformer VALUES come from
47//! their own routed units (StandardScaler, MinMaxScaler, …).
48//!
49//! | REQ | Scope | Status | Evidence / Blocker |
50//! |-----|-------|--------|--------------------|
51//! | REQ-1 | Column routing + output ORDERING (transformer outputs in registration/list order, remainder appended LAST in ascending column order) + `Drop`/`Passthrough` + OVERLAPPING columns + combined output VALUES (index selectors) | SHIPPED | [`ColumnTransformer`] `fit`/`transform` (hstack in list order + remainder `(0..n).filter(!covered)` ascending) matches sklearn `_hstack` `_column_transformer.py:976-1006,1091` + `_validate_remainder` `sorted(set(range)-cols)` `:546`; 8 oracle tests in `tests/divergence_column_transformer.rs`. Consumer: re-export `lib.rs:128` + `PipelineTransformer` |
52//! | REQ-2 | [`make_column_transformer`] builds a composing `ColumnTransformer` | SHIPPED (scoped) | auto-names `transformer-N` (sklearn uses lowercased class names — naming gap folded into REQ-7); oracle value test |
53//! | REQ-3 | Error/parameter contracts (out-of-range column index, transform ncols mismatch) | SHIPPED (scoped) | [`ColumnTransformer::fit`]/[`FittedColumnTransformer`] `transform`; divergence error tests |
54//! | REQ-4 | Non-index `ColumnSelector`s (str/slice/bool-mask/callable) + `make_column_selector` | NOT-STARTED | `Indices` only; sklearn `_column_transformer.py:1468` — blocker #1435 |
55//! | REQ-5 | `remainder` as an estimator + `'drop'`/`'passthrough'` as a STEP transformer | NOT-STARTED | sklearn `_column_transformer.py:277-281,460-462` — blocker #1436 |
56//! | REQ-6 | `sparse_threshold` + sparse output + `transformer_weights` + `n_jobs`/`verbose` | NOT-STARTED | dense only; sklearn `_column_transformer.py:282,998,1091-1200` — blocker #1437 |
57//! | REQ-7 | `get_feature_names_out` + `verbose_feature_names_out` (`name__feature`) + class-name auto-naming | NOT-STARTED | sklearn `_column_transformer.py:599,662,1456-1465` — blocker #1438 |
58//! | REQ-8 | `named_transformers_`/`transformers_`/`n_features_in_`/`feature_names_in_` fitted-attr surface | NOT-STARTED | sklearn `_column_transformer.py:574-582,694-726` — blocker #1439 |
59//! | REQ-9 | Generic `F` (currently `f64`-only) | NOT-STARTED | R-DEV-3 — blocker #1440 |
60//! | REQ-10 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1441 |
61//! | REQ-11 | ferray substrate | NOT-STARTED | dense `Array2<f64>` only — blocker #1442 |
62
63use ferrolearn_core::error::FerroError;
64use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
65use ferrolearn_core::traits::{Fit, Transform};
66use ndarray::{Array1, Array2};
67
68// ---------------------------------------------------------------------------
69// ColumnSelector
70// ---------------------------------------------------------------------------
71
72/// Specifies which columns a transformer should operate on.
73///
74/// Currently the only supported variant is [`Indices`](ColumnSelector::Indices),
75/// which selects columns by their zero-based integer positions.
76#[derive(Debug, Clone)]
77pub enum ColumnSelector {
78 /// Select columns by zero-based index.
79 ///
80 /// The indices do not need to be sorted, but every index must be strictly
81 /// less than the number of columns in the input matrix. Duplicate indices
82 /// are allowed; the same column will simply appear twice in the sub-matrix
83 /// passed to the transformer.
84 Indices(Vec<usize>),
85}
86
87impl ColumnSelector {
88 /// Resolve the selector to a concrete list of column indices.
89 ///
90 /// # Errors
91 ///
92 /// Returns [`FerroError::InvalidParameter`] if any index is out of range
93 /// (i.e., `>= n_features`).
94 fn resolve(&self, n_features: usize) -> Result<Vec<usize>, FerroError> {
95 match self {
96 ColumnSelector::Indices(indices) => {
97 for &idx in indices {
98 if idx >= n_features {
99 return Err(FerroError::InvalidParameter {
100 name: "ColumnSelector::Indices".into(),
101 reason: format!(
102 "column index {idx} is out of range for input with {n_features} features"
103 ),
104 });
105 }
106 }
107 Ok(indices.clone())
108 }
109 }
110 }
111}
112
113// ---------------------------------------------------------------------------
114// Remainder
115// ---------------------------------------------------------------------------
116
117/// Policy for columns that are not selected by any transformer.
118///
119/// When at least one column is not covered by any registered transformer,
120/// `Remainder` determines what happens to those columns in the output.
121#[derive(Debug, Clone)]
122pub enum Remainder {
123 /// Discard remainder columns — they do not appear in the output.
124 Drop,
125 /// Pass remainder columns through unchanged, appended after all
126 /// transformer outputs.
127 Passthrough,
128}
129
130// ---------------------------------------------------------------------------
131// Helper: extract a sub-matrix by column indices
132// ---------------------------------------------------------------------------
133
134/// Build a new `Array2<f64>` containing only the columns at `indices`.
135///
136/// Columns are emitted in the order they appear in `indices`.
137fn select_columns(x: &Array2<f64>, indices: &[usize]) -> Array2<f64> {
138 let nrows = x.nrows();
139 let ncols = indices.len();
140 if ncols == 0 {
141 return Array2::zeros((nrows, 0));
142 }
143 let mut out = Array2::zeros((nrows, ncols));
144 for (new_j, &old_j) in indices.iter().enumerate() {
145 out.column_mut(new_j).assign(&x.column(old_j));
146 }
147 out
148}
149
150/// Horizontally concatenate a slice of `Array2<f64>` matrices.
151///
152/// All matrices must have the same number of rows.
153///
154/// # Errors
155///
156/// Returns [`FerroError::ShapeMismatch`] if row counts differ.
157fn hstack(matrices: &[Array2<f64>]) -> Result<Array2<f64>, FerroError> {
158 if matrices.is_empty() {
159 return Ok(Array2::zeros((0, 0)));
160 }
161 let nrows = matrices[0].nrows();
162 let total_cols: usize = matrices.iter().map(ndarray::ArrayBase::ncols).sum();
163
164 // Handle the case where the first matrix establishes nrows = 0 separately.
165 if total_cols == 0 {
166 return Ok(Array2::zeros((nrows, 0)));
167 }
168
169 let mut out = Array2::zeros((nrows, total_cols));
170 let mut col_offset = 0;
171 for m in matrices {
172 if m.nrows() != nrows {
173 return Err(FerroError::ShapeMismatch {
174 expected: vec![nrows, m.ncols()],
175 actual: vec![m.nrows(), m.ncols()],
176 context: "ColumnTransformer hstack: row count mismatch".into(),
177 });
178 }
179 let end = col_offset + m.ncols();
180 if m.ncols() > 0 {
181 out.slice_mut(ndarray::s![.., col_offset..end]).assign(m);
182 }
183 col_offset = end;
184 }
185 Ok(out)
186}
187
188// ---------------------------------------------------------------------------
189// ColumnTransformer (unfitted)
190// ---------------------------------------------------------------------------
191
192/// An unfitted column transformer.
193///
194/// Applies each registered transformer to its designated column subset, then
195/// horizontally concatenates all outputs. The [`Remainder`] policy controls
196/// what happens to columns not covered by any transformer.
197///
198/// # Transformer order
199///
200/// Transformers are applied and their outputs concatenated in the order they
201/// were registered. Remainder columns (when
202/// `remainder = `[`Remainder::Passthrough`]) are appended last.
203///
204/// # Overlapping selections
205///
206/// Each transformer receives its own copy of the selected columns, so
207/// overlapping `ColumnSelector`s are fully supported.
208///
209/// # Examples
210///
211/// ```
212/// use ferrolearn_preprocess::column_transformer::{
213/// ColumnSelector, ColumnTransformer, Remainder,
214/// };
215/// use ferrolearn_preprocess::StandardScaler;
216/// use ferrolearn_core::Fit;
217/// use ferrolearn_core::Transform;
218/// use ndarray::array;
219///
220/// let x = array![[1.0_f64, 10.0, 100.0], [2.0, 20.0, 200.0], [3.0, 30.0, 300.0]];
221/// let ct = ColumnTransformer::new(
222/// vec![("scaler".into(), Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
223/// Remainder::Passthrough,
224/// );
225/// let fitted = ct.fit(&x, &()).unwrap();
226/// let out = fitted.transform(&x).unwrap();
227/// // 2 scaled columns + 1 passthrough column
228/// assert_eq!(out.ncols(), 3);
229/// ```
230pub struct ColumnTransformer {
231 /// Named transformer steps with their column selectors.
232 transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
233 /// Policy for columns not covered by any transformer.
234 remainder: Remainder,
235}
236
237impl ColumnTransformer {
238 /// Create a new `ColumnTransformer`.
239 ///
240 /// # Parameters
241 ///
242 /// - `transformers`: A list of `(name, transformer, selector)` triples.
243 /// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
244 #[must_use]
245 pub fn new(
246 transformers: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
247 remainder: Remainder,
248 ) -> Self {
249 Self {
250 transformers,
251 remainder,
252 }
253 }
254}
255
256// ---------------------------------------------------------------------------
257// Fit implementation
258// ---------------------------------------------------------------------------
259
260impl Fit<Array2<f64>, ()> for ColumnTransformer {
261 type Fitted = FittedColumnTransformer;
262 type Error = FerroError;
263
264 /// Fit each transformer on its selected column subset.
265 ///
266 /// Validates that all selected column indices are within bounds before
267 /// fitting any transformer.
268 ///
269 /// # Errors
270 ///
271 /// - [`FerroError::InvalidParameter`] if any column index is out of range.
272 /// - Propagates any error returned by an individual transformer's
273 /// `fit_pipeline` call.
274 fn fit(&self, x: &Array2<f64>, _y: &()) -> Result<FittedColumnTransformer, FerroError> {
275 let n_features = x.ncols();
276 let n_rows = x.nrows();
277
278 // A dummy y vector required by PipelineTransformer::fit_pipeline.
279 let dummy_y = Array1::<f64>::zeros(n_rows);
280
281 // Resolve all selectors up front to validate indices eagerly.
282 let mut resolved_selectors: Vec<Vec<usize>> = Vec::with_capacity(self.transformers.len());
283 for (name, _, selector) in &self.transformers {
284 let indices = selector.resolve(n_features).map_err(|e| {
285 // Enrich the error with the transformer name.
286 FerroError::InvalidParameter {
287 name: format!("ColumnTransformer step '{name}'"),
288 reason: e.to_string(),
289 }
290 })?;
291 resolved_selectors.push(indices);
292 }
293
294 // Build the set of covered column indices (for remainder computation).
295 let covered: std::collections::HashSet<usize> = resolved_selectors
296 .iter()
297 .flat_map(|v| v.iter().copied())
298 .collect();
299
300 let remainder_indices: Vec<usize> =
301 (0..n_features).filter(|c| !covered.contains(c)).collect();
302
303 // Fit each transformer on its sub-matrix.
304 let mut fitted_transformers: Vec<FittedSubTransformer> =
305 Vec::with_capacity(self.transformers.len());
306
307 for ((name, transformer, _), indices) in self.transformers.iter().zip(resolved_selectors) {
308 let sub_x = select_columns(x, &indices);
309 let fitted = transformer.fit_pipeline(&sub_x, &dummy_y)?;
310 fitted_transformers.push((name.clone(), fitted, indices));
311 }
312
313 Ok(FittedColumnTransformer {
314 fitted_transformers,
315 remainder: self.remainder.clone(),
316 remainder_indices,
317 n_features_in: n_features,
318 })
319 }
320}
321
322// ---------------------------------------------------------------------------
323// PipelineTransformer implementation
324// ---------------------------------------------------------------------------
325
326impl PipelineTransformer<f64> for ColumnTransformer {
327 /// Fit the column transformer using the pipeline interface.
328 ///
329 /// The `y` argument is ignored; it exists only for API compatibility.
330 ///
331 /// # Errors
332 ///
333 /// Propagates errors from [`Fit::fit`].
334 fn fit_pipeline(
335 &self,
336 x: &Array2<f64>,
337 _y: &Array1<f64>,
338 ) -> Result<Box<dyn FittedPipelineTransformer<f64>>, FerroError> {
339 let fitted = self.fit(x, &())?;
340 Ok(Box::new(fitted))
341 }
342}
343
344// ---------------------------------------------------------------------------
345// FittedColumnTransformer
346// ---------------------------------------------------------------------------
347
348/// A named, fitted sub-transformer with its column indices.
349type FittedSubTransformer = (String, Box<dyn FittedPipelineTransformer<f64>>, Vec<usize>);
350
351/// A fitted column transformer holding fitted sub-transformers and metadata.
352///
353/// Created by calling [`Fit::fit`] on a [`ColumnTransformer`].
354/// Implements [`Transform<Array2<f64>>`] to apply the fitted transformers and
355/// concatenate their outputs, as well as [`FittedPipelineTransformer`] for use
356/// inside a [`ferrolearn_core::pipeline::Pipeline`].
357pub struct FittedColumnTransformer {
358 /// Fitted transformers with their associated column indices.
359 fitted_transformers: Vec<FittedSubTransformer>,
360 /// Remainder policy from the original [`ColumnTransformer`].
361 remainder: Remainder,
362 /// Column indices not covered by any transformer.
363 remainder_indices: Vec<usize>,
364 /// Number of input features seen during fitting.
365 n_features_in: usize,
366}
367
368impl FittedColumnTransformer {
369 /// Return the number of input features seen during fitting.
370 #[must_use]
371 pub fn n_features_in(&self) -> usize {
372 self.n_features_in
373 }
374
375 /// Return the names of all registered transformer steps.
376 #[must_use]
377 pub fn transformer_names(&self) -> Vec<&str> {
378 self.fitted_transformers
379 .iter()
380 .map(|(name, _, _)| name.as_str())
381 .collect()
382 }
383
384 /// Return the remainder column indices (columns not selected by any transformer).
385 #[must_use]
386 pub fn remainder_indices(&self) -> &[usize] {
387 &self.remainder_indices
388 }
389}
390
391// ---------------------------------------------------------------------------
392// Transform implementation
393// ---------------------------------------------------------------------------
394
395impl Transform<Array2<f64>> for FittedColumnTransformer {
396 type Output = Array2<f64>;
397 type Error = FerroError;
398
399 /// Transform data by applying each fitted transformer to its column subset,
400 /// then horizontally concatenating all outputs.
401 ///
402 /// When `remainder = Passthrough`, the unselected columns are appended
403 /// after all transformer outputs. When `remainder = Drop`, they are
404 /// discarded.
405 ///
406 /// # Errors
407 ///
408 /// - [`FerroError::ShapeMismatch`] if the input does not have
409 /// `n_features_in` columns.
410 /// - Propagates any error from individual transformer `transform_pipeline`
411 /// calls.
412 fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
413 if x.ncols() != self.n_features_in {
414 return Err(FerroError::ShapeMismatch {
415 expected: vec![x.nrows(), self.n_features_in],
416 actual: vec![x.nrows(), x.ncols()],
417 context: "FittedColumnTransformer::transform".into(),
418 });
419 }
420
421 let mut parts: Vec<Array2<f64>> = Vec::with_capacity(self.fitted_transformers.len() + 1);
422
423 for (_, fitted, indices) in &self.fitted_transformers {
424 let sub_x = select_columns(x, indices);
425 let transformed = fitted.transform_pipeline(&sub_x)?;
426 parts.push(transformed);
427 }
428
429 // Append remainder columns if requested.
430 if matches!(self.remainder, Remainder::Passthrough) && !self.remainder_indices.is_empty() {
431 let remainder_sub = select_columns(x, &self.remainder_indices);
432 parts.push(remainder_sub);
433 }
434
435 hstack(&parts)
436 }
437}
438
439// ---------------------------------------------------------------------------
440// FittedPipelineTransformer implementation
441// ---------------------------------------------------------------------------
442
443impl FittedPipelineTransformer<f64> for FittedColumnTransformer {
444 /// Transform data using the pipeline interface.
445 ///
446 /// # Errors
447 ///
448 /// Propagates errors from [`Transform::transform`].
449 fn transform_pipeline(&self, x: &Array2<f64>) -> Result<Array2<f64>, FerroError> {
450 self.transform(x)
451 }
452}
453
454// ---------------------------------------------------------------------------
455// make_column_transformer convenience function
456// ---------------------------------------------------------------------------
457
458/// Convenience function to build a [`ColumnTransformer`] with auto-generated
459/// step names.
460///
461/// Steps are named `"transformer-0"`, `"transformer-1"`, etc.
462///
463/// # Parameters
464///
465/// - `transformers`: A list of `(transformer, selector)` pairs.
466/// - `remainder`: Policy for uncovered columns (`Drop` or `Passthrough`).
467///
468/// # Examples
469///
470/// ```
471/// use ferrolearn_preprocess::column_transformer::{
472/// make_column_transformer, ColumnSelector, Remainder,
473/// };
474/// use ferrolearn_preprocess::StandardScaler;
475/// use ferrolearn_core::Fit;
476/// use ferrolearn_core::Transform;
477/// use ndarray::array;
478///
479/// let x = array![[1.0_f64, 10.0], [2.0, 20.0], [3.0, 30.0]];
480/// let ct = make_column_transformer(
481/// vec![(Box::new(StandardScaler::<f64>::new()), ColumnSelector::Indices(vec![0, 1]))],
482/// Remainder::Drop,
483/// );
484/// let fitted = ct.fit(&x, &()).unwrap();
485/// let out = fitted.transform(&x).unwrap();
486/// assert_eq!(out.ncols(), 2);
487/// ```
488#[must_use]
489pub fn make_column_transformer(
490 transformers: Vec<(Box<dyn PipelineTransformer<f64>>, ColumnSelector)>,
491 remainder: Remainder,
492) -> ColumnTransformer {
493 let named: Vec<(String, Box<dyn PipelineTransformer<f64>>, ColumnSelector)> = transformers
494 .into_iter()
495 .enumerate()
496 .map(|(i, (t, s))| (format!("transformer-{i}"), t, s))
497 .collect();
498 ColumnTransformer::new(named, remainder)
499}
500
501// ---------------------------------------------------------------------------
502// Tests
503// ---------------------------------------------------------------------------
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508 use approx::assert_abs_diff_eq;
509 use ferrolearn_core::pipeline::{Pipeline, PipelineEstimator};
510 use ndarray::{Array2, array};
511
512 use crate::{MinMaxScaler, StandardScaler};
513
514 // -----------------------------------------------------------------------
515 // Helpers
516 // -----------------------------------------------------------------------
517
518 /// Build a simple 4-column test matrix (rows = 4, cols = 4).
519 fn make_x() -> Array2<f64> {
520 array![
521 [1.0, 2.0, 10.0, 20.0],
522 [2.0, 4.0, 20.0, 40.0],
523 [3.0, 6.0, 30.0, 60.0],
524 [4.0, 8.0, 40.0, 80.0],
525 ]
526 }
527
528 // -----------------------------------------------------------------------
529 // 1. Basic 2-transformer usage
530 // -----------------------------------------------------------------------
531
532 #[test]
533 fn test_basic_two_transformers_drop_remainder() {
534 let x = make_x(); // 4×4
535 let ct = ColumnTransformer::new(
536 vec![
537 (
538 "std".into(),
539 Box::new(StandardScaler::<f64>::new()),
540 ColumnSelector::Indices(vec![0, 1]),
541 ),
542 (
543 "mm".into(),
544 Box::new(MinMaxScaler::<f64>::new()),
545 ColumnSelector::Indices(vec![2, 3]),
546 ),
547 ],
548 Remainder::Drop,
549 );
550
551 let fitted = ct.fit(&x, &()).unwrap();
552 let out = fitted.transform(&x).unwrap();
553
554 // All 4 columns covered → no remainder; output is 4 cols
555 assert_eq!(out.nrows(), 4);
556 assert_eq!(out.ncols(), 4);
557 }
558
559 // -----------------------------------------------------------------------
560 // 2. Remainder::Drop drops uncovered columns
561 // -----------------------------------------------------------------------
562
563 #[test]
564 fn test_remainder_drop() {
565 let x = make_x(); // 4×4
566 // Only cover cols 0 and 1 — cols 2 and 3 should be dropped.
567 let ct = ColumnTransformer::new(
568 vec![(
569 "std".into(),
570 Box::new(StandardScaler::<f64>::new()),
571 ColumnSelector::Indices(vec![0, 1]),
572 )],
573 Remainder::Drop,
574 );
575
576 let fitted = ct.fit(&x, &()).unwrap();
577 let out = fitted.transform(&x).unwrap();
578
579 assert_eq!(out.nrows(), 4);
580 assert_eq!(out.ncols(), 2, "uncovered cols should be dropped");
581 }
582
583 // -----------------------------------------------------------------------
584 // 3. Remainder::Passthrough passes uncovered columns through unchanged
585 // -----------------------------------------------------------------------
586
587 #[test]
588 fn test_remainder_passthrough() {
589 let x = make_x(); // 4×4
590 // Only cover cols 0 and 1 — cols 2 and 3 should pass through.
591 let ct = ColumnTransformer::new(
592 vec![(
593 "std".into(),
594 Box::new(StandardScaler::<f64>::new()),
595 ColumnSelector::Indices(vec![0, 1]),
596 )],
597 Remainder::Passthrough,
598 );
599
600 let fitted = ct.fit(&x, &()).unwrap();
601 let out = fitted.transform(&x).unwrap();
602
603 assert_eq!(out.nrows(), 4);
604 assert_eq!(out.ncols(), 4, "passthrough: 2 transformed + 2 remainder");
605
606 // The last 2 columns should be the original cols 2 and 3.
607 for i in 0..4 {
608 assert_abs_diff_eq!(out[[i, 2]], x[[i, 2]], epsilon = 1e-12);
609 assert_abs_diff_eq!(out[[i, 3]], x[[i, 3]], epsilon = 1e-12);
610 }
611 }
612
613 // -----------------------------------------------------------------------
614 // 4. Invalid column index (out of range)
615 // -----------------------------------------------------------------------
616
617 #[test]
618 fn test_invalid_column_index_out_of_range() {
619 let x = make_x(); // 4×4 — valid indices are 0..3
620 let ct = ColumnTransformer::new(
621 vec![(
622 "std".into(),
623 Box::new(StandardScaler::<f64>::new()),
624 ColumnSelector::Indices(vec![0, 99]), // 99 is out of range
625 )],
626 Remainder::Drop,
627 );
628 let result = ct.fit(&x, &());
629 assert!(result.is_err(), "expected error for out-of-range index");
630 }
631
632 // -----------------------------------------------------------------------
633 // 5. Empty transformer list with Remainder::Drop
634 // -----------------------------------------------------------------------
635
636 #[test]
637 fn test_empty_transformer_list_drop() {
638 let x = make_x();
639 let ct = ColumnTransformer::new(vec![], Remainder::Drop);
640 let fitted = ct.fit(&x, &()).unwrap();
641 let out = fitted.transform(&x).unwrap();
642 // No transformers, remainder dropped → empty output
643 assert_eq!(out.nrows(), 0, "hstack of nothing with no passthrough");
644 }
645
646 // -----------------------------------------------------------------------
647 // 6. Empty transformer list with Remainder::Passthrough
648 // -----------------------------------------------------------------------
649
650 #[test]
651 fn test_empty_transformer_list_passthrough() {
652 let x = make_x(); // 4×4
653 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
654 let fitted = ct.fit(&x, &()).unwrap();
655 let out = fitted.transform(&x).unwrap();
656 // No transformers, all columns pass through unchanged.
657 assert_eq!(out.nrows(), 4);
658 assert_eq!(out.ncols(), 4);
659 for i in 0..4 {
660 for j in 0..4 {
661 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
662 }
663 }
664 }
665
666 // -----------------------------------------------------------------------
667 // 7. Overlapping column selections
668 // -----------------------------------------------------------------------
669
670 #[test]
671 fn test_overlapping_column_selections() {
672 let x = make_x(); // 4×4
673 // Both transformers select col 0 (overlapping is allowed).
674 let ct = ColumnTransformer::new(
675 vec![
676 (
677 "std1".into(),
678 Box::new(StandardScaler::<f64>::new()),
679 ColumnSelector::Indices(vec![0, 1]),
680 ),
681 (
682 "mm1".into(),
683 Box::new(MinMaxScaler::<f64>::new()),
684 ColumnSelector::Indices(vec![0, 2]), // col 0 also used here
685 ),
686 ],
687 Remainder::Drop,
688 );
689
690 let fitted = ct.fit(&x, &()).unwrap();
691 let out = fitted.transform(&x).unwrap();
692
693 // Output: 2 cols from std1 + 2 cols from mm1 = 4 cols
694 assert_eq!(out.nrows(), 4);
695 assert_eq!(out.ncols(), 4);
696 }
697
698 // -----------------------------------------------------------------------
699 // 8. Single transformer
700 // -----------------------------------------------------------------------
701
702 #[test]
703 fn test_single_transformer() {
704 let x = array![[1.0_f64, 2.0], [3.0, 4.0], [5.0, 6.0]];
705 let ct = ColumnTransformer::new(
706 vec![(
707 "mm".into(),
708 Box::new(MinMaxScaler::<f64>::new()),
709 ColumnSelector::Indices(vec![0, 1]),
710 )],
711 Remainder::Drop,
712 );
713
714 let fitted = ct.fit(&x, &()).unwrap();
715 let out = fitted.transform(&x).unwrap();
716
717 assert_eq!(out.nrows(), 3);
718 assert_eq!(out.ncols(), 2);
719
720 // MinMax on cols 0 and 1: first row → 0.0, last row → 1.0
721 assert_abs_diff_eq!(out[[0, 0]], 0.0, epsilon = 1e-10);
722 assert_abs_diff_eq!(out[[2, 0]], 1.0, epsilon = 1e-10);
723 assert_abs_diff_eq!(out[[0, 1]], 0.0, epsilon = 1e-10);
724 assert_abs_diff_eq!(out[[2, 1]], 1.0, epsilon = 1e-10);
725 }
726
727 // -----------------------------------------------------------------------
728 // 9. make_column_transformer convenience function
729 // -----------------------------------------------------------------------
730
731 #[test]
732 fn test_make_column_transformer_auto_names() {
733 let x = make_x();
734 let ct = make_column_transformer(
735 vec![
736 (
737 Box::new(StandardScaler::<f64>::new()),
738 ColumnSelector::Indices(vec![0, 1]),
739 ),
740 (
741 Box::new(MinMaxScaler::<f64>::new()),
742 ColumnSelector::Indices(vec![2, 3]),
743 ),
744 ],
745 Remainder::Drop,
746 );
747
748 let fitted = ct.fit(&x, &()).unwrap();
749 assert_eq!(
750 fitted.transformer_names(),
751 vec!["transformer-0", "transformer-1"]
752 );
753
754 let out = fitted.transform(&x).unwrap();
755 assert_eq!(out.nrows(), 4);
756 assert_eq!(out.ncols(), 4);
757 }
758
759 // -----------------------------------------------------------------------
760 // 10. Pipeline integration
761 // -----------------------------------------------------------------------
762
763 #[test]
764 fn test_pipeline_integration() {
765 // Wrap a ColumnTransformer as a pipeline step.
766 let x = make_x();
767 let y = Array1::<f64>::zeros(4);
768
769 let ct = ColumnTransformer::new(
770 vec![(
771 "std".into(),
772 Box::new(StandardScaler::<f64>::new()),
773 ColumnSelector::Indices(vec![0, 1, 2, 3]),
774 )],
775 Remainder::Drop,
776 );
777
778 // Use a trivial estimator that sums rows.
779 struct SumEstimator;
780 impl PipelineEstimator<f64> for SumEstimator {
781 fn fit_pipeline(
782 &self,
783 _x: &Array2<f64>,
784 _y: &Array1<f64>,
785 ) -> Result<Box<dyn ferrolearn_core::pipeline::FittedPipelineEstimator<f64>>, FerroError>
786 {
787 Ok(Box::new(FittedSum))
788 }
789 }
790 struct FittedSum;
791 impl ferrolearn_core::pipeline::FittedPipelineEstimator<f64> for FittedSum {
792 fn predict_pipeline(&self, x: &Array2<f64>) -> Result<Array1<f64>, FerroError> {
793 let sums: Vec<f64> = x.rows().into_iter().map(|r| r.sum()).collect();
794 Ok(Array1::from_vec(sums))
795 }
796 }
797
798 let pipeline = Pipeline::new()
799 .transform_step("ct", Box::new(ct))
800 .estimator_step("sum", Box::new(SumEstimator));
801
802 use ferrolearn_core::Fit as _;
803 let fitted_pipeline = pipeline.fit(&x, &y).unwrap();
804
805 use ferrolearn_core::Predict as _;
806 let preds = fitted_pipeline.predict(&x).unwrap();
807 assert_eq!(preds.len(), 4);
808 }
809
810 // -----------------------------------------------------------------------
811 // 11. Transform shape correctness — number of output columns
812 // -----------------------------------------------------------------------
813
814 #[test]
815 fn test_output_shape_all_selected_drop() {
816 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
817 let ct = ColumnTransformer::new(
818 vec![
819 (
820 "s".into(),
821 Box::new(StandardScaler::<f64>::new()),
822 ColumnSelector::Indices(vec![0]),
823 ),
824 (
825 "m".into(),
826 Box::new(MinMaxScaler::<f64>::new()),
827 ColumnSelector::Indices(vec![1, 2]),
828 ),
829 ],
830 Remainder::Drop,
831 );
832 let fitted = ct.fit(&x, &()).unwrap();
833 let out = fitted.transform(&x).unwrap();
834 assert_eq!(out.shape(), &[2, 3]);
835 }
836
837 // -----------------------------------------------------------------------
838 // 12. Transform shape — partial selection + passthrough
839 // -----------------------------------------------------------------------
840
841 #[test]
842 fn test_output_shape_partial_passthrough() {
843 // 5-column input, transform 2 cols, passthrough 3
844 let x = Array2::<f64>::from_shape_vec((3, 5), (1..=15).map(f64::from).collect()).unwrap();
845 let ct = ColumnTransformer::new(
846 vec![(
847 "std".into(),
848 Box::new(StandardScaler::<f64>::new()),
849 ColumnSelector::Indices(vec![0, 1]),
850 )],
851 Remainder::Passthrough,
852 );
853 let fitted = ct.fit(&x, &()).unwrap();
854 let out = fitted.transform(&x).unwrap();
855 assert_eq!(out.shape(), &[3, 5]);
856 }
857
858 // -----------------------------------------------------------------------
859 // 13. n_features_in accessor
860 // -----------------------------------------------------------------------
861
862 #[test]
863 fn test_n_features_in() {
864 let x = make_x(); // 4×4
865 let ct = ColumnTransformer::new(
866 vec![(
867 "std".into(),
868 Box::new(StandardScaler::<f64>::new()),
869 ColumnSelector::Indices(vec![0]),
870 )],
871 Remainder::Drop,
872 );
873 let fitted = ct.fit(&x, &()).unwrap();
874 assert_eq!(fitted.n_features_in(), 4);
875 }
876
877 // -----------------------------------------------------------------------
878 // 14. Shape mismatch on transform (wrong number of columns)
879 // -----------------------------------------------------------------------
880
881 #[test]
882 fn test_shape_mismatch_on_transform() {
883 let x = make_x(); // 4×4
884 let ct = ColumnTransformer::new(
885 vec![(
886 "std".into(),
887 Box::new(StandardScaler::<f64>::new()),
888 ColumnSelector::Indices(vec![0, 1]),
889 )],
890 Remainder::Drop,
891 );
892 let fitted = ct.fit(&x, &()).unwrap();
893
894 // Now pass a matrix with only 2 columns — should fail.
895 let x_bad = array![[1.0_f64, 2.0], [3.0, 4.0]];
896 let result = fitted.transform(&x_bad);
897 assert!(result.is_err(), "expected shape mismatch error");
898 }
899
900 // -----------------------------------------------------------------------
901 // 15. remainder_indices accessor
902 // -----------------------------------------------------------------------
903
904 #[test]
905 fn test_remainder_indices_accessor() {
906 let x = make_x(); // 4×4
907 let ct = ColumnTransformer::new(
908 vec![(
909 "std".into(),
910 Box::new(StandardScaler::<f64>::new()),
911 ColumnSelector::Indices(vec![0, 2]),
912 )],
913 Remainder::Passthrough,
914 );
915 let fitted = ct.fit(&x, &()).unwrap();
916 // Remainder should be cols 1 and 3.
917 assert_eq!(fitted.remainder_indices(), &[1, 3]);
918 }
919
920 // -----------------------------------------------------------------------
921 // 16. StandardScaler output values are correct (zero-mean)
922 // -----------------------------------------------------------------------
923
924 #[test]
925 fn test_standard_scaler_zero_mean_in_output() {
926 let x = array![[1.0_f64, 100.0, 0.5], [2.0, 200.0, 1.5], [3.0, 300.0, 2.5],];
927 let ct = ColumnTransformer::new(
928 vec![(
929 "std".into(),
930 Box::new(StandardScaler::<f64>::new()),
931 ColumnSelector::Indices(vec![0, 1]),
932 )],
933 Remainder::Drop,
934 );
935 let fitted = ct.fit(&x, &()).unwrap();
936 let out = fitted.transform(&x).unwrap();
937
938 // Cols 0 and 1 of output should have mean ≈ 0.
939 for j in 0..2 {
940 let mean: f64 = out.column(j).iter().sum::<f64>() / 3.0;
941 assert_abs_diff_eq!(mean, 0.0, epsilon = 1e-10);
942 }
943 }
944
945 // -----------------------------------------------------------------------
946 // 17. MinMaxScaler output values are in [0, 1]
947 // -----------------------------------------------------------------------
948
949 #[test]
950 fn test_min_max_values_in_range() {
951 let x = make_x();
952 let ct = ColumnTransformer::new(
953 vec![(
954 "mm".into(),
955 Box::new(MinMaxScaler::<f64>::new()),
956 ColumnSelector::Indices(vec![0, 1, 2, 3]),
957 )],
958 Remainder::Drop,
959 );
960 let fitted = ct.fit(&x, &()).unwrap();
961 let out = fitted.transform(&x).unwrap();
962
963 for j in 0..4 {
964 let col_min = out.column(j).iter().copied().fold(f64::INFINITY, f64::min);
965 let col_max = out
966 .column(j)
967 .iter()
968 .copied()
969 .fold(f64::NEG_INFINITY, f64::max);
970 assert_abs_diff_eq!(col_min, 0.0, epsilon = 1e-10);
971 assert_abs_diff_eq!(col_max, 1.0, epsilon = 1e-10);
972 }
973 }
974
975 // -----------------------------------------------------------------------
976 // 18. Pipeline transformer interface (fit_pipeline / transform_pipeline)
977 // -----------------------------------------------------------------------
978
979 #[test]
980 fn test_pipeline_transformer_interface() {
981 let x = make_x();
982 let y = Array1::<f64>::zeros(4);
983 let ct = ColumnTransformer::new(
984 vec![(
985 "std".into(),
986 Box::new(StandardScaler::<f64>::new()),
987 ColumnSelector::Indices(vec![0, 1]),
988 )],
989 Remainder::Passthrough,
990 );
991 let fitted_box = ct.fit_pipeline(&x, &y).unwrap();
992 let out = fitted_box.transform_pipeline(&x).unwrap();
993 assert_eq!(out.nrows(), 4);
994 assert_eq!(out.ncols(), 4);
995 }
996
997 // -----------------------------------------------------------------------
998 // 19. Remainder passthrough values are identical to input values
999 // -----------------------------------------------------------------------
1000
1001 #[test]
1002 fn test_passthrough_values_are_exact() {
1003 let x = array![[10.0_f64, 20.0, 30.0], [40.0, 50.0, 60.0],];
1004 // Only transform col 0; cols 1 and 2 pass through.
1005 let ct = ColumnTransformer::new(
1006 vec![(
1007 "mm".into(),
1008 Box::new(MinMaxScaler::<f64>::new()),
1009 ColumnSelector::Indices(vec![0]),
1010 )],
1011 Remainder::Passthrough,
1012 );
1013 let fitted = ct.fit(&x, &()).unwrap();
1014 let out = fitted.transform(&x).unwrap();
1015 // out[:, 1] == x[:, 1] and out[:, 2] == x[:, 2]
1016 assert_abs_diff_eq!(out[[0, 1]], 20.0, epsilon = 1e-12);
1017 assert_abs_diff_eq!(out[[1, 1]], 50.0, epsilon = 1e-12);
1018 assert_abs_diff_eq!(out[[0, 2]], 30.0, epsilon = 1e-12);
1019 assert_abs_diff_eq!(out[[1, 2]], 60.0, epsilon = 1e-12);
1020 }
1021
1022 // -----------------------------------------------------------------------
1023 // 20. Transformer names from explicit ColumnTransformer::new
1024 // -----------------------------------------------------------------------
1025
1026 #[test]
1027 fn test_transformer_names_explicit() {
1028 let x = make_x();
1029 let ct = ColumnTransformer::new(
1030 vec![
1031 (
1032 "alpha".into(),
1033 Box::new(StandardScaler::<f64>::new()),
1034 ColumnSelector::Indices(vec![0]),
1035 ),
1036 (
1037 "beta".into(),
1038 Box::new(MinMaxScaler::<f64>::new()),
1039 ColumnSelector::Indices(vec![1]),
1040 ),
1041 ],
1042 Remainder::Drop,
1043 );
1044 let fitted = ct.fit(&x, &()).unwrap();
1045 assert_eq!(fitted.transformer_names(), vec!["alpha", "beta"]);
1046 }
1047
1048 // -----------------------------------------------------------------------
1049 // 21. make_column_transformer with single step
1050 // -----------------------------------------------------------------------
1051
1052 #[test]
1053 fn test_make_column_transformer_single() {
1054 let x = array![[1.0_f64, 2.0], [3.0, 4.0]];
1055 let ct = make_column_transformer(
1056 vec![(
1057 Box::new(StandardScaler::<f64>::new()),
1058 ColumnSelector::Indices(vec![0, 1]),
1059 )],
1060 Remainder::Drop,
1061 );
1062 let fitted = ct.fit(&x, &()).unwrap();
1063 assert_eq!(fitted.transformer_names(), vec!["transformer-0"]);
1064 let out = fitted.transform(&x).unwrap();
1065 assert_eq!(out.shape(), &[2, 2]);
1066 }
1067
1068 // -----------------------------------------------------------------------
1069 // 22. Edge case: all columns as remainder with Passthrough
1070 // -----------------------------------------------------------------------
1071
1072 #[test]
1073 fn test_all_remainder_passthrough_unchanged() {
1074 let x = array![[1.0_f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
1075 let ct = ColumnTransformer::new(vec![], Remainder::Passthrough);
1076 let fitted = ct.fit(&x, &()).unwrap();
1077 let out = fitted.transform(&x).unwrap();
1078 assert_eq!(out.shape(), &[2, 3]);
1079 for i in 0..2 {
1080 for j in 0..3 {
1081 assert_abs_diff_eq!(out[[i, j]], x[[i, j]], epsilon = 1e-12);
1082 }
1083 }
1084 }
1085}