Skip to main content

anofox_ml_preprocessing/
select_from_model.rs

1use anofox_ml_core::{Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3
4/// Parameters for `SelectFromModel` feature selector (unfitted state).
5///
6/// Selects features based on a pre-computed importance vector (e.g., from a
7/// fitted `RandomForestClassifier`'s
8/// `feature_importances()`).
9///
10/// Features can be selected in two ways:
11/// - **Threshold**: keep all features with importance >= threshold.
12/// - **Max features**: keep the top `max_features` features by importance.
13///
14/// If both are specified, threshold is applied first and then the result is
15/// capped to `max_features`.
16///
17/// Since this selector does not learn from raw data (it uses pre-computed
18/// importances), it exposes a custom `fit` method instead of implementing the
19/// standard [`FitUnsupervised`](anofox_ml_core::FitUnsupervised) trait.
20///
21/// # Example
22///
23/// ```
24/// use anofox_ml_preprocessing::SelectFromModel;
25/// use anofox_ml_core::Transform;
26/// use ndarray::array;
27///
28/// let importances = array![0.05, 0.40, 0.10, 0.45];
29///
30/// // Select features with importance >= 0.20
31/// let selector = SelectFromModel::new().with_threshold(0.20);
32/// let fitted = selector.fit(&importances).unwrap();
33///
34/// assert_eq!(fitted.selected_indices(), &[1, 3]);
35///
36/// let x = array![
37///     [1.0, 2.0, 3.0, 4.0],
38///     [5.0, 6.0, 7.0, 8.0],
39/// ];
40/// let x_selected = fitted.transform(&x).unwrap();
41/// assert_eq!(x_selected.ncols(), 2);
42/// ```
43#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
44pub struct SelectFromModel {
45    /// Features with importance >= threshold are selected.
46    /// If `None`, no threshold filtering is applied.
47    pub threshold: Option<f64>,
48    /// Maximum number of features to select (top by importance).
49    /// If `None`, no cap is applied.
50    pub max_features: Option<usize>,
51}
52
53impl SelectFromModel {
54    /// Create a new `SelectFromModel` with no threshold and no feature cap.
55    ///
56    /// At least one of `threshold` or `max_features` must be set before calling
57    /// [`fit`](Self::fit).
58    pub fn new() -> Self {
59        Self {
60            threshold: None,
61            max_features: None,
62        }
63    }
64
65    /// Set the minimum importance threshold.
66    pub fn with_threshold(mut self, threshold: f64) -> Self {
67        self.threshold = Some(threshold);
68        self
69    }
70
71    /// Set the maximum number of features to select.
72    pub fn with_max_features(mut self, max_features: usize) -> Self {
73        self.max_features = Some(max_features);
74        self
75    }
76
77    /// Fit the selector on a pre-computed feature importance vector.
78    ///
79    /// Returns an error if the importance vector is empty or if neither
80    /// `threshold` nor `max_features` is set.
81    pub fn fit(&self, importances: &Array1<f64>) -> Result<FittedSelectFromModel> {
82        let n_features = importances.len();
83
84        if n_features == 0 {
85            return Err(RustMlError::EmptyInput(
86                "importances vector is empty".into(),
87            ));
88        }
89
90        if self.threshold.is_none() && self.max_features.is_none() {
91            return Err(RustMlError::InvalidParameter(
92                "at least one of threshold or max_features must be set".into(),
93            ));
94        }
95
96        if let Some(max_f) = self.max_features {
97            if max_f == 0 {
98                return Err(RustMlError::InvalidParameter(
99                    "max_features must be at least 1".into(),
100                ));
101            }
102        }
103
104        // Step 1: Apply threshold filter.
105        let mut candidates: Vec<(usize, f64)> = if let Some(thresh) = self.threshold {
106            importances
107                .iter()
108                .copied()
109                .enumerate()
110                .filter(|&(_, imp)| imp >= thresh)
111                .collect()
112        } else {
113            importances.iter().copied().enumerate().collect()
114        };
115
116        // Step 2: If max_features is set, keep only the top-N by importance.
117        if let Some(max_f) = self.max_features {
118            if candidates.len() > max_f {
119                // Sort descending by importance; break ties by index (ascending).
120                candidates.sort_by(|a, b| {
121                    b.1.partial_cmp(&a.1)
122                        .unwrap_or(std::cmp::Ordering::Equal)
123                        .then(a.0.cmp(&b.0))
124                });
125                candidates.truncate(max_f);
126            }
127        }
128
129        if candidates.is_empty() {
130            return Err(RustMlError::InvalidParameter(
131                "no features meet the selection criteria".into(),
132            ));
133        }
134
135        // Sort by index for stable column ordering.
136        let mut selected_indices: Vec<usize> = candidates.iter().map(|&(idx, _)| idx).collect();
137        selected_indices.sort_unstable();
138
139        Ok(FittedSelectFromModel {
140            importances: importances.clone(),
141            selected_indices,
142            n_features_in: n_features,
143        })
144    }
145}
146
147impl Default for SelectFromModel {
148    fn default() -> Self {
149        Self::new()
150    }
151}
152
153/// Fitted `SelectFromModel` -- holds the original importances and the indices
154/// of the selected features.
155#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
156pub struct FittedSelectFromModel {
157    /// Original per-feature importance vector.
158    importances: Array1<f64>,
159    /// Indices of the selected features, sorted ascending.
160    selected_indices: Vec<usize>,
161    /// Total number of input features (before selection).
162    n_features_in: usize,
163}
164
165impl FittedSelectFromModel {
166    /// Per-feature importances supplied during fitting.
167    pub fn importances(&self) -> &Array1<f64> {
168        &self.importances
169    }
170
171    /// Indices of the selected features, sorted in ascending order.
172    pub fn selected_indices(&self) -> &[usize] {
173        &self.selected_indices
174    }
175
176    /// Number of features that survived selection.
177    pub fn n_features_selected(&self) -> usize {
178        self.selected_indices.len()
179    }
180}
181
182impl<F: Float> Transform<F> for FittedSelectFromModel {
183    fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
184        if x.ncols() != self.n_features_in {
185            return Err(RustMlError::ShapeMismatch(format!(
186                "expected {} features, got {}",
187                self.n_features_in,
188                x.ncols()
189            )));
190        }
191
192        let n_rows = x.nrows();
193        let n_selected = self.selected_indices.len();
194        let mut result = Array2::<F>::zeros((n_rows, n_selected));
195
196        for (i, row) in x.rows().into_iter().enumerate() {
197            for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
198                result[[i, out_j]] = row[src_j];
199            }
200        }
201
202        Ok(result)
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use ndarray::array;
210
211    #[test]
212    fn test_threshold_selects_important_features() {
213        let importances = array![0.05, 0.40, 0.10, 0.45];
214
215        let selector = SelectFromModel::new().with_threshold(0.20);
216        let fitted = selector.fit(&importances).unwrap();
217
218        assert_eq!(fitted.selected_indices(), &[1, 3]);
219    }
220
221    #[test]
222    fn test_max_features_selects_top_n() {
223        let importances = array![0.1, 0.5, 0.3, 0.8, 0.2];
224
225        let selector = SelectFromModel::new().with_max_features(2);
226        let fitted = selector.fit(&importances).unwrap();
227
228        // Top 2 by importance: index 3 (0.8) and index 1 (0.5), sorted -> [1, 3]
229        assert_eq!(fitted.selected_indices(), &[1, 3]);
230    }
231
232    #[test]
233    fn test_threshold_and_max_features_combined() {
234        let importances = array![0.05, 0.40, 0.30, 0.45, 0.35];
235
236        // Threshold 0.20 keeps indices [1, 2, 3, 4], then max_features=2 keeps top 2.
237        let selector = SelectFromModel::new()
238            .with_threshold(0.20)
239            .with_max_features(2);
240        let fitted = selector.fit(&importances).unwrap();
241
242        // Top 2 after threshold: index 3 (0.45) and index 1 (0.40), sorted -> [1, 3]
243        assert_eq!(fitted.selected_indices(), &[1, 3]);
244    }
245
246    #[test]
247    fn test_transform_selects_correct_columns() {
248        let importances = array![0.1, 0.9, 0.5];
249
250        let selector = SelectFromModel::new().with_max_features(2);
251        let fitted = selector.fit(&importances).unwrap();
252
253        // Should select indices 1 (0.9) and 2 (0.5), sorted -> [1, 2]
254        assert_eq!(fitted.selected_indices(), &[1, 2]);
255
256        let x = array![[10.0, 20.0, 30.0], [40.0, 50.0, 60.0],];
257        let result = fitted.transform(&x).unwrap();
258
259        assert_eq!(result.dim(), (2, 2));
260        assert_eq!(result[[0, 0]], 20.0);
261        assert_eq!(result[[0, 1]], 30.0);
262        assert_eq!(result[[1, 0]], 50.0);
263        assert_eq!(result[[1, 1]], 60.0);
264    }
265
266    #[test]
267    fn test_error_no_criteria_set() {
268        let importances = array![0.1, 0.2, 0.3];
269
270        let selector = SelectFromModel::new(); // neither threshold nor max_features
271        let result = selector.fit(&importances);
272        assert!(result.is_err());
273        match result.unwrap_err() {
274            RustMlError::InvalidParameter(msg) => {
275                assert!(
276                    msg.contains("threshold") || msg.contains("max_features"),
277                    "unexpected message: {}",
278                    msg
279                );
280            }
281            other => panic!("expected InvalidParameter, got {:?}", other),
282        }
283    }
284
285    #[test]
286    fn test_error_no_features_survive_threshold() {
287        let importances = array![0.01, 0.02, 0.03];
288
289        let selector = SelectFromModel::new().with_threshold(0.50);
290        let result = selector.fit(&importances);
291        assert!(result.is_err());
292        match result.unwrap_err() {
293            RustMlError::InvalidParameter(msg) => {
294                assert!(msg.contains("no features"), "unexpected message: {}", msg);
295            }
296            other => panic!("expected InvalidParameter, got {:?}", other),
297        }
298    }
299
300    #[test]
301    fn test_error_empty_importances() {
302        let importances = Array1::<f64>::zeros(0);
303
304        let selector = SelectFromModel::new().with_threshold(0.0);
305        let result = selector.fit(&importances);
306        assert!(result.is_err());
307    }
308
309    #[test]
310    fn test_shape_mismatch_on_transform() {
311        let importances = array![0.5, 0.5, 0.5];
312
313        let selector = SelectFromModel::new().with_threshold(0.0);
314        let fitted = selector.fit(&importances).unwrap();
315
316        let wrong = array![[1.0, 2.0]]; // 2 cols instead of 3
317        assert!(Transform::<f64>::transform(&fitted, &wrong).is_err());
318    }
319
320    #[test]
321    fn test_works_with_f32_transform() {
322        let importances = array![0.1, 0.9];
323
324        let selector = SelectFromModel::new().with_max_features(1);
325        let fitted = selector.fit(&importances).unwrap();
326
327        assert_eq!(fitted.selected_indices(), &[1]);
328
329        let x: Array2<f32> = array![[1.0_f32, 2.0], [3.0, 4.0]];
330        let result = Transform::<f32>::transform(&fitted, &x).unwrap();
331        assert_eq!(result.dim(), (2, 1));
332        assert_eq!(result[[0, 0]], 2.0_f32);
333    }
334
335    #[test]
336    fn test_max_features_zero_is_error() {
337        let importances = array![0.1, 0.2];
338
339        let selector = SelectFromModel::new().with_max_features(0);
340        let result = selector.fit(&importances);
341        assert!(result.is_err());
342    }
343
344    #[test]
345    fn test_n_features_selected() {
346        let importances = array![0.1, 0.5, 0.3, 0.8];
347
348        let selector = SelectFromModel::new().with_threshold(0.25);
349        let fitted = selector.fit(&importances).unwrap();
350
351        assert_eq!(fitted.n_features_selected(), 3); // indices 1, 2, 3
352        assert_eq!(fitted.selected_indices(), &[1, 2, 3]);
353    }
354}