1use anofox_ml_core::{Float, Result, RustMlError, Transform};
2use ndarray::{Array1, Array2};
3
4#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
44pub struct SelectFromModel {
45 pub threshold: Option<f64>,
48 pub max_features: Option<usize>,
51}
52
53impl SelectFromModel {
54 pub fn new() -> Self {
59 Self {
60 threshold: None,
61 max_features: None,
62 }
63 }
64
65 pub fn with_threshold(mut self, threshold: f64) -> Self {
67 self.threshold = Some(threshold);
68 self
69 }
70
71 pub fn with_max_features(mut self, max_features: usize) -> Self {
73 self.max_features = Some(max_features);
74 self
75 }
76
77 pub fn fit(&self, importances: &Array1<f64>) -> Result<FittedSelectFromModel> {
82 let n_features = importances.len();
83
84 if n_features == 0 {
85 return Err(RustMlError::EmptyInput(
86 "importances vector is empty".into(),
87 ));
88 }
89
90 if self.threshold.is_none() && self.max_features.is_none() {
91 return Err(RustMlError::InvalidParameter(
92 "at least one of threshold or max_features must be set".into(),
93 ));
94 }
95
96 if let Some(max_f) = self.max_features {
97 if max_f == 0 {
98 return Err(RustMlError::InvalidParameter(
99 "max_features must be at least 1".into(),
100 ));
101 }
102 }
103
104 let mut candidates: Vec<(usize, f64)> = if let Some(thresh) = self.threshold {
106 importances
107 .iter()
108 .copied()
109 .enumerate()
110 .filter(|&(_, imp)| imp >= thresh)
111 .collect()
112 } else {
113 importances.iter().copied().enumerate().collect()
114 };
115
116 if let Some(max_f) = self.max_features {
118 if candidates.len() > max_f {
119 candidates.sort_by(|a, b| {
121 b.1.partial_cmp(&a.1)
122 .unwrap_or(std::cmp::Ordering::Equal)
123 .then(a.0.cmp(&b.0))
124 });
125 candidates.truncate(max_f);
126 }
127 }
128
129 if candidates.is_empty() {
130 return Err(RustMlError::InvalidParameter(
131 "no features meet the selection criteria".into(),
132 ));
133 }
134
135 let mut selected_indices: Vec<usize> = candidates.iter().map(|&(idx, _)| idx).collect();
137 selected_indices.sort_unstable();
138
139 Ok(FittedSelectFromModel {
140 importances: importances.clone(),
141 selected_indices,
142 n_features_in: n_features,
143 })
144 }
145}
146
147impl Default for SelectFromModel {
148 fn default() -> Self {
149 Self::new()
150 }
151}
152
153#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
156pub struct FittedSelectFromModel {
157 importances: Array1<f64>,
159 selected_indices: Vec<usize>,
161 n_features_in: usize,
163}
164
165impl FittedSelectFromModel {
166 pub fn importances(&self) -> &Array1<f64> {
168 &self.importances
169 }
170
171 pub fn selected_indices(&self) -> &[usize] {
173 &self.selected_indices
174 }
175
176 pub fn n_features_selected(&self) -> usize {
178 self.selected_indices.len()
179 }
180}
181
182impl<F: Float> Transform<F> for FittedSelectFromModel {
183 fn transform(&self, x: &Array2<F>) -> Result<Array2<F>> {
184 if x.ncols() != self.n_features_in {
185 return Err(RustMlError::ShapeMismatch(format!(
186 "expected {} features, got {}",
187 self.n_features_in,
188 x.ncols()
189 )));
190 }
191
192 let n_rows = x.nrows();
193 let n_selected = self.selected_indices.len();
194 let mut result = Array2::<F>::zeros((n_rows, n_selected));
195
196 for (i, row) in x.rows().into_iter().enumerate() {
197 for (out_j, &src_j) in self.selected_indices.iter().enumerate() {
198 result[[i, out_j]] = row[src_j];
199 }
200 }
201
202 Ok(result)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use ndarray::array;
210
211 #[test]
212 fn test_threshold_selects_important_features() {
213 let importances = array![0.05, 0.40, 0.10, 0.45];
214
215 let selector = SelectFromModel::new().with_threshold(0.20);
216 let fitted = selector.fit(&importances).unwrap();
217
218 assert_eq!(fitted.selected_indices(), &[1, 3]);
219 }
220
221 #[test]
222 fn test_max_features_selects_top_n() {
223 let importances = array![0.1, 0.5, 0.3, 0.8, 0.2];
224
225 let selector = SelectFromModel::new().with_max_features(2);
226 let fitted = selector.fit(&importances).unwrap();
227
228 assert_eq!(fitted.selected_indices(), &[1, 3]);
230 }
231
232 #[test]
233 fn test_threshold_and_max_features_combined() {
234 let importances = array![0.05, 0.40, 0.30, 0.45, 0.35];
235
236 let selector = SelectFromModel::new()
238 .with_threshold(0.20)
239 .with_max_features(2);
240 let fitted = selector.fit(&importances).unwrap();
241
242 assert_eq!(fitted.selected_indices(), &[1, 3]);
244 }
245
246 #[test]
247 fn test_transform_selects_correct_columns() {
248 let importances = array![0.1, 0.9, 0.5];
249
250 let selector = SelectFromModel::new().with_max_features(2);
251 let fitted = selector.fit(&importances).unwrap();
252
253 assert_eq!(fitted.selected_indices(), &[1, 2]);
255
256 let x = array![[10.0, 20.0, 30.0], [40.0, 50.0, 60.0],];
257 let result = fitted.transform(&x).unwrap();
258
259 assert_eq!(result.dim(), (2, 2));
260 assert_eq!(result[[0, 0]], 20.0);
261 assert_eq!(result[[0, 1]], 30.0);
262 assert_eq!(result[[1, 0]], 50.0);
263 assert_eq!(result[[1, 1]], 60.0);
264 }
265
266 #[test]
267 fn test_error_no_criteria_set() {
268 let importances = array![0.1, 0.2, 0.3];
269
270 let selector = SelectFromModel::new(); let result = selector.fit(&importances);
272 assert!(result.is_err());
273 match result.unwrap_err() {
274 RustMlError::InvalidParameter(msg) => {
275 assert!(
276 msg.contains("threshold") || msg.contains("max_features"),
277 "unexpected message: {}",
278 msg
279 );
280 }
281 other => panic!("expected InvalidParameter, got {:?}", other),
282 }
283 }
284
285 #[test]
286 fn test_error_no_features_survive_threshold() {
287 let importances = array![0.01, 0.02, 0.03];
288
289 let selector = SelectFromModel::new().with_threshold(0.50);
290 let result = selector.fit(&importances);
291 assert!(result.is_err());
292 match result.unwrap_err() {
293 RustMlError::InvalidParameter(msg) => {
294 assert!(msg.contains("no features"), "unexpected message: {}", msg);
295 }
296 other => panic!("expected InvalidParameter, got {:?}", other),
297 }
298 }
299
300 #[test]
301 fn test_error_empty_importances() {
302 let importances = Array1::<f64>::zeros(0);
303
304 let selector = SelectFromModel::new().with_threshold(0.0);
305 let result = selector.fit(&importances);
306 assert!(result.is_err());
307 }
308
309 #[test]
310 fn test_shape_mismatch_on_transform() {
311 let importances = array![0.5, 0.5, 0.5];
312
313 let selector = SelectFromModel::new().with_threshold(0.0);
314 let fitted = selector.fit(&importances).unwrap();
315
316 let wrong = array![[1.0, 2.0]]; assert!(Transform::<f64>::transform(&fitted, &wrong).is_err());
318 }
319
320 #[test]
321 fn test_works_with_f32_transform() {
322 let importances = array![0.1, 0.9];
323
324 let selector = SelectFromModel::new().with_max_features(1);
325 let fitted = selector.fit(&importances).unwrap();
326
327 assert_eq!(fitted.selected_indices(), &[1]);
328
329 let x: Array2<f32> = array![[1.0_f32, 2.0], [3.0, 4.0]];
330 let result = Transform::<f32>::transform(&fitted, &x).unwrap();
331 assert_eq!(result.dim(), (2, 1));
332 assert_eq!(result[[0, 0]], 2.0_f32);
333 }
334
335 #[test]
336 fn test_max_features_zero_is_error() {
337 let importances = array![0.1, 0.2];
338
339 let selector = SelectFromModel::new().with_max_features(0);
340 let result = selector.fit(&importances);
341 assert!(result.is_err());
342 }
343
344 #[test]
345 fn test_n_features_selected() {
346 let importances = array![0.1, 0.5, 0.3, 0.8];
347
348 let selector = SelectFromModel::new().with_threshold(0.25);
349 let fitted = selector.fit(&importances).unwrap();
350
351 assert_eq!(fitted.n_features_selected(), 3); assert_eq!(fitted.selected_indices(), &[1, 2, 3]);
353 }
354}