quantrs2_ml/sklearn_compatibility/
model_selection.rs1use super::{SklearnClassifier, SklearnEstimator};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_core::random::prelude::*;
7use std::collections::HashMap;
8
9#[allow(non_snake_case)]
11pub fn cross_val_score<E>(
12 estimator: &mut E,
13 X: &Array2<f64>,
14 y: &Array1<f64>,
15 cv: usize,
16) -> Result<Array1<f64>>
17where
18 E: SklearnClassifier,
19{
20 let n_samples = X.nrows();
21 let fold_size = n_samples / cv;
22 let mut scores = Array1::zeros(cv);
23
24 let mut indices: Vec<usize> = (0..n_samples).collect();
26 indices.shuffle(&mut thread_rng());
27
28 for fold in 0..cv {
29 let start_test = fold * fold_size;
30 let end_test = if fold == cv - 1 {
31 n_samples
32 } else {
33 (fold + 1) * fold_size
34 };
35
36 let test_indices = &indices[start_test..end_test];
38 let train_indices: Vec<usize> = indices
39 .iter()
40 .enumerate()
41 .filter(|(i, _)| *i < start_test || *i >= end_test)
42 .map(|(_, &idx)| idx)
43 .collect();
44
45 let X_train = X.select(Axis(0), &train_indices);
47 let y_train = y.select(Axis(0), &train_indices);
48 let X_test = X.select(Axis(0), test_indices);
49 let y_test = y.select(Axis(0), test_indices);
50
51 let y_test_int = y_test.mapv(|x| x.round() as i32);
53
54 estimator.fit(&X_train, Some(&y_train))?;
56 scores[fold] = estimator.score(&X_test, &y_test_int)?;
57 }
58
59 Ok(scores)
60}
61
62#[allow(non_snake_case)]
64pub fn train_test_split(
65 X: &Array2<f64>,
66 y: &Array1<f64>,
67 test_size: f64,
68 random_state: Option<u64>,
69) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>, Array1<f64>)> {
70 let n_samples = X.nrows();
71 let n_test = (n_samples as f64 * test_size).round() as usize;
72
73 let mut indices: Vec<usize> = (0..n_samples).collect();
75
76 if let Some(seed) = random_state {
77 let mut rng = StdRng::seed_from_u64(seed);
78 indices.shuffle(&mut rng);
79 } else {
80 indices.shuffle(&mut thread_rng());
81 }
82
83 let test_indices = &indices[..n_test];
84 let train_indices = &indices[n_test..];
85
86 let X_train = X.select(Axis(0), train_indices);
87 let X_test = X.select(Axis(0), test_indices);
88 let y_train = y.select(Axis(0), train_indices);
89 let y_test = y.select(Axis(0), test_indices);
90
91 Ok((X_train, X_test, y_train, y_test))
92}
93
94pub struct GridSearchCV<E> {
96 estimator: E,
98 param_grid: HashMap<String, Vec<String>>,
100 cv: usize,
102 pub best_params_: HashMap<String, String>,
104 pub best_score_: f64,
106 pub best_estimator_: E,
108 fitted: bool,
110}
111
112impl<E> GridSearchCV<E>
113where
114 E: SklearnClassifier + Clone,
115{
116 pub fn new(estimator: E, param_grid: HashMap<String, Vec<String>>, cv: usize) -> Self {
118 Self {
119 best_estimator_: estimator.clone(),
120 estimator,
121 param_grid,
122 cv,
123 best_params_: HashMap::new(),
124 best_score_: f64::NEG_INFINITY,
125 fitted: false,
126 }
127 }
128
129 #[allow(non_snake_case)]
131 pub fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
132 let param_combinations = self.generate_param_combinations();
133
134 for params in param_combinations {
135 let mut estimator = self.estimator.clone();
136 estimator.set_params(params.clone())?;
137
138 let scores = cross_val_score(&mut estimator, X, y, self.cv)?;
139 let mean_score = scores.mean().unwrap_or(0.0);
140
141 if mean_score > self.best_score_ {
142 self.best_score_ = mean_score;
143 self.best_params_ = params.clone();
144 self.best_estimator_ = estimator;
145 }
146 }
147
148 if !self.best_params_.is_empty() {
150 self.best_estimator_.set_params(self.best_params_.clone())?;
151 self.best_estimator_.fit(X, Some(y))?;
152 }
153
154 self.fitted = true;
155 Ok(())
156 }
157
158 fn generate_param_combinations(&self) -> Vec<HashMap<String, String>> {
160 let mut combinations = vec![HashMap::new()];
161
162 for (param_name, param_values) in &self.param_grid {
163 let mut new_combinations = Vec::new();
164
165 for combination in &combinations {
166 for value in param_values {
167 let mut new_combination = combination.clone();
168 new_combination.insert(param_name.clone(), value.clone());
169 new_combinations.push(new_combination);
170 }
171 }
172
173 combinations = new_combinations;
174 }
175
176 combinations
177 }
178
179 pub fn best_params(&self) -> &HashMap<String, String> {
181 &self.best_params_
182 }
183
184 pub fn best_score(&self) -> f64 {
186 self.best_score_
187 }
188
189 #[allow(non_snake_case)]
191 pub fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
192 if !self.fitted {
193 return Err(MLError::ModelNotTrained("Model not trained".to_string()));
194 }
195 self.best_estimator_.predict(X)
196 }
197}
198
199pub struct KFold {
201 n_splits: usize,
203 shuffle: bool,
205 random_state: Option<u64>,
207}
208
209impl KFold {
210 pub fn new(n_splits: usize) -> Self {
212 Self {
213 n_splits,
214 shuffle: false,
215 random_state: None,
216 }
217 }
218
219 pub fn shuffle(mut self, shuffle: bool) -> Self {
221 self.shuffle = shuffle;
222 self
223 }
224
225 pub fn random_state(mut self, random_state: u64) -> Self {
227 self.random_state = Some(random_state);
228 self
229 }
230
231 pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
233 let mut indices: Vec<usize> = (0..n_samples).collect();
234
235 if self.shuffle {
236 if let Some(seed) = self.random_state {
237 fastrand::seed(seed);
238 }
239 for i in (1..indices.len()).rev() {
240 let j = fastrand::usize(0..=i);
241 indices.swap(i, j);
242 }
243 }
244
245 let fold_size = n_samples / self.n_splits;
246 let mut folds = Vec::with_capacity(self.n_splits);
247
248 for fold in 0..self.n_splits {
249 let start = fold * fold_size;
250 let end = if fold == self.n_splits - 1 {
251 n_samples
252 } else {
253 start + fold_size
254 };
255
256 let test_indices: Vec<usize> = indices[start..end].to_vec();
257 let train_indices: Vec<usize> = indices[..start]
258 .iter()
259 .chain(indices[end..].iter())
260 .copied()
261 .collect();
262
263 folds.push((train_indices, test_indices));
264 }
265
266 folds
267 }
268}
269
270pub struct StratifiedKFold {
272 n_splits: usize,
274 shuffle: bool,
276 random_state: Option<u64>,
278}
279
280impl StratifiedKFold {
281 pub fn new(n_splits: usize) -> Self {
283 Self {
284 n_splits,
285 shuffle: false,
286 random_state: None,
287 }
288 }
289
290 pub fn shuffle(mut self, shuffle: bool) -> Self {
292 self.shuffle = shuffle;
293 self
294 }
295
296 pub fn random_state(mut self, random_state: u64) -> Self {
298 self.random_state = Some(random_state);
299 self
300 }
301
302 pub fn split(&self, y: &Array1<f64>) -> Vec<(Vec<usize>, Vec<usize>)> {
304 let n_samples = y.len();
305
306 let mut class_indices: std::collections::HashMap<i64, Vec<usize>> =
308 std::collections::HashMap::new();
309 for (i, &val) in y.iter().enumerate() {
310 let class = val as i64;
311 class_indices.entry(class).or_insert_with(Vec::new).push(i);
312 }
313
314 if self.shuffle {
316 if let Some(seed) = self.random_state {
317 fastrand::seed(seed);
318 }
319 for indices in class_indices.values_mut() {
320 for i in (1..indices.len()).rev() {
321 let j = fastrand::usize(0..=i);
322 indices.swap(i, j);
323 }
324 }
325 }
326
327 let mut folds: Vec<(Vec<usize>, Vec<usize>)> = (0..self.n_splits)
328 .map(|_| (Vec::new(), Vec::new()))
329 .collect();
330
331 for indices in class_indices.values() {
333 let fold_sizes: Vec<usize> = (0..self.n_splits)
334 .map(|f| {
335 let base = indices.len() / self.n_splits;
336 if f < indices.len() % self.n_splits {
337 base + 1
338 } else {
339 base
340 }
341 })
342 .collect();
343
344 let mut current = 0;
345 for (fold, &size) in fold_sizes.iter().enumerate() {
346 for &idx in &indices[current..current + size] {
347 folds[fold].1.push(idx); }
349 current += size;
350 }
351 }
352
353 for fold_idx in 0..self.n_splits {
355 let test_set: std::collections::HashSet<usize> =
356 folds[fold_idx].1.iter().copied().collect();
357 folds[fold_idx].0 = (0..n_samples).filter(|i| !test_set.contains(i)).collect();
358 }
359
360 folds
361 }
362}