quantrs2_ml/sklearn_compatibility/
feature_selection.rs1use super::SklearnEstimator;
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2};
6use std::collections::HashMap;
7
8pub struct SelectKBest {
10 score_func: String,
11 k: usize,
12 fitted: bool,
13 selected_features_: Option<Vec<usize>>,
14}
15
16impl SelectKBest {
17 pub fn new(score_func: &str, k: usize) -> Self {
18 Self {
19 score_func: score_func.to_string(),
20 k,
21 fitted: false,
22 selected_features_: None,
23 }
24 }
25
26 pub fn get_support(&self) -> Option<&Vec<usize>> {
28 self.selected_features_.as_ref()
29 }
30}
31
32impl SklearnEstimator for SelectKBest {
33 #[allow(non_snake_case)]
34 fn fit(&mut self, X: &Array2<f64>, _y: Option<&Array1<f64>>) -> Result<()> {
35 let features: Vec<usize> = (0..self.k.min(X.ncols())).collect();
37 self.selected_features_ = Some(features);
38 self.fitted = true;
39 Ok(())
40 }
41
42 fn get_params(&self) -> HashMap<String, String> {
43 let mut params = HashMap::new();
44 params.insert("score_func".to_string(), self.score_func.clone());
45 params.insert("k".to_string(), self.k.to_string());
46 params
47 }
48
49 fn set_params(&mut self, params: HashMap<String, String>) -> Result<()> {
50 for (key, value) in params {
51 match key.as_str() {
52 "k" => {
53 self.k = value.parse().map_err(|_| {
54 MLError::InvalidConfiguration(format!("Invalid k parameter: {}", value))
55 })?;
56 }
57 "score_func" => {
58 self.score_func = value;
59 }
60 _ => {}
61 }
62 }
63 Ok(())
64 }
65
66 fn is_fitted(&self) -> bool {
67 self.fitted
68 }
69}
70
71pub struct VarianceThreshold {
73 threshold: f64,
75 variances: Option<Array1<f64>>,
77 mask: Option<Vec<bool>>,
79}
80
81impl VarianceThreshold {
82 pub fn new(threshold: f64) -> Self {
84 Self {
85 threshold,
86 variances: None,
87 mask: None,
88 }
89 }
90
91 #[allow(non_snake_case)]
93 pub fn fit(&mut self, X: &Array2<f64>) -> Result<()> {
94 let n_features = X.ncols();
95 let n_samples = X.nrows() as f64;
96 let mut variances = Array1::zeros(n_features);
97 let mut mask = vec![false; n_features];
98
99 for j in 0..n_features {
100 let mean = X.column(j).sum() / n_samples;
102 let var = X.column(j).mapv(|x| (x - mean).powi(2)).sum() / n_samples;
104 variances[j] = var;
105 mask[j] = var > self.threshold;
106 }
107
108 self.variances = Some(variances);
109 self.mask = Some(mask);
110 Ok(())
111 }
112
113 #[allow(non_snake_case)]
115 pub fn transform(&self, X: &Array2<f64>) -> Result<Array2<f64>> {
116 let mask = self
117 .mask
118 .as_ref()
119 .ok_or_else(|| MLError::ModelNotTrained("VarianceThreshold not fitted".to_string()))?;
120
121 let selected_cols: Vec<usize> = mask
122 .iter()
123 .enumerate()
124 .filter_map(|(i, &m)| if m { Some(i) } else { None })
125 .collect();
126
127 if selected_cols.is_empty() {
128 return Err(MLError::InvalidConfiguration(
129 "No features selected".to_string(),
130 ));
131 }
132
133 let mut result = Array2::zeros((X.nrows(), selected_cols.len()));
134 for (new_j, &old_j) in selected_cols.iter().enumerate() {
135 for i in 0..X.nrows() {
136 result[[i, new_j]] = X[[i, old_j]];
137 }
138 }
139
140 Ok(result)
141 }
142
143 #[allow(non_snake_case)]
145 pub fn fit_transform(&mut self, X: &Array2<f64>) -> Result<Array2<f64>> {
146 self.fit(X)?;
147 self.transform(X)
148 }
149
150 pub fn variances(&self) -> Option<&Array1<f64>> {
152 self.variances.as_ref()
153 }
154
155 pub fn get_support(&self) -> Option<&Vec<bool>> {
157 self.mask.as_ref()
158 }
159}