ferrolearn_preprocess/select_from_model.rs
1//! Feature selection driven by a model's feature importance weights.
2//!
3//! [`SelectFromModel`](super::feature_selection::SelectFromModel) provides
4//! basic mean/explicit-threshold selection. This module provides a richer
5//! API via [`SelectFromModelExt`], which supports four threshold strategies
6//! (mean, median, explicit value, percentile) and an optional
7//! `max_features` cap.
8//!
9//! # Threshold Strategies
10//!
11//! | Variant | Description |
12//! |---------|-------------|
13//! | [`ThresholdStrategy::Mean`] | Threshold = arithmetic mean of importances |
14//! | [`ThresholdStrategy::Median`] | Threshold = median of importances |
15//! | [`ThresholdStrategy::Value`] | User-supplied explicit threshold |
16//! | [`ThresholdStrategy::Percentile`] | Keep features in the top *p*% by importance |
17//!
18//! When `max_features` is set, at most that many features are retained
19//! (in descending importance order) regardless of the threshold.
20//!
21//! ## REQ status
22//!
23//! Translation target: scikit-learn 1.5.2 `class SelectFromModel`
24//! (`sklearn/feature_selection/_from_model.py:256`). Tracking: #1352. Each REQ
25//! is BINARY — SHIPPED (impl + non-test consumer + tests + green verification)
26//! or NOT-STARTED (with a concrete open blocker). HONEST scope: this unit ships
27//! the threshold + selection-mask + `max_features` core GIVEN a static
28//! importance vector; sklearn wraps a fitted estimator and extracts its
29//! importances — that estimator machinery is NOT-STARTED.
30//!
31//! | REQ | Scope | Status | Evidence / Blocker |
32//! |-----|-------|--------|--------------------|
33//! | REQ-1 | Threshold (mean/median/value) + selection mask (`score >= threshold`) + `max_features` top-k cap, given a static importance vector | SHIPPED | [`SelectFromModelExt`] `fit` matches sklearn `_get_support_mask` `_from_model.py:299-312` + `_calculate_threshold` `:24-71` (mean=`np.mean`, median=`np.median`); threshold-then-cap is algebraically equivalent to sklearn cap-then-threshold (exhaustive-grid oracle-verified); 15 oracle value tests in `tests/divergence_select_from_model.rs`. Consumer: boundary re-export `lib.rs` (grandfathered S5/R-DEFER-1) + `PipelineTransformer` |
34//! | REQ-2 | Error/parameter contracts (empty importances, `Percentile` range, transform ncols mismatch) | SHIPPED (scoped) | [`SelectFromModelExt::fit`]/[`FittedSelectFromModelExt`] `transform`; in-module + divergence error tests |
35//! | REQ-3 | Estimator wrapping + `coef_`/`feature_importances_` extraction (`_get_feature_importances`) | NOT-STARTED | takes importances directly; sklearn `_from_model.py:299-304` — blocker #1353 |
36//! | REQ-4 | `norm_order` multi-output coef norm | NOT-STARTED | scalar importances only; sklearn `_from_model.py:303` — blocker #1354 |
37//! | REQ-5 | Scaled-string `scale*mean`/`scale*median` thresholds + default-from-estimator (l1→1e-5) | NOT-STARTED | sklearn `_from_model.py:30-55` — blocker #1355 |
38//! | REQ-6 | `prefit` + `importance_getter` params | NOT-STARTED | sklearn `_from_model.py:256-271,277-284` — blocker #1356 |
39//! | REQ-7 | `max_features` callable + `_check_max_features` range validation `[0, n_features]` | NOT-STARTED | int cap only; sklearn `_from_model.py:315-331` — blocker #1357 |
40//! | REQ-8 | `SelectorMixin` surface (`get_support`/`inverse_transform`/`get_feature_names_out`) | NOT-STARTED | sklearn `_base.py` `SelectorMixin` — blocker #1358 |
41//! | REQ-9 | PyO3 binding | NOT-STARTED | no `ferrolearn-python` registration — blocker #1359 |
42//! | REQ-10 | ferray substrate | NOT-STARTED | dense `Array2` + `num_traits::Float` only — blocker #1360 |
43//!
44//! NOTE: [`ThresholdStrategy::Percentile`] is a ferrolearn EXTENSION with NO
45//! sklearn `SelectFromModel` analog (sklearn supports only mean/median/`scale*ref`/
46//! float); it is not a sklearn-parity REQ and carries no blocker.
47
48use ferrolearn_core::error::FerroError;
49use ferrolearn_core::pipeline::{FittedPipelineTransformer, PipelineTransformer};
50use ferrolearn_core::traits::{Fit, Transform};
51use ndarray::{Array1, Array2};
52use num_traits::Float;
53
54// ---------------------------------------------------------------------------
55// ThresholdStrategy
56// ---------------------------------------------------------------------------
57
58/// Strategy for computing the importance threshold in [`SelectFromModelExt`].
59#[derive(Debug, Clone, Copy, PartialEq, Default)]
60pub enum ThresholdStrategy {
61 /// Threshold equals the arithmetic mean of all feature importances.
62 #[default]
63 Mean,
64 /// Threshold equals the median of all feature importances.
65 Median,
66 /// User-supplied explicit threshold value.
67 Value(f64),
68 /// Keep features in the top `p`% of importance scores (0 < p <= 100).
69 ///
70 /// For example, `Percentile(25.0)` retains features whose importance is
71 /// at or above the 75th-percentile value (i.e., the top 25%).
72 Percentile(f64),
73}
74
75// ---------------------------------------------------------------------------
76// SelectFromModelExt (unfitted)
77// ---------------------------------------------------------------------------
78
79/// An extended model-importance-based feature selector.
80///
81/// Like [`SelectFromModel`](super::feature_selection::SelectFromModel) but
82/// supports four threshold strategies and an optional `max_features` cap.
83///
84/// # Examples
85///
86/// ```
87/// use ferrolearn_preprocess::select_from_model::{SelectFromModelExt, ThresholdStrategy};
88/// use ferrolearn_core::traits::{Fit, Transform};
89/// use ndarray::array;
90///
91/// let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
92/// let importances = array![0.1, 0.5, 0.4];
93/// let fitted = sel.fit(&importances, &()).unwrap();
94/// let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
95/// let out = fitted.transform(&x).unwrap();
96/// // Mean importance = (0.1+0.5+0.4)/3 ≈ 0.333; columns 1 and 2 kept
97/// assert_eq!(out.ncols(), 2);
98/// ```
99#[must_use]
100#[derive(Debug, Clone)]
101pub struct SelectFromModelExt<F> {
102 /// The threshold strategy.
103 threshold: ThresholdStrategy,
104 /// Optional cap on number of features to select.
105 max_features: Option<usize>,
106 _marker: std::marker::PhantomData<F>,
107}
108
109impl<F: Float + Send + Sync + 'static> SelectFromModelExt<F> {
110 /// Create a new `SelectFromModelExt`.
111 ///
112 /// # Parameters
113 ///
114 /// - `threshold` — the strategy for computing the importance threshold.
115 /// - `max_features` — optional maximum number of features to retain.
116 pub fn new(threshold: ThresholdStrategy, max_features: Option<usize>) -> Self {
117 Self {
118 threshold,
119 max_features,
120 _marker: std::marker::PhantomData,
121 }
122 }
123
124 /// Return the threshold strategy.
125 #[must_use]
126 pub fn threshold_strategy(&self) -> ThresholdStrategy {
127 self.threshold
128 }
129
130 /// Return the maximum number of features (if set).
131 #[must_use]
132 pub fn max_features(&self) -> Option<usize> {
133 self.max_features
134 }
135}
136
137impl<F: Float + Send + Sync + 'static> Default for SelectFromModelExt<F> {
138 fn default() -> Self {
139 Self::new(ThresholdStrategy::Mean, None)
140 }
141}
142
143// ---------------------------------------------------------------------------
144// FittedSelectFromModelExt
145// ---------------------------------------------------------------------------
146
147/// A fitted model-importance selector produced by [`SelectFromModelExt::fit`].
148#[derive(Debug, Clone)]
149pub struct FittedSelectFromModelExt<F> {
150 /// Number of features seen during fitting.
151 n_features_in: usize,
152 /// The computed threshold value.
153 threshold_value: F,
154 /// Feature importances supplied during fitting.
155 importances: Array1<F>,
156 /// Indices of selected columns (sorted).
157 selected_indices: Vec<usize>,
158}
159
160impl<F: Float + Send + Sync + 'static> FittedSelectFromModelExt<F> {
161 /// Return the computed threshold value.
162 #[must_use]
163 pub fn threshold_value(&self) -> F {
164 self.threshold_value
165 }
166
167 /// Return the feature importances.
168 #[must_use]
169 pub fn importances(&self) -> &Array1<F> {
170 &self.importances
171 }
172
173 /// Return the indices of the selected columns.
174 #[must_use]
175 pub fn selected_indices(&self) -> &[usize] {
176 &self.selected_indices
177 }
178
179 /// Return the number of selected features.
180 #[must_use]
181 pub fn n_features_selected(&self) -> usize {
182 self.selected_indices.len()
183 }
184}
185
186// ---------------------------------------------------------------------------
187// Helpers
188// ---------------------------------------------------------------------------
189
190/// Compute the median of a slice of floats.
191fn compute_median<F: Float>(values: &[F]) -> F {
192 let mut sorted: Vec<F> = values.to_vec();
193 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
194 let n = sorted.len();
195 if n.is_multiple_of(2) {
196 let two = F::one() + F::one();
197 (sorted[n / 2 - 1] + sorted[n / 2]) / two
198 } else {
199 sorted[n / 2]
200 }
201}
202
203/// Compute the percentile threshold. `pct` is the percentage of features to
204/// keep (e.g., 25.0 means top 25%).
205fn compute_percentile_threshold<F: Float>(values: &[F], pct: f64) -> F {
206 let mut sorted: Vec<F> = values.to_vec();
207 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
208 let n = sorted.len();
209 // The threshold is set at the (100 - pct) percentile of the sorted values.
210 // E.g., for top 25% we want the value at the 75th percentile.
211 let rank = ((100.0 - pct) / 100.0) * (n.saturating_sub(1)) as f64;
212 let lower = rank.floor() as usize;
213 let upper = rank.ceil() as usize;
214 let lower = lower.min(n.saturating_sub(1));
215 let upper = upper.min(n.saturating_sub(1));
216 if lower == upper {
217 sorted[lower]
218 } else {
219 let frac = F::from(rank - rank.floor()).unwrap_or_else(F::zero);
220 sorted[lower] * (F::one() - frac) + sorted[upper] * frac
221 }
222}
223
224/// Build a new `Array2<F>` containing only the columns listed in `indices`.
225fn select_columns<F: Float>(x: &Array2<F>, indices: &[usize]) -> Array2<F> {
226 let nrows = x.nrows();
227 let ncols = indices.len();
228 if ncols == 0 {
229 return Array2::zeros((nrows, 0));
230 }
231 let mut out = Array2::zeros((nrows, ncols));
232 for (new_j, &old_j) in indices.iter().enumerate() {
233 for i in 0..nrows {
234 out[[i, new_j]] = x[[i, old_j]];
235 }
236 }
237 out
238}
239
240// ---------------------------------------------------------------------------
241// Trait implementations
242// ---------------------------------------------------------------------------
243
244impl<F: Float + Send + Sync + 'static> Fit<Array1<F>, ()> for SelectFromModelExt<F> {
245 type Fitted = FittedSelectFromModelExt<F>;
246 type Error = FerroError;
247
248 /// Fit by computing the threshold from the given feature importances.
249 ///
250 /// # Parameters
251 ///
252 /// - `x` — per-feature importance scores (one value per feature).
253 /// - `_y` — ignored (unsupervised).
254 ///
255 /// # Errors
256 ///
257 /// - [`FerroError::InvalidParameter`] if the importance vector is empty,
258 /// or if `Percentile` value is not in `(0, 100]`.
259 fn fit(&self, x: &Array1<F>, _y: &()) -> Result<FittedSelectFromModelExt<F>, FerroError> {
260 let n = x.len();
261 if n == 0 {
262 return Err(FerroError::InvalidParameter {
263 name: "importances".into(),
264 reason: "importance vector must not be empty".into(),
265 });
266 }
267
268 let values: Vec<F> = x.iter().copied().collect();
269
270 // Compute threshold
271 let threshold_value = match self.threshold {
272 ThresholdStrategy::Mean => {
273 values.iter().copied().fold(F::zero(), |acc, v| acc + v)
274 / F::from(n).unwrap_or_else(F::one)
275 }
276 ThresholdStrategy::Median => compute_median(&values),
277 ThresholdStrategy::Value(v) => F::from(v).unwrap_or_else(F::zero),
278 ThresholdStrategy::Percentile(pct) => {
279 if pct <= 0.0 || pct > 100.0 {
280 return Err(FerroError::InvalidParameter {
281 name: "percentile".into(),
282 reason: format!("percentile must be in (0, 100], got {}", pct),
283 });
284 }
285 compute_percentile_threshold(&values, pct)
286 }
287 };
288
289 // Select features whose importance >= threshold
290 let mut selected_indices: Vec<usize> = values
291 .iter()
292 .enumerate()
293 .filter(|&(_, &imp)| imp >= threshold_value)
294 .map(|(j, _)| j)
295 .collect();
296
297 // Apply max_features cap: keep only the top-k by importance
298 if let Some(max_f) = self.max_features
299 && selected_indices.len() > max_f
300 {
301 // Sort selected by importance descending, keep top max_f
302 selected_indices.sort_by(|&a, &b| {
303 values[b]
304 .partial_cmp(&values[a])
305 .unwrap_or(std::cmp::Ordering::Equal)
306 });
307 selected_indices.truncate(max_f);
308 // Re-sort in column order
309 selected_indices.sort_unstable();
310 }
311
312 Ok(FittedSelectFromModelExt {
313 n_features_in: n,
314 threshold_value,
315 importances: x.clone(),
316 selected_indices,
317 })
318 }
319}
320
321impl<F: Float + Send + Sync + 'static> Transform<Array2<F>> for FittedSelectFromModelExt<F> {
322 type Output = Array2<F>;
323 type Error = FerroError;
324
325 /// Return a matrix containing only the selected columns.
326 ///
327 /// # Errors
328 ///
329 /// Returns [`FerroError::ShapeMismatch`] if the number of columns differs
330 /// from the number of features seen during fitting.
331 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
332 if x.ncols() != self.n_features_in {
333 return Err(FerroError::ShapeMismatch {
334 expected: vec![x.nrows(), self.n_features_in],
335 actual: vec![x.nrows(), x.ncols()],
336 context: "FittedSelectFromModelExt::transform".into(),
337 });
338 }
339 Ok(select_columns(x, &self.selected_indices))
340 }
341}
342
343// ---------------------------------------------------------------------------
344// Pipeline integration
345// ---------------------------------------------------------------------------
346
347impl<F: Float + Send + Sync + 'static> PipelineTransformer<F> for FittedSelectFromModelExt<F> {
348 /// Clone the fitted selector and box it as a pipeline transformer.
349 ///
350 /// Because the selector is already fitted (importances supplied at fit
351 /// time), `fit_pipeline` simply boxes the existing fitted state.
352 ///
353 /// # Errors
354 ///
355 /// This implementation never fails.
356 fn fit_pipeline(
357 &self,
358 _x: &Array2<F>,
359 _y: &Array1<F>,
360 ) -> Result<Box<dyn FittedPipelineTransformer<F>>, FerroError> {
361 Ok(Box::new(self.clone()))
362 }
363}
364
365impl<F: Float + Send + Sync + 'static> FittedPipelineTransformer<F>
366 for FittedSelectFromModelExt<F>
367{
368 /// Transform using the pipeline interface.
369 ///
370 /// # Errors
371 ///
372 /// Propagates errors from [`Transform::transform`].
373 fn transform_pipeline(&self, x: &Array2<F>) -> Result<Array2<F>, FerroError> {
374 self.transform(x)
375 }
376}
377
378// ---------------------------------------------------------------------------
379// Tests
380// ---------------------------------------------------------------------------
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use approx::assert_abs_diff_eq;
386 use ndarray::array;
387
388 #[test]
389 fn test_mean_threshold() {
390 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
391 let importances = array![0.1, 0.5, 0.4];
392 let fitted = sel.fit(&importances, &()).unwrap();
393 // Mean = (0.1+0.5+0.4)/3 ≈ 0.333; cols 1 and 2 kept
394 assert_eq!(fitted.selected_indices(), &[1, 2]);
395 let x = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
396 let out = fitted.transform(&x).unwrap();
397 assert_eq!(out.ncols(), 2);
398 assert_abs_diff_eq!(out[[0, 0]], 2.0, epsilon = 1e-15);
399 assert_abs_diff_eq!(out[[0, 1]], 3.0, epsilon = 1e-15);
400 }
401
402 #[test]
403 fn test_median_threshold() {
404 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Median, None);
405 // Sorted: [0.1, 0.3, 0.5] → median = 0.3
406 let importances = array![0.1, 0.5, 0.3];
407 let fitted = sel.fit(&importances, &()).unwrap();
408 // Features with importance >= 0.3: indices 1 (0.5) and 2 (0.3)
409 assert_eq!(fitted.selected_indices(), &[1, 2]);
410 }
411
412 #[test]
413 fn test_median_threshold_even() {
414 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Median, None);
415 // Sorted: [0.1, 0.2, 0.5, 0.6] → median = (0.2+0.5)/2 = 0.35
416 let importances = array![0.1, 0.5, 0.2, 0.6];
417 let fitted = sel.fit(&importances, &()).unwrap();
418 // Features >= 0.35: 1 (0.5) and 3 (0.6)
419 assert_eq!(fitted.selected_indices(), &[1, 3]);
420 }
421
422 #[test]
423 fn test_explicit_value_threshold() {
424 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.45), None);
425 let importances = array![0.1, 0.5, 0.4];
426 let fitted = sel.fit(&importances, &()).unwrap();
427 assert_eq!(fitted.selected_indices(), &[1]);
428 }
429
430 #[test]
431 fn test_percentile_threshold_top_50() {
432 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(50.0), None);
433 // Sorted: [0.1, 0.3, 0.5, 0.7]
434 // Top 50% → threshold at 50th percentile = sorted[1.5] interp = 0.4
435 let importances = array![0.5, 0.1, 0.7, 0.3];
436 let fitted = sel.fit(&importances, &()).unwrap();
437 // Features >= threshold: 0 (0.5), 2 (0.7)
438 assert!(fitted.selected_indices().contains(&0));
439 assert!(fitted.selected_indices().contains(&2));
440 assert_eq!(fitted.n_features_selected(), 2);
441 }
442
443 #[test]
444 fn test_percentile_100_keeps_all() {
445 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(100.0), None);
446 let importances = array![0.1, 0.5, 0.3];
447 let fitted = sel.fit(&importances, &()).unwrap();
448 assert_eq!(fitted.n_features_selected(), 3);
449 }
450
451 #[test]
452 fn test_percentile_invalid() {
453 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(0.0), None);
454 let importances = array![0.1, 0.5, 0.3];
455 assert!(sel.fit(&importances, &()).is_err());
456
457 let sel2 = SelectFromModelExt::<f64>::new(ThresholdStrategy::Percentile(101.0), None);
458 assert!(sel2.fit(&importances, &()).is_err());
459 }
460
461 #[test]
462 fn test_max_features_cap() {
463 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.0), Some(2));
464 // All features pass threshold=0, but max_features=2
465 let importances = array![0.3, 0.5, 0.1, 0.7];
466 let fitted = sel.fit(&importances, &()).unwrap();
467 assert_eq!(fitted.n_features_selected(), 2);
468 // Should keep top-2: indices 1 (0.5) and 3 (0.7)
469 assert_eq!(fitted.selected_indices(), &[1, 3]);
470 }
471
472 #[test]
473 fn test_max_features_not_needed() {
474 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.4), Some(5));
475 let importances = array![0.1, 0.5, 0.4];
476 let fitted = sel.fit(&importances, &()).unwrap();
477 // Only 2 pass threshold, max_features=5 doesn't limit
478 assert_eq!(fitted.n_features_selected(), 2);
479 }
480
481 #[test]
482 fn test_empty_importances_error() {
483 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
484 let importances: Array1<f64> = Array1::zeros(0);
485 assert!(sel.fit(&importances, &()).is_err());
486 }
487
488 #[test]
489 fn test_shape_mismatch_on_transform() {
490 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
491 let importances = array![0.5, 0.5];
492 let fitted = sel.fit(&importances, &()).unwrap();
493 let x_bad = array![[1.0, 2.0, 3.0]]; // 3 cols, 2 expected
494 assert!(fitted.transform(&x_bad).is_err());
495 }
496
497 #[test]
498 fn test_threshold_value_accessor() {
499 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(0.42), None);
500 let importances = array![0.1, 0.5];
501 let fitted = sel.fit(&importances, &()).unwrap();
502 assert_abs_diff_eq!(fitted.threshold_value(), 0.42, epsilon = 1e-15);
503 }
504
505 #[test]
506 fn test_default() {
507 let sel = SelectFromModelExt::<f64>::default();
508 assert_eq!(sel.threshold_strategy(), ThresholdStrategy::Mean);
509 assert_eq!(sel.max_features(), None);
510 }
511
512 #[test]
513 fn test_pipeline_integration() {
514 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Mean, None);
515 let importances = array![0.1, 0.9];
516 let fitted = sel.fit(&importances, &()).unwrap();
517 let x = array![[1.0, 2.0], [3.0, 4.0]];
518 let y = array![0.0, 1.0];
519 let fitted_box = fitted.fit_pipeline(&x, &y).unwrap();
520 let out = fitted_box.transform_pipeline(&x).unwrap();
521 assert_eq!(out.ncols(), 1);
522 }
523
524 #[test]
525 fn test_f32() {
526 let sel = SelectFromModelExt::<f32>::new(ThresholdStrategy::Mean, None);
527 let importances: Array1<f32> = array![0.1f32, 0.5, 0.4];
528 let fitted = sel.fit(&importances, &()).unwrap();
529 assert_eq!(fitted.n_features_selected(), 2);
530 }
531
532 #[test]
533 fn test_none_selected_high_threshold() {
534 let sel = SelectFromModelExt::<f64>::new(ThresholdStrategy::Value(10.0), None);
535 let importances = array![0.1, 0.5, 0.4];
536 let fitted = sel.fit(&importances, &()).unwrap();
537 assert_eq!(fitted.n_features_selected(), 0);
538 let x = array![[1.0, 2.0, 3.0]];
539 let out = fitted.transform(&x).unwrap();
540 assert_eq!(out.ncols(), 0);
541 assert_eq!(out.nrows(), 1);
542 }
543}