scry_learn/search/
random.rs1use std::collections::HashMap;
5
6use crate::dataset::Dataset;
7use crate::error::{Result, ScryLearnError};
8use crate::metrics::accuracy;
9use crate::split::{k_fold, stratified_k_fold, ScoringFn};
10
11use super::{cartesian_product, evaluate_combo, CvResult, ParamGrid, ParamValue, Tunable};
12
13#[non_exhaustive]
37pub struct RandomizedSearchCV {
38 base_model: Box<dyn Tunable>,
39 param_grid: ParamGrid,
40 n_iter: usize,
41 cv: usize,
42 scorer: ScoringFn,
43 seed: u64,
44 stratified: bool,
45 best_params_: Option<HashMap<String, ParamValue>>,
46 best_score_: f64,
47 cv_results_: Vec<CvResult>,
48}
49
50impl RandomizedSearchCV {
51 pub fn new(model: impl Tunable + 'static, grid: ParamGrid) -> Self {
55 Self {
56 base_model: Box::new(model),
57 param_grid: grid,
58 n_iter: 10,
59 cv: 5,
60 scorer: accuracy,
61 seed: 42,
62 stratified: false,
63 best_params_: None,
64 best_score_: f64::NEG_INFINITY,
65 cv_results_: Vec::new(),
66 }
67 }
68
69 pub fn n_iter(mut self, n: usize) -> Self {
71 self.n_iter = n;
72 self
73 }
74
75 pub fn cv(mut self, k: usize) -> Self {
77 self.cv = k;
78 self
79 }
80
81 pub fn scoring(mut self, scorer: ScoringFn) -> Self {
83 self.scorer = scorer;
84 self
85 }
86
87 pub fn seed(mut self, seed: u64) -> Self {
89 self.seed = seed;
90 self
91 }
92
93 pub fn stratified(mut self, stratified: bool) -> Self {
98 self.stratified = stratified;
99 self
100 }
101
102 pub fn fit(mut self, data: &Dataset) -> Result<Self> {
106 if self.cv < 2 {
107 return Err(ScryLearnError::InvalidParameter(format!(
108 "cv must be >= 2, got {}",
109 self.cv
110 )));
111 }
112 let all_combos = cartesian_product(&self.param_grid);
113 if all_combos.is_empty() {
114 return Err(ScryLearnError::InvalidParameter(
115 "parameter grid is empty".into(),
116 ));
117 }
118
119 let folds = if self.stratified {
120 stratified_k_fold(data, self.cv, self.seed)
121 } else {
122 k_fold(data, self.cv, self.seed)
123 };
124 let mut rng = crate::rng::FastRng::new(self.seed);
125
126 let n = self.n_iter.min(all_combos.len());
128 let mut indices: Vec<usize> = (0..all_combos.len()).collect();
129 for i in (1..indices.len()).rev() {
131 let j = rng.usize(0..=i);
132 indices.swap(i, j);
133 }
134
135 for &idx in &indices[..n] {
136 let combo = &all_combos[idx];
137 let result = evaluate_combo(&*self.base_model, combo, &folds, self.scorer)?;
138
139 if result.mean_score.is_finite()
140 && (self.best_params_.is_none() || result.mean_score > self.best_score_)
141 {
142 self.best_score_ = result.mean_score;
143 self.best_params_ = Some(result.params.clone());
144 }
145 self.cv_results_.push(result);
146 }
147
148 if self.best_params_.is_none() {
149 return Err(ScryLearnError::InvalidParameter(
150 "all parameter combinations produced NaN scores".into(),
151 ));
152 }
153
154 Ok(self)
155 }
156
157 pub fn best_params(&self) -> &HashMap<String, ParamValue> {
163 self.best_params_.as_ref().expect("call fit() first")
164 }
165
166 pub fn best_score(&self) -> f64 {
168 self.best_score_
169 }
170
171 pub fn cv_results(&self) -> &[CvResult] {
173 &self.cv_results_
174 }
175}