quantrs2_ml/sklearn_compatibility/
feature_selection.rs

1//! Sklearn-compatible feature selection algorithms
2
3use super::SklearnEstimator;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2};
6use std::collections::HashMap;
7
8/// Select K Best features (sklearn-compatible)
9pub struct SelectKBest {
10    score_func: String,
11    k: usize,
12    fitted: bool,
13    selected_features_: Option<Vec<usize>>,
14}
15
16impl SelectKBest {
17    pub fn new(score_func: &str, k: usize) -> Self {
18        Self {
19            score_func: score_func.to_string(),
20            k,
21            fitted: false,
22            selected_features_: None,
23        }
24    }
25
26    /// Get selected features
27    pub fn get_support(&self) -> Option<&Vec<usize>> {
28        self.selected_features_.as_ref()
29    }
30}
31
32impl SklearnEstimator for SelectKBest {
33    #[allow(non_snake_case)]
34    fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
35        // Mock implementation - select first k features
36        let features: Vec<usize> = (0..self.k.min(X.ncols())).collect();
37        self.selected_features_ = Some(features);
38        self.fitted = true;
39        Ok(())
40    }
41
42    fn get_params(&self) -> HashMap<String, String> {
43        let mut params = HashMap::new();
44        params.insert("score_func".to_string(), self.score_func.clone());
45        params.insert("k".to_string(), self.k.to_string());
46        params
47    }
48
49    fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
50        for (key, value) in params {
51            match key.as_str() {
52                "k" => {
53                    self.k = value.parse().map_err(|_| {
54                        MLError::InvalidConfiguration(format!("Invalid k parameter: {}", value))
55                    })?;
56                }
57                "score_func" => {
58                    self.score_func = value;
59                }
60                _ => {}
61            }
62        }
63        Ok(())
64    }
65
66    fn is_fitted(&self) -> bool {
67        self.fitted
68    }
69}
70
71/// Variance threshold feature selector
72pub struct VarianceThreshold {
73    /// Threshold
74    threshold: f64,
75    /// Variances
76    variances: Option<Array1<f64>>,
77    /// Mask of selected features
78    mask: Option<Vec<bool>>,
79}
80
81impl VarianceThreshold {
82    /// Create new VarianceThreshold
83    pub fn new(threshold: f64) -> Self {
84        Self {
85            threshold,
86            variances: None,
87            mask: None,
88        }
89    }
90
91    /// Fit the selector
92    #[allow(non_snake_case)]
93    pub fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
94        let n_features = X.ncols();
95        let n_samples = X.nrows() as f64;
96        let mut variances = Array1::zeros(n_features);
97        let mut mask = vec![false; n_features];
98
99        for j in 0..n_features {
100            // Compute mean
101            let mean = X.column(j).sum() / n_samples;
102            // Compute variance
103            let var = X.column(j).mapv(|x| (x - mean).powi(2)).sum() / n_samples;
104            variances[j] = var;
105            mask[j] = var > self.threshold;
106        }
107
108        self.variances = Some(variances);
109        self.mask = Some(mask);
110        Ok(())
111    }
112
113    /// Transform the data
114    #[allow(non_snake_case)]
115    pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
116        let mask = self
117            .mask
118            .as_ref()
119            .ok_or_else(|| MLError::ModelNotTrained("VarianceThreshold not fitted".to_string()))?;
120
121        let selected_cols: Vec<usize> = mask
122            .iter()
123            .enumerate()
124            .filter_map(|(i, &m)| if m { Some(i) } else { None })
125            .collect();
126
127        if selected_cols.is_empty() {
128            return Err(MLError::InvalidConfiguration(
129                "No features selected".to_string(),
130            ));
131        }
132
133        let mut result = Array2::zeros((X.nrows(), selected_cols.len()));
134        for (new_j, &old_j) in selected_cols.iter().enumerate() {
135            for i in 0..X.nrows() {
136                result[[i, new_j]] = X[[i, old_j]];
137            }
138        }
139
140        Ok(result)
141    }
142
143    /// Fit and transform
144    #[allow(non_snake_case)]
145    pub fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
146        self.fit(X)?;
147        self.transform(X)
148    }
149
150    /// Get variances
151    pub fn variances(&self) -> Option<&Array1<f64>> {
152        self.variances.as_ref()
153    }
154
155    /// Get feature mask
156    pub fn get_support(&self) -> Option<&Vec<bool>> {
157        self.mask.as_ref()
158    }
159}