Skip to main content

scirs2_optimize/nas/
automl.rs

1//! AutoML: Automated machine learning pipeline optimisation.
2//!
3//! Provides random search over user-defined hyperparameter spaces,
4//! returning the configuration that maximises (or minimises) a
5//! user-supplied objective function.
6
7use std::collections::HashMap;
8
9use crate::error::OptimizeError;
10use scirs2_core::random::{rngs::StdRng, Rng, RngExt, SeedableRng};
11
12/// A single hyperparameter's search domain.
13#[derive(Debug, Clone)]
14pub enum HyperparamSpace {
15    /// Choose uniformly from a finite set of string values
16    Categorical(Vec<String>),
17    /// Uniform integer in `[lo, hi]` (inclusive)
18    IntRange(i64, i64),
19    /// Uniform float in `[lo, hi)`
20    FloatRange(f64, f64),
21    /// Log-uniform float in `[lo, hi)` (useful for learning rates)
22    LogFloatRange(f64, f64),
23    /// Bernoulli sample with p = 0.5
24    Bool,
25}
26
27impl HyperparamSpace {
28    /// Draw one sample from this domain.
29    pub fn sample(&self, rng: &mut StdRng) -> HyperparamValue {
30        match self {
31            Self::Categorical(choices) => {
32                if choices.is_empty() {
33                    return HyperparamValue::String(String::new());
34                }
35                let idx = rng.random_range(0..choices.len());
36                HyperparamValue::String(choices[idx].clone())
37            }
38            Self::IntRange(lo, hi) => {
39                if lo >= hi {
40                    return HyperparamValue::Int(*lo);
41                }
42                HyperparamValue::Int(rng.random_range(*lo..=*hi))
43            }
44            Self::FloatRange(lo, hi) => {
45                if lo >= hi {
46                    return HyperparamValue::Float(*lo);
47                }
48                let u = rng.random::<f64>();
49                HyperparamValue::Float(lo + u * (hi - lo))
50            }
51            Self::LogFloatRange(lo, hi) => {
52                if *lo <= 0.0 || *hi <= 0.0 || lo >= hi {
53                    return HyperparamValue::Float(*lo);
54                }
55                let log_lo = lo.ln();
56                let log_hi = hi.ln();
57                let u = rng.random::<f64>();
58                let log_val = log_lo + u * (log_hi - log_lo);
59                HyperparamValue::Float(log_val.exp())
60            }
61            Self::Bool => HyperparamValue::Bool(rng.random_bool(0.5)),
62        }
63    }
64}
65
66/// A sampled hyperparameter value.
67#[derive(Debug, Clone)]
68pub enum HyperparamValue {
69    /// String (from Categorical)
70    String(String),
71    /// Integer
72    Int(i64),
73    /// Floating-point scalar
74    Float(f64),
75    /// Boolean
76    Bool(bool),
77}
78
79impl HyperparamValue {
80    /// Extract float, or `None` if this is not a `Float` variant.
81    pub fn as_float(&self) -> Option<f64> {
82        if let Self::Float(v) = self {
83            Some(*v)
84        } else {
85            None
86        }
87    }
88
89    /// Extract integer, or `None`.
90    pub fn as_int(&self) -> Option<i64> {
91        if let Self::Int(v) = self {
92            Some(*v)
93        } else {
94            None
95        }
96    }
97
98    /// Extract bool, or `None`.
99    pub fn as_bool(&self) -> Option<bool> {
100        if let Self::Bool(v) = self {
101            Some(*v)
102        } else {
103            None
104        }
105    }
106
107    /// Extract string reference, or `None`.
108    pub fn as_str(&self) -> Option<&str> {
109        if let Self::String(s) = self {
110            Some(s.as_str())
111        } else {
112            None
113        }
114    }
115}
116
117/// Configuration for an AutoML random search run.
118#[derive(Debug)]
119pub struct AutoMLConfig {
120    /// Per-hyperparameter search domains
121    pub search_spaces: HashMap<String, HyperparamSpace>,
122    /// Number of random trials to evaluate
123    pub n_trials: usize,
124    /// Name of the optimisation target metric
125    pub optimization_target: String,
126    /// If `true`, maximise the metric; if `false`, minimise it
127    pub maximize: bool,
128}
129
130impl AutoMLConfig {
131    /// Create a new config with no search spaces.
132    pub fn new(target: &str, maximize: bool) -> Self {
133        Self {
134            search_spaces: HashMap::new(),
135            n_trials: 50,
136            optimization_target: target.to_string(),
137            maximize,
138        }
139    }
140
141    /// Add a named hyperparameter with its search domain (builder pattern).
142    pub fn add_space(mut self, name: &str, space: HyperparamSpace) -> Self {
143        self.search_spaces.insert(name.to_string(), space);
144        self
145    }
146
147    /// Set the trial budget (builder pattern).
148    pub fn with_n_trials(mut self, n: usize) -> Self {
149        self.n_trials = n;
150        self
151    }
152}
153
154/// Summary of a completed AutoML search.
155#[derive(Debug)]
156pub struct AutoMLResult {
157    /// Best hyperparameter configuration found
158    pub best_config: HashMap<String, HyperparamValue>,
159    /// Score of the best configuration
160    pub best_score: f64,
161    /// All (config, score) pairs in trial order
162    pub all_configs: Vec<(HashMap<String, HyperparamValue>, f64)>,
163    /// Total number of trials run
164    pub n_trials: usize,
165}
166
167impl AutoMLResult {
168    /// Iterate over all trial scores.
169    pub fn scores(&self) -> impl Iterator<Item = f64> + '_ {
170        self.all_configs.iter().map(|(_, s)| *s)
171    }
172}
173
174/// AutoML optimiser using random search.
175///
176/// For each trial a configuration is sampled uniformly from the product
177/// of all hyperparameter domains and evaluated by the user-provided
178/// closure.  The best configuration (by `maximize` criterion) is returned.
179pub struct AutoMLOptimizer {
180    config: AutoMLConfig,
181}
182
183impl AutoMLOptimizer {
184    /// Create a new optimiser with the given configuration.
185    pub fn new(config: AutoMLConfig) -> Self {
186        Self { config }
187    }
188
189    /// Run random search.
190    ///
191    /// # Arguments
192    /// - `evaluate`: Closure mapping a config to a scalar metric.
193    /// - `seed`: Random seed for reproducibility.
194    ///
195    /// # Errors
196    /// Propagates any error returned by `evaluate`.
197    pub fn optimize<F>(&self, evaluate: F, seed: u64) -> Result<AutoMLResult, OptimizeError>
198    where
199        F: Fn(&HashMap<String, HyperparamValue>) -> Result<f64, OptimizeError>,
200    {
201        if self.config.n_trials == 0 {
202            return Err(OptimizeError::InvalidParameter(
203                "n_trials must be at least 1".to_string(),
204            ));
205        }
206
207        let mut rng = StdRng::seed_from_u64(seed);
208
209        let mut best_score = if self.config.maximize {
210            f64::NEG_INFINITY
211        } else {
212            f64::INFINITY
213        };
214        let mut best_config: HashMap<String, HyperparamValue> = HashMap::new();
215        let mut all_configs = Vec::with_capacity(self.config.n_trials);
216
217        for _ in 0..self.config.n_trials {
218            // Sample one config from each hyperparameter's domain
219            let trial_config: HashMap<String, HyperparamValue> = self
220                .config
221                .search_spaces
222                .iter()
223                .map(|(k, space)| (k.clone(), space.sample(&mut rng)))
224                .collect();
225
226            let score = evaluate(&trial_config)?;
227
228            let is_better = if self.config.maximize {
229                score > best_score
230            } else {
231                score < best_score
232            };
233
234            if is_better || best_config.is_empty() {
235                best_score = score;
236                best_config = trial_config.clone();
237            }
238
239            all_configs.push((trial_config, score));
240        }
241
242        Ok(AutoMLResult {
243            best_config,
244            best_score,
245            all_configs,
246            n_trials: self.config.n_trials,
247        })
248    }
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    /// Quadratic objective in log-learning-rate: minimise (lr - 1e-3)^2
256    fn lr_objective(cfg: &HashMap<String, HyperparamValue>) -> Result<f64, OptimizeError> {
257        let lr = cfg
258            .get("lr")
259            .and_then(|v| v.as_float())
260            .ok_or_else(|| OptimizeError::InvalidParameter("missing lr".into()))?;
261        let target = 1e-3_f64;
262        Ok(-((lr - target) / target).powi(2))
263    }
264
265    #[test]
266    fn test_automl_random_search_finds_good_lr() {
267        let config = AutoMLConfig::new("neg_mse", true)
268            .add_space("lr", HyperparamSpace::LogFloatRange(1e-5, 1e-1))
269            .with_n_trials(200);
270
271        let opt = AutoMLOptimizer::new(config);
272        let result = opt.optimize(lr_objective, 42).expect("optimize failed");
273
274        assert_eq!(result.n_trials, 200);
275        assert_eq!(result.all_configs.len(), 200);
276
277        // Best score should be close to 0.0 with enough trials
278        assert!(
279            result.best_score > -1.0,
280            "best_score too low: {}",
281            result.best_score
282        );
283    }
284
285    #[test]
286    fn test_automl_minimize_mode() {
287        let config = AutoMLConfig::new("mse", false)
288            .add_space("lr", HyperparamSpace::LogFloatRange(1e-5, 1e-1))
289            .with_n_trials(100);
290
291        let opt = AutoMLOptimizer::new(config);
292        let result = opt
293            .optimize(
294                |cfg| {
295                    let lr = cfg["lr"].as_float().unwrap_or(1.0);
296                    Ok((lr - 1e-3).powi(2))
297                },
298                7,
299            )
300            .expect("optimize failed");
301
302        // Minimised MSE should be very small
303        assert!(result.best_score < 1.0);
304    }
305
306    #[test]
307    fn test_automl_categorical_space() {
308        let config = AutoMLConfig::new("score", true)
309            .add_space(
310                "optimizer",
311                HyperparamSpace::Categorical(vec!["adam".into(), "sgd".into(), "rmsprop".into()]),
312            )
313            .with_n_trials(30);
314
315        let opt = AutoMLOptimizer::new(config);
316        let result = opt
317            .optimize(
318                |cfg| {
319                    let name = cfg["optimizer"].as_str().unwrap_or("unknown");
320                    Ok(if name == "adam" { 1.0 } else { 0.0 })
321                },
322                0,
323            )
324            .expect("optimize failed");
325
326        assert!(result.best_score >= 0.0);
327    }
328
329    #[test]
330    fn test_automl_int_range_space() {
331        let config = AutoMLConfig::new("score", true)
332            .add_space("n_layers", HyperparamSpace::IntRange(1, 10))
333            .with_n_trials(50);
334
335        let opt = AutoMLOptimizer::new(config);
336        let result = opt
337            .optimize(
338                |cfg| {
339                    let n = cfg["n_layers"].as_int().unwrap_or(1);
340                    Ok(-(n as f64 - 5.0).powi(2))
341                },
342                5,
343            )
344            .expect("optimize failed");
345
346        let best_n = result.best_config["n_layers"].as_int().unwrap_or(0);
347        assert!((1..=10).contains(&best_n));
348    }
349
350    #[test]
351    fn test_automl_bool_space_samples() {
352        let config = AutoMLConfig::new("score", true)
353            .add_space("use_bn", HyperparamSpace::Bool)
354            .with_n_trials(20);
355
356        let opt = AutoMLOptimizer::new(config);
357        let result = opt
358            .optimize(
359                |cfg| {
360                    let bn = cfg["use_bn"].as_bool().unwrap_or(false);
361                    Ok(if bn { 1.0 } else { 0.0 })
362                },
363                3,
364            )
365            .expect("optimize failed");
366
367        assert!(result.best_score >= 0.0);
368    }
369
370    #[test]
371    fn test_automl_zero_trials_errors() {
372        let config = AutoMLConfig::new("score", true).with_n_trials(0);
373        let opt = AutoMLOptimizer::new(config);
374        assert!(opt.optimize(|_| Ok(0.0), 0).is_err());
375    }
376
377    #[test]
378    fn test_automl_result_scores_iter() {
379        let config = AutoMLConfig::new("score", true)
380            .add_space("lr", HyperparamSpace::FloatRange(0.0, 1.0))
381            .with_n_trials(10);
382
383        let opt = AutoMLOptimizer::new(config);
384        let result = opt.optimize(|_| Ok(1.0), 0).expect("optimize failed");
385
386        let scores: Vec<f64> = result.scores().collect();
387        assert_eq!(scores.len(), 10);
388    }
389}