quantrs2_ml/sklearn_compatibility/
model_selection.rs

1//! Sklearn-compatible model selection utilities
2
3use super::{SklearnClassifier, SklearnEstimator};
4use crate::error::{MLError, Result};
5use scirs2_core::ndarray::{Array1, Array2, Axis};
6use scirs2_core::random::prelude::*;
7use std::collections::HashMap;
8
9/// Cross-validation score
10#[allow(non_snake_case)]
11pub fn cross_val_score<E>(
12    estimator: &mut E,
13    X: &Array2<f64>,
14    y: &Array1<f64>,
15    cv: usize,
16) -> Result<Array1<f64>>
17where
18    E: SklearnClassifier,
19{
20    let n_samples = X.nrows();
21    let fold_size = n_samples / cv;
22    let mut scores = Array1::zeros(cv);
23
24    // Create fold indices
25    let mut indices: Vec<usize> = (0..n_samples).collect();
26    indices.shuffle(&mut thread_rng());
27
28    for fold in 0..cv {
29        let start_test = fold * fold_size;
30        let end_test = if fold == cv - 1 {
31            n_samples
32        } else {
33            (fold + 1) * fold_size
34        };
35
36        // Create train/test splits
37        let test_indices = &indices[start_test..end_test];
38        let train_indices: Vec<usize> = indices
39            .iter()
40            .enumerate()
41            .filter(|(i, _)| *i < start_test || *i >= end_test)
42            .map(|(_, &idx)| idx)
43            .collect();
44
45        // Extract train/test data
46        let X_train = X.select(Axis(0), &train_indices);
47        let y_train = y.select(Axis(0), &train_indices);
48        let X_test = X.select(Axis(0), test_indices);
49        let y_test = y.select(Axis(0), test_indices);
50
51        // Convert to i32 for classification
52        let y_test_int = y_test.mapv(|x| x.round() as i32);
53
54        // Train and evaluate
55        estimator.fit(&X_train, Some(&y_train))?;
56        scores[fold] = estimator.score(&X_test, &y_test_int)?;
57    }
58
59    Ok(scores)
60}
61
62/// Train-test split
63#[allow(non_snake_case)]
64pub fn train_test_split(
65    X: &Array2<f64>,
66    y: &Array1<f64>,
67    test_size: f64,
68    random_state: Option<u64>,
69) -> Result<(Array2<f64>, Array2<f64>, Array1<f64>, Array1<f64>)> {
70    let n_samples = X.nrows();
71    let n_test = (n_samples as f64 * test_size).round() as usize;
72
73    // Create indices
74    let mut indices: Vec<usize> = (0..n_samples).collect();
75
76    if let Some(seed) = random_state {
77        let mut rng = StdRng::seed_from_u64(seed);
78        indices.shuffle(&mut rng);
79    } else {
80        indices.shuffle(&mut thread_rng());
81    }
82
83    let test_indices = &indices[..n_test];
84    let train_indices = &indices[n_test..];
85
86    let X_train = X.select(Axis(0), train_indices);
87    let X_test = X.select(Axis(0), test_indices);
88    let y_train = y.select(Axis(0), train_indices);
89    let y_test = y.select(Axis(0), test_indices);
90
91    Ok((X_train, X_test, y_train, y_test))
92}
93
94/// Grid search for hyperparameter tuning
95pub struct GridSearchCV<E> {
96    /// Base estimator
97    estimator: E,
98    /// Parameter grid
99    param_grid: HashMap<String, Vec<String>>,
100    /// Cross-validation folds
101    cv: usize,
102    /// Best parameters
103    pub best_params_: HashMap<String, String>,
104    /// Best score
105    pub best_score_: f64,
106    /// Best estimator
107    pub best_estimator_: E,
108    /// Fitted flag
109    fitted: bool,
110}
111
112impl<E> GridSearchCV<E>
113where
114    E: SklearnClassifier + Clone,
115{
116    /// Create new grid search
117    pub fn new(estimator: E, param_grid: HashMap<String, Vec<String>>, cv: usize) -> Self {
118        Self {
119            best_estimator_: estimator.clone(),
120            estimator,
121            param_grid,
122            cv,
123            best_params_: HashMap::new(),
124            best_score_: f64::NEG_INFINITY,
125            fitted: false,
126        }
127    }
128
129    /// Fit grid search
130    #[allow(non_snake_case)]
131    pub fn fit(&mut self, X: &Array2<f64>, y: &Array1<f64>) -> Result<()> {
132        let param_combinations = self.generate_param_combinations();
133
134        for params in param_combinations {
135            let mut estimator = self.estimator.clone();
136            estimator.set_params(params.clone())?;
137
138            let scores = cross_val_score(&mut estimator, X, y, self.cv)?;
139            let mean_score = scores.mean().unwrap_or(0.0);
140
141            if mean_score > self.best_score_ {
142                self.best_score_ = mean_score;
143                self.best_params_ = params.clone();
144                self.best_estimator_ = estimator;
145            }
146        }
147
148        // Fit best estimator
149        if !self.best_params_.is_empty() {
150            self.best_estimator_.set_params(self.best_params_.clone())?;
151            self.best_estimator_.fit(X, Some(y))?;
152        }
153
154        self.fitted = true;
155        Ok(())
156    }
157
158    /// Generate all parameter combinations
159    fn generate_param_combinations(&self) -> Vec<HashMap<String, String>> {
160        let mut combinations = vec![HashMap::new()];
161
162        for (param_name, param_values) in &self.param_grid {
163            let mut new_combinations = Vec::new();
164
165            for combination in &combinations {
166                for value in param_values {
167                    let mut new_combination = combination.clone();
168                    new_combination.insert(param_name.clone(), value.clone());
169                    new_combinations.push(new_combination);
170                }
171            }
172
173            combinations = new_combinations;
174        }
175
176        combinations
177    }
178
179    /// Get best parameters
180    pub fn best_params(&self) -> &HashMap<String, String> {
181        &self.best_params_
182    }
183
184    /// Get best score
185    pub fn best_score(&self) -> f64 {
186        self.best_score_
187    }
188
189    /// Predict with best estimator
190    #[allow(non_snake_case)]
191    pub fn predict(&self, X: &Array2<f64>) -> Result<Array1<i32>> {
192        if !self.fitted {
193            return Err(MLError::ModelNotTrained("Model not trained".to_string()));
194        }
195        self.best_estimator_.predict(X)
196    }
197}
198
199/// K-Fold cross-validation
200pub struct KFold {
201    /// Number of folds
202    n_splits: usize,
203    /// Whether to shuffle
204    shuffle: bool,
205    /// Random state
206    random_state: Option<u64>,
207}
208
209impl KFold {
210    /// Create new KFold
211    pub fn new(n_splits: usize) -> Self {
212        Self {
213            n_splits,
214            shuffle: false,
215            random_state: None,
216        }
217    }
218
219    /// Set shuffle
220    pub fn shuffle(mut self, shuffle: bool) -> Self {
221        self.shuffle = shuffle;
222        self
223    }
224
225    /// Set random state
226    pub fn random_state(mut self, random_state: u64) -> Self {
227        self.random_state = Some(random_state);
228        self
229    }
230
231    /// Split data into folds
232    pub fn split(&self, n_samples: usize) -> Vec<(Vec<usize>, Vec<usize>)> {
233        let mut indices: Vec<usize> = (0..n_samples).collect();
234
235        if self.shuffle {
236            if let Some(seed) = self.random_state {
237                fastrand::seed(seed);
238            }
239            for i in (1..indices.len()).rev() {
240                let j = fastrand::usize(0..=i);
241                indices.swap(i, j);
242            }
243        }
244
245        let fold_size = n_samples / self.n_splits;
246        let mut folds = Vec::with_capacity(self.n_splits);
247
248        for fold in 0..self.n_splits {
249            let start = fold * fold_size;
250            let end = if fold == self.n_splits - 1 {
251                n_samples
252            } else {
253                start + fold_size
254            };
255
256            let test_indices: Vec<usize> = indices[start..end].to_vec();
257            let train_indices: Vec<usize> = indices[..start]
258                .iter()
259                .chain(indices[end..].iter())
260                .copied()
261                .collect();
262
263            folds.push((train_indices, test_indices));
264        }
265
266        folds
267    }
268}
269
270/// Stratified K-Fold cross-validation
271pub struct StratifiedKFold {
272    /// Number of folds
273    n_splits: usize,
274    /// Whether to shuffle
275    shuffle: bool,
276    /// Random state
277    random_state: Option<u64>,
278}
279
280impl StratifiedKFold {
281    /// Create new StratifiedKFold
282    pub fn new(n_splits: usize) -> Self {
283        Self {
284            n_splits,
285            shuffle: false,
286            random_state: None,
287        }
288    }
289
290    /// Set shuffle
291    pub fn shuffle(mut self, shuffle: bool) -> Self {
292        self.shuffle = shuffle;
293        self
294    }
295
296    /// Set random state
297    pub fn random_state(mut self, random_state: u64) -> Self {
298        self.random_state = Some(random_state);
299        self
300    }
301
302    /// Split data into stratified folds
303    pub fn split(&self, y: &Array1<f64>) -> Vec<(Vec<usize>, Vec<usize>)> {
304        let n_samples = y.len();
305
306        // Group indices by class
307        let mut class_indices: std::collections::HashMap<i64, Vec<usize>> =
308            std::collections::HashMap::new();
309        for (i, &val) in y.iter().enumerate() {
310            let class = val as i64;
311            class_indices.entry(class).or_insert_with(Vec::new).push(i);
312        }
313
314        // Shuffle within each class if requested
315        if self.shuffle {
316            if let Some(seed) = self.random_state {
317                fastrand::seed(seed);
318            }
319            for indices in class_indices.values_mut() {
320                for i in (1..indices.len()).rev() {
321                    let j = fastrand::usize(0..=i);
322                    indices.swap(i, j);
323                }
324            }
325        }
326
327        let mut folds: Vec<(Vec<usize>, Vec<usize>)> = (0..self.n_splits)
328            .map(|_| (Vec::new(), Vec::new()))
329            .collect();
330
331        // Distribute samples from each class across folds
332        for indices in class_indices.values() {
333            let fold_sizes: Vec<usize> = (0..self.n_splits)
334                .map(|f| {
335                    let base = indices.len() / self.n_splits;
336                    if f < indices.len() % self.n_splits {
337                        base + 1
338                    } else {
339                        base
340                    }
341                })
342                .collect();
343
344            let mut current = 0;
345            for (fold, &size) in fold_sizes.iter().enumerate() {
346                for &idx in &indices[current..current + size] {
347                    folds[fold].1.push(idx); // Test indices for this fold
348                }
349                current += size;
350            }
351        }
352
353        // Create train indices as complement of test indices
354        for fold_idx in 0..self.n_splits {
355            let test_set: std::collections::HashSet<usize> =
356                folds[fold_idx].1.iter().copied().collect();
357            folds[fold_idx].0 = (0..n_samples).filter(|i| !test_set.contains(i)).collect();
358        }
359
360        folds
361    }
362}