Skip to main content

ferrolearn_preprocess/
select_from_model.rs

1//! Feature selection driven by a model's feature importance weights.
2//!
3//! [`SelectFromModel`](super::feature_selection::SelectFromModel) provides
4//! basic mean/explicit-threshold selection.  This module provides a richer
5//! API via [`SelectFromModelExt`], which supports four threshold strategies
6//! (mean, median, explicit value, percentile) and an optional
7//! `max_features` cap.
8//!
9//! # Threshold Strategies
10//!
11//! | Variant | Description |
12//! |---------|-------------|
13//! | [`ThresholdStrategy::Mean`] | Threshold = arithmetic mean of importances |
14//! | [`ThresholdStrategy::Median`] | Threshold = median of importances |
15//! | [`ThresholdStrategy::Value`] | User-supplied explicit threshold |
16//! | [`ThresholdStrategy::Percentile`] | Keep features in the top *p*% by importance |
17//!
18//! When `max_features` is set, at most that many features are retained
19//! (in descending importance order) regardless of the threshold.
20//!
21//! ## REQ status
22//!
23//! Translation target: scikit-learn 1.5.2 `class SelectFromModel`
24//! (`sklearn/feature_selection/_from_model.py:256`). Tracking: #1352. Each REQ
25//! is BINARY — SHIPPED (impl + non-test consumer + tests + green verification)
26//! or NOT-STARTED (with a concrete open blocker). HONEST scope: this unit ships
27//! the threshold + selection-mask + `max_features` core GIVEN a static
28//! importance vector; sklearn wraps a fitted estimator and extracts its
29//! importances — that estimator machinery is NOT-STARTED.
30//!
31//! | REQ | Scope | Status | Evidence / Blocker |
32//! |-----|-------|--------|--------------------|
33//! | REQ-1 | Threshold (mean/median/value) + selection mask (`score >= threshold`) + `max_features` top-k cap, given a static importance vector | SHIPPED | [`SelectFromModelExt`] `fit` matches sklearn `_get_support_mask` `_from_model.py:299-312` + `_calculate_threshold` `:24-71` (mean=`np.mean`, median=`np.median`); threshold-then-cap is algebraically equivalent to sklearn cap-then-threshold (exhaustive-grid oracle-verified); 15 oracle value tests in `tests/divergence_select_from_model.rs`. Consumer: boundary re-export `lib.rs` (grandfathered S5/R-DEFER-1) + `PipelineTransformer` |
34//! | REQ-2 | Error/parameter contracts (empty importances, `Percentile` range, transform ncols mismatch) | SHIPPED (scoped) | [`SelectFromModelExt::fit`]/[`FittedSelectFromModelExt`] `transform`; in-module + divergence error tests |
35//! | REQ-3 | Estimator wrapping + `coef_`/`feature_importances_` extraction (`_get_feature_importances`) | NOT-STARTED | takes importances directly; sklearn `_from_model.py:299-304` — blocker #1353 |
36//! | REQ-4 | `norm_order` multi-output coef norm | NOT-STARTED | scalar importances only; sklearn `_from_model.py:303` — blocker #1354 |
37//! | REQ-5 | Scaled-string `scale*mean`/`scale*median` thresholds + default-from-estimator (l1→1e-5) | NOT-STARTED | sklearn `_from_model.py:30-55` — blocker #1355 |
38//! | REQ-6 | `prefit` + `importance_getter` params | NOT-STARTED | sklearn `_from_model.py:256-271,277-284` — blocker #1356 |
39//! | REQ-7 | `max_features` callable + `_check_max_features` range validation `[0, n_features]` | NOT-STARTED | int cap only; sklearn `_from_model.py:315-331` — blocker #1357 |
40//! | REQ-8 | `SelectorMixin` surface (`get_support`/`inverse_transform`/`get_feature_names_out`) | NOT-STARTED | sklearn `_base.py` `SelectorMixin` — blocker #1358 |
41//! | REQ-9 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1359 |
42//! | REQ-10 | ferray substrate | NOT-STARTED | dense `Array2` + `num_traits::Float` only — blocker #1360 |
43//!
44//! NOTE: [`ThresholdStrategy::Percentile`] is a ferrolearn EXTENSION with NO
45//! sklearn `SelectFromModel` analog (sklearn supports only mean/median/`scale*ref`/
46//! float); it is not a sklearn-parity REQ and carries no blocker.
47
48use ferrolearn_core::error::FerroError;
49use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
50use ferrolearn_core::traits::{Fit, Transform};
51use ndarray::{Array1, Array2};
52use num_traits::Float;
53
54// ---------------------------------------------------------------------------
55// ThresholdStrategy
56// ---------------------------------------------------------------------------
57
58/// Strategy for computing the importance threshold in [`SelectFromModelExt`].
59#[derive(Debug, Clone, Copy, PartialEq, Default)]
60pub enum ThresholdStrategy {
61    /// Threshold equals the arithmetic mean of all feature importances.
62    #[default]
63    Mean,
64    /// Threshold equals the median of all feature importances.
65    Median,
66    /// User-supplied explicit threshold value.
67    Value(f64),
68    /// Keep features in the top `p`% of importance scores (0 < p <= 100).
69    ///
70    /// For example, `Percentile(25.0)` retains features whose importance is
71    /// at or above the 75th-percentile value (i.e., the top 25%).
72    Percentile(f64),
73}
74
75// ---------------------------------------------------------------------------
76// SelectFromModelExt (unfitted)
77// ---------------------------------------------------------------------------
78
79/// An extended model-importance-based feature selector.
80///
81/// Like [`SelectFromModel`](super::feature_selection::SelectFromModel) but
82/// supports four threshold strategies and an optional `max_features` cap.
83///
84/// # Examples
85///
86/// ```
87/// use ferrolearn_preprocess::select_from_model::{SelectFromModelExt, ThresholdStrategy};
88/// use ferrolearn_core::traits::{Fit, Transform};
89/// use ndarray::array;
90///
91/// let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
92/// let importances = array![0.1, 0.5, 0.4];
93/// let fitted = sel.fit(&importances, &()).unwrap();
94/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
95/// let out = fitted.transform(&x).unwrap();
96/// // Mean importance = (0.1+0.5+0.4)/3 ≈ 0.333; columns 1 and 2 kept
97/// assert_eq!(out.ncols(), 2);
98/// ```
99#[must_use]
100#[derive(Debug, Clone)]
101pub struct SelectFromModelExt<F> {
102    /// The threshold strategy.
103    threshold: ThresholdStrategy,
104    /// Optional cap on number of features to select.
105    max_features: Option<usize>,
106    _marker: std::marker::PhantomData<F>,
107}
108
109impl<F: Float + Send + Sync + 'static> SelectFromModelExt<F> {
110    /// Create a new `SelectFromModelExt`.
111    ///
112    /// # Parameters
113    ///
114    /// - `threshold` — the strategy for computing the importance threshold.
115    /// - `max_features` — optional maximum number of features to retain.
116    pub fn new(threshold: ThresholdStrategy, max_features: Option<usize>) -> Self {
117        Self {
118            threshold,
119            max_features,
120            _marker: std::marker::PhantomData,
121        }
122    }
123
124    /// Return the threshold strategy.
125    #[must_use]
126    pub fn threshold_strategy(&self) -> ThresholdStrategy {
127        self.threshold
128    }
129
130    /// Return the maximum number of features (if set).
131    #[must_use]
132    pub fn max_features(&self) -> Option<usize> {
133        self.max_features
134    }
135}
136
137impl<F: Float + Send + Sync + 'static> Default for SelectFromModelExt<F> {
138    fn default() -> Self {
139        Self::new(ThresholdStrategy::Mean, None)
140    }
141}
142
143// ---------------------------------------------------------------------------
144// FittedSelectFromModelExt
145// ---------------------------------------------------------------------------
146
147/// A fitted model-importance selector produced by [`SelectFromModelExt::fit`].
148#[derive(Debug, Clone)]
149pub struct FittedSelectFromModelExt<F> {
150    /// Number of features seen during fitting.
151    n_features_in: usize,
152    /// The computed threshold value.
153    threshold_value: F,
154    /// Feature importances supplied during fitting.
155    importances: Array1<F>,
156    /// Indices of selected columns (sorted).
157    selected_indices: Vec<usize>,
158}
159
160impl<F: Float + Send + Sync + 'static> FittedSelectFromModelExt<F> {
161    /// Return the computed threshold value.
162    #[must_use]
163    pub fn threshold_value(&self) -> F {
164        self.threshold_value
165    }
166
167    /// Return the feature importances.
168    #[must_use]
169    pub fn importances(&self) -> &Array1<F> {
170        &self.importances
171    }
172
173    /// Return the indices of the selected columns.
174    #[must_use]
175    pub fn selected_indices(&self) -> &[usize] {
176        &self.selected_indices
177    }
178
179    /// Return the number of selected features.
180    #[must_use]
181    pub fn n_features_selected(&self) -> usize {
182        self.selected_indices.len()
183    }
184}
185
186// ---------------------------------------------------------------------------
187// Helpers
188// ---------------------------------------------------------------------------
189
190/// Compute the median of a slice of floats.
191fn compute_median<F: Float>(values: &[F]) -> F {
192    let mut sorted: Vec<F> = values.to_vec();
193    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
194    let n = sorted.len();
195    if n.is_multiple_of(2) {
196        let two = F::one() + F::one();
197        (sorted[n / 2 - 1] + sorted[n / 2]) / two
198    } else {
199        sorted[n / 2]
200    }
201}
202
203/// Compute the percentile threshold. `pct` is the percentage of features to
204/// keep (e.g., 25.0 means top 25%).
205fn compute_percentile_threshold<F: Float>(values: &[F], pct: f64) -> F {
206    let mut sorted: Vec<F> = values.to_vec();
207    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
208    let n = sorted.len();
209    // The threshold is set at the (100 - pct) percentile of the sorted values.
210    // E.g., for top 25% we want the value at the 75th percentile.
211    let rank = ((100.0 - pct) / 100.0) * (n.saturating_sub(1)) as f64;
212    let lower = rank.floor() as usize;
213    let upper = rank.ceil() as usize;
214    let lower = lower.min(n.saturating_sub(1));
215    let upper = upper.min(n.saturating_sub(1));
216    if lower == upper {
217        sorted[lower]
218    } else {
219        let frac = F::from(rank - rank.floor()).unwrap_or_else(F::zero);
220        sorted[lower] * (F::one() - frac) + sorted[upper] * frac
221    }
222}
223
224/// Build a new `Array2<F>` containing only the columns listed in `indices`.
225fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
226    let nrows = x.nrows();
227    let ncols = indices.len();
228    if ncols == 0 {
229        return Array2::zeros((nrows, 0));
230    }
231    let mut out = Array2::zeros((nrows, ncols));
232    for (new_j, &old_j) in indices.iter().enumerate() {
233        for i in 0..nrows {
234            out[[i, new_j]] = x[[i, old_j]];
235        }
236    }
237    out
238}
239
240// ---------------------------------------------------------------------------
241// Trait implementations
242// ---------------------------------------------------------------------------
243
244impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFromModelExt<F> {
245    type Fitted = FittedSelectFromModelExt<F>;
246    type Error = FerroError;
247
248    /// Fit by computing the threshold from the given feature importances.
249    ///
250    /// # Parameters
251    ///
252    /// - `x` — per-feature importance scores (one value per feature).
253    /// - `_y` — ignored (unsupervised).
254    ///
255    /// # Errors
256    ///
257    /// - [`FerroError::InvalidParameter`] if the importance vector is empty,
258    ///   or if `Percentile` value is not in `(0, 100]`.
259    fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFromModelExt<F>, FerroError> {
260        let n = x.len();
261        if n == 0 {
262            return Err(FerroError::InvalidParameter {
263                name: "importances".into(),
264                reason: "importance vector must not be empty".into(),
265            });
266        }
267
268        let values: Vec<F> = x.iter().copied().collect();
269
270        // Compute threshold
271        let threshold_value = match self.threshold {
272            ThresholdStrategy::Mean => {
273                values.iter().copied().fold(F::zero(), |acc, v| acc + v)
274                    / F::from(n).unwrap_or_else(F::one)
275            }
276            ThresholdStrategy::Median => compute_median(&values),
277            ThresholdStrategy::Value(v) => F::from(v).unwrap_or_else(F::zero),
278            ThresholdStrategy::Percentile(pct) => {
279                if pct <= 0.0 || pct > 100.0 {
280                    return Err(FerroError::InvalidParameter {
281                        name: "percentile".into(),
282                        reason: format!("percentile must be in (0, 100], got {}", pct),
283                    });
284                }
285                compute_percentile_threshold(&values, pct)
286            }
287        };
288
289        // Select features whose importance >= threshold
290        let mut selected_indices: Vec<usize> = values
291            .iter()
292            .enumerate()
293            .filter(|&(_, &imp)| imp >= threshold_value)
294            .map(|(j, _)| j)
295            .collect();
296
297        // Apply max_features cap: keep only the top-k by importance
298        if let Some(max_f) = self.max_features
299            && selected_indices.len() > max_f
300        {
301            // Sort selected by importance descending, keep top max_f
302            selected_indices.sort_by(|&a, &b| {
303                values[b]
304                    .partial_cmp(&values[a])
305                    .unwrap_or(std::cmp::Ordering::Equal)
306            });
307            selected_indices.truncate(max_f);
308            // Re-sort in column order
309            selected_indices.sort_unstable();
310        }
311
312        Ok(FittedSelectFromModelExt {
313            n_features_in: n,
314            threshold_value,
315            importances: x.clone(),
316            selected_indices,
317        })
318    }
319}
320
321impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFromModelExt<F> {
322    type Output = Array2<F>;
323    type Error = FerroError;
324
325    /// Return a matrix containing only the selected columns.
326    ///
327    /// # Errors
328    ///
329    /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
330    /// from the number of features seen during fitting.
331    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
332        if x.ncols() != self.n_features_in {
333            return Err(FerroError::ShapeMismatch {
334                expected: vec![x.nrows(), self.n_features_in],
335                actual: vec![x.nrows(), x.ncols()],
336                context: "FittedSelectFromModelExt::transform".into(),
337            });
338        }
339        Ok(select_columns(x, &self.selected_indices))
340    }
341}
342
343// ---------------------------------------------------------------------------
344// Pipeline integration
345// ---------------------------------------------------------------------------
346
347impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FittedSelectFromModelExt<F> {
348    /// Clone the fitted selector and box it as a pipeline transformer.
349    ///
350    /// Because the selector is already fitted (importances supplied at fit
351    /// time), `fit_pipeline` simply boxes the existing fitted state.
352    ///
353    /// # Errors
354    ///
355    /// This implementation never fails.
356    fn fit_pipeline(
357        &self,
358        _x: &Array2<F>,
359        _y: &Array1<F>,
360    ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
361        Ok(Box::new(self.clone()))
362    }
363}
364
365impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
366    for FittedSelectFromModelExt<F>
367{
368    /// Transform using the pipeline interface.
369    ///
370    /// # Errors
371    ///
372    /// Propagates errors from [`Transform::transform`].
373    fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
374        self.transform(x)
375    }
376}
377
378// ---------------------------------------------------------------------------
379// Tests
380// ---------------------------------------------------------------------------
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use approx::assert_abs_diff_eq;
386    use ndarray::array;
387
388    #[test]
389    fn test_mean_threshold() {
390        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
391        let importances = array![0.1, 0.5, 0.4];
392        let fitted = sel.fit(&importances, &()).unwrap();
393        // Mean = (0.1+0.5+0.4)/3 ≈ 0.333; cols 1 and 2 kept
394        assert_eq!(fitted.selected_indices(), &[1, 2]);
395        let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
396        let out = fitted.transform(&x).unwrap();
397        assert_eq!(out.ncols(), 2);
398        assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
399        assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
400    }
401
402    #[test]
403    fn test_median_threshold() {
404        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Median, None);
405        // Sorted: [0.1, 0.3, 0.5] → median = 0.3
406        let importances = array![0.1, 0.5, 0.3];
407        let fitted = sel.fit(&importances, &()).unwrap();
408        // Features with importance >= 0.3: indices 1 (0.5) and 2 (0.3)
409        assert_eq!(fitted.selected_indices(), &[1, 2]);
410    }
411
412    #[test]
413    fn test_median_threshold_even() {
414        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Median, None);
415        // Sorted: [0.1, 0.2, 0.5, 0.6] → median = (0.2+0.5)/2 = 0.35
416        let importances = array![0.1, 0.5, 0.2, 0.6];
417        let fitted = sel.fit(&importances, &()).unwrap();
418        // Features >= 0.35: 1 (0.5) and 3 (0.6)
419        assert_eq!(fitted.selected_indices(), &[1, 3]);
420    }
421
422    #[test]
423    fn test_explicit_value_threshold() {
424        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.45), None);
425        let importances = array![0.1, 0.5, 0.4];
426        let fitted = sel.fit(&importances, &()).unwrap();
427        assert_eq!(fitted.selected_indices(), &[1]);
428    }
429
430    #[test]
431    fn test_percentile_threshold_top_50() {
432        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(50.0), None);
433        // Sorted: [0.1, 0.3, 0.5, 0.7]
434        // Top 50% → threshold at 50th percentile = sorted[1.5] interp = 0.4
435        let importances = array![0.5, 0.1, 0.7, 0.3];
436        let fitted = sel.fit(&importances, &()).unwrap();
437        // Features >= threshold: 0 (0.5), 2 (0.7)
438        assert!(fitted.selected_indices().contains(&0));
439        assert!(fitted.selected_indices().contains(&2));
440        assert_eq!(fitted.n_features_selected(), 2);
441    }
442
443    #[test]
444    fn test_percentile_100_keeps_all() {
445        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(100.0), None);
446        let importances = array![0.1, 0.5, 0.3];
447        let fitted = sel.fit(&importances, &()).unwrap();
448        assert_eq!(fitted.n_features_selected(), 3);
449    }
450
451    #[test]
452    fn test_percentile_invalid() {
453        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(0.0), None);
454        let importances = array![0.1, 0.5, 0.3];
455        assert!(sel.fit(&importances, &()).is_err());
456
457        let sel2 = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(101.0), None);
458        assert!(sel2.fit(&importances, &()).is_err());
459    }
460
461    #[test]
462    fn test_max_features_cap() {
463        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.0), Some(2));
464        // All features pass threshold=0, but max_features=2
465        let importances = array![0.3, 0.5, 0.1, 0.7];
466        let fitted = sel.fit(&importances, &()).unwrap();
467        assert_eq!(fitted.n_features_selected(), 2);
468        // Should keep top-2: indices 1 (0.5) and 3 (0.7)
469        assert_eq!(fitted.selected_indices(), &[1, 3]);
470    }
471
472    #[test]
473    fn test_max_features_not_needed() {
474        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.4), Some(5));
475        let importances = array![0.1, 0.5, 0.4];
476        let fitted = sel.fit(&importances, &()).unwrap();
477        // Only 2 pass threshold, max_features=5 doesn't limit
478        assert_eq!(fitted.n_features_selected(), 2);
479    }
480
481    #[test]
482    fn test_empty_importances_error() {
483        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
484        let importances: Array1<f64> = Array1::zeros(0);
485        assert!(sel.fit(&importances, &()).is_err());
486    }
487
488    #[test]
489    fn test_shape_mismatch_on_transform() {
490        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
491        let importances = array![0.5, 0.5];
492        let fitted = sel.fit(&importances, &()).unwrap();
493        let x_bad = array![[1.0, 2.0, 3.0]]; // 3 cols, 2 expected
494        assert!(fitted.transform(&x_bad).is_err());
495    }
496
497    #[test]
498    fn test_threshold_value_accessor() {
499        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.42), None);
500        let importances = array![0.1, 0.5];
501        let fitted = sel.fit(&importances, &()).unwrap();
502        assert_abs_diff_eq!(fitted.threshold_value(), 0.42, epsilon = 1e-15);
503    }
504
505    #[test]
506    fn test_default() {
507        let sel = SelectFromModelExt::<f64>::default();
508        assert_eq!(sel.threshold_strategy(), ThresholdStrategy::Mean);
509        assert_eq!(sel.max_features(), None);
510    }
511
512    #[test]
513    fn test_pipeline_integration() {
514        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
515        let importances = array![0.1, 0.9];
516        let fitted = sel.fit(&importances, &()).unwrap();
517        let x = array![[1.0, 2.0], [3.0, 4.0]];
518        let y = array![0.0, 1.0];
519        let fitted_box = fitted.fit_pipeline(&x, &y).unwrap();
520        let out = fitted_box.transform_pipeline(&x).unwrap();
521        assert_eq!(out.ncols(), 1);
522    }
523
524    #[test]
525    fn test_f32() {
526        let sel = SelectFromModelExt::<f32>::new(ThresholdStrategy::Mean, None);
527        let importances: Array1<f32> = array![0.1f32, 0.5, 0.4];
528        let fitted = sel.fit(&importances, &()).unwrap();
529        assert_eq!(fitted.n_features_selected(), 2);
530    }
531
532    #[test]
533    fn test_none_selected_high_threshold() {
534        let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(10.0), None);
535        let importances = array![0.1, 0.5, 0.4];
536        let fitted = sel.fit(&importances, &()).unwrap();
537        assert_eq!(fitted.n_features_selected(), 0);
538        let x = array![[1.0, 2.0, 3.0]];
539        let out = fitted.transform(&x).unwrap();
540        assert_eq!(out.ncols(), 0);
541        assert_eq!(out.nrows(), 1);
542    }
543}