1use std::collections::HashMap;
8
9use crate::error::OptimizeError;
10use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
11
12#[derive(Debug, Clone)]
14pub enum HyperparamSpace {
15 Categorical(Vec<String>),
17 IntRange(i64, i64),
19 FloatRange(f64, f64),
21 LogFloatRange(f64, f64),
23 Bool,
25}
26
27impl HyperparamSpace {
28 pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
30 match self {
31 Self::Categorical(choices) => {
32 if choices.is_empty() {
33 return HyperparamValue::String(String::new());
34 }
35 let idx = rng.random_range(0..choices.len());
36 HyperparamValue::String(choices[idx].clone())
37 }
38 Self::IntRange(lo, hi) => {
39 if lo >= hi {
40 return HyperparamValue::Int(*lo);
41 }
42 HyperparamValue::Int(rng.random_range(*lo..=*hi))
43 }
44 Self::FloatRange(lo, hi) => {
45 if lo >= hi {
46 return HyperparamValue::Float(*lo);
47 }
48 let u = rng.random::<f64>();
49 HyperparamValue::Float(lo + u * (hi - lo))
50 }
51 Self::LogFloatRange(lo, hi) => {
52 if *lo <= 0.0 || *hi <= 0.0 || lo >= hi {
53 return HyperparamValue::Float(*lo);
54 }
55 let log_lo = lo.ln();
56 let log_hi = hi.ln();
57 let u = rng.random::<f64>();
58 let log_val = log_lo + u * (log_hi - log_lo);
59 HyperparamValue::Float(log_val.exp())
60 }
61 Self::Bool => HyperparamValue::Bool(rng.random_bool(0.5)),
62 }
63 }
64}
65
66#[derive(Debug, Clone)]
68pub enum HyperparamValue {
69 String(String),
71 Int(i64),
73 Float(f64),
75 Bool(bool),
77}
78
79impl HyperparamValue {
80 pub fn as_float(&self) -> Option<f64> {
82 if let Self::Float(v) = self {
83 Some(*v)
84 } else {
85 None
86 }
87 }
88
89 pub fn as_int(&self) -> Option<i64> {
91 if let Self::Int(v) = self {
92 Some(*v)
93 } else {
94 None
95 }
96 }
97
98 pub fn as_bool(&self) -> Option<bool> {
100 if let Self::Bool(v) = self {
101 Some(*v)
102 } else {
103 None
104 }
105 }
106
107 pub fn as_str(&self) -> Option<&str> {
109 if let Self::String(s) = self {
110 Some(s.as_str())
111 } else {
112 None
113 }
114 }
115}
116
117#[derive(Debug)]
119pub struct AutoMLConfig {
120 pub search_spaces: HashMap<String, HyperparamSpace>,
122 pub n_trials: usize,
124 pub optimization_target: String,
126 pub maximize: bool,
128}
129
130impl AutoMLConfig {
131 pub fn new(target: &str, maximize: bool) -> Self {
133 Self {
134 search_spaces: HashMap::new(),
135 n_trials: 50,
136 optimization_target: target.to_string(),
137 maximize,
138 }
139 }
140
141 pub fn add_space(mut self, name: &str, space: HyperparamSpace) -> Self {
143 self.search_spaces.insert(name.to_string(), space);
144 self
145 }
146
147 pub fn with_n_trials(mut self, n: usize) -> Self {
149 self.n_trials = n;
150 self
151 }
152}
153
154#[derive(Debug)]
156pub struct AutoMLResult {
157 pub best_config: HashMap<String, HyperparamValue>,
159 pub best_score: f64,
161 pub all_configs: Vec<(HashMap<String, HyperparamValue>, f64)>,
163 pub n_trials: usize,
165}
166
167impl AutoMLResult {
168 pub fn scores(&self) -> impl Iterator<Item = f64> + '_ {
170 self.all_configs.iter().map(|(_, s)| *s)
171 }
172}
173
174pub struct AutoMLOptimizer {
180 config: AutoMLConfig,
181}
182
183impl AutoMLOptimizer {
184 pub fn new(config: AutoMLConfig) -> Self {
186 Self { config }
187 }
188
189 pub fn optimize<F>(&self, evaluate: F, seed: u64) -> Result<AutoMLResult, OptimizeError>
198 where
199 F: Fn(&HashMap<String, HyperparamValue>) -> Result<f64, OptimizeError>,
200 {
201 if self.config.n_trials == 0 {
202 return Err(OptimizeError::InvalidParameter(
203 "n_trials must be at least 1".to_string(),
204 ));
205 }
206
207 let mut rng = StdRng::seed_from_u64(seed);
208
209 let mut best_score = if self.config.maximize {
210 f64::NEG_INFINITY
211 } else {
212 f64::INFINITY
213 };
214 let mut best_config: HashMap<String, HyperparamValue> = HashMap::new();
215 let mut all_configs = Vec::with_capacity(self.config.n_trials);
216
217 for _ in 0..self.config.n_trials {
218 let trial_config: HashMap<String, HyperparamValue> = self
220 .config
221 .search_spaces
222 .iter()
223 .map(|(k, space)| (k.clone(), space.sample(&mut rng)))
224 .collect();
225
226 let score = evaluate(&trial_config)?;
227
228 let is_better = if self.config.maximize {
229 score > best_score
230 } else {
231 score < best_score
232 };
233
234 if is_better || best_config.is_empty() {
235 best_score = score;
236 best_config = trial_config.clone();
237 }
238
239 all_configs.push((trial_config, score));
240 }
241
242 Ok(AutoMLResult {
243 best_config,
244 best_score,
245 all_configs,
246 n_trials: self.config.n_trials,
247 })
248 }
249}
250
251#[cfg(test)]
252mod tests {
253 use super::*;
254
255 fn lr_objective(cfg: &HashMap<String, HyperparamValue>) -> Result<f64, OptimizeError> {
257 let lr = cfg
258 .get("lr")
259 .and_then(|v| v.as_float())
260 .ok_or_else(|| OptimizeError::InvalidParameter("missing lr".into()))?;
261 let target = 1e-3_f64;
262 Ok(-((lr - target) / target).powi(2))
263 }
264
265 #[test]
266 fn test_automl_random_search_finds_good_lr() {
267 let config = AutoMLConfig::new("neg_mse", true)
268 .add_space("lr", HyperparamSpace::LogFloatRange(1e-5, 1e-1))
269 .with_n_trials(200);
270
271 let opt = AutoMLOptimizer::new(config);
272 let result = opt.optimize(lr_objective, 42).expect("optimize failed");
273
274 assert_eq!(result.n_trials, 200);
275 assert_eq!(result.all_configs.len(), 200);
276
277 assert!(
279 result.best_score > -1.0,
280 "best_score too low: {}",
281 result.best_score
282 );
283 }
284
285 #[test]
286 fn test_automl_minimize_mode() {
287 let config = AutoMLConfig::new("mse", false)
288 .add_space("lr", HyperparamSpace::LogFloatRange(1e-5, 1e-1))
289 .with_n_trials(100);
290
291 let opt = AutoMLOptimizer::new(config);
292 let result = opt
293 .optimize(
294 |cfg| {
295 let lr = cfg["lr"].as_float().unwrap_or(1.0);
296 Ok((lr - 1e-3).powi(2))
297 },
298 7,
299 )
300 .expect("optimize failed");
301
302 assert!(result.best_score < 1.0);
304 }
305
306 #[test]
307 fn test_automl_categorical_space() {
308 let config = AutoMLConfig::new("score", true)
309 .add_space(
310 "optimizer",
311 HyperparamSpace::Categorical(vec!["adam".into(), "sgd".into(), "rmsprop".into()]),
312 )
313 .with_n_trials(30);
314
315 let opt = AutoMLOptimizer::new(config);
316 let result = opt
317 .optimize(
318 |cfg| {
319 let name = cfg["optimizer"].as_str().unwrap_or("unknown");
320 Ok(if name == "adam" { 1.0 } else { 0.0 })
321 },
322 0,
323 )
324 .expect("optimize failed");
325
326 assert!(result.best_score >= 0.0);
327 }
328
329 #[test]
330 fn test_automl_int_range_space() {
331 let config = AutoMLConfig::new("score", true)
332 .add_space("n_layers", HyperparamSpace::IntRange(1, 10))
333 .with_n_trials(50);
334
335 let opt = AutoMLOptimizer::new(config);
336 let result = opt
337 .optimize(
338 |cfg| {
339 let n = cfg["n_layers"].as_int().unwrap_or(1);
340 Ok(-(n as f64 - 5.0).powi(2))
341 },
342 5,
343 )
344 .expect("optimize failed");
345
346 let best_n = result.best_config["n_layers"].as_int().unwrap_or(0);
347 assert!((1..=10).contains(&best_n));
348 }
349
350 #[test]
351 fn test_automl_bool_space_samples() {
352 let config = AutoMLConfig::new("score", true)
353 .add_space("use_bn", HyperparamSpace::Bool)
354 .with_n_trials(20);
355
356 let opt = AutoMLOptimizer::new(config);
357 let result = opt
358 .optimize(
359 |cfg| {
360 let bn = cfg["use_bn"].as_bool().unwrap_or(false);
361 Ok(if bn { 1.0 } else { 0.0 })
362 },
363 3,
364 )
365 .expect("optimize failed");
366
367 assert!(result.best_score >= 0.0);
368 }
369
370 #[test]
371 fn test_automl_zero_trials_errors() {
372 let config = AutoMLConfig::new("score", true).with_n_trials(0);
373 let opt = AutoMLOptimizer::new(config);
374 assert!(opt.optimize(|_| Ok(0.0), 0).is_err());
375 }
376
377 #[test]
378 fn test_automl_result_scores_iter() {
379 let config = AutoMLConfig::new("score", true)
380 .add_space("lr", HyperparamSpace::FloatRange(0.0, 1.0))
381 .with_n_trials(10);
382
383 let opt = AutoMLOptimizer::new(config);
384 let result = opt.optimize(|_| Ok(1.0), 0).expect("optimize failed");
385
386 let scores: Vec<f64> = result.scores().collect();
387 assert_eq!(scores.len(), 10);
388 }
389}