Skip to main content

entrenar/optim/hpo/types/
space.rs

1//! Hyperparameter search space
2
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use crate::optim::hpo::error::{HPOError, Result};
8
9use super::parameter::{ParameterDomain, ParameterValue};
10
11/// Hyperparameter search space
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct HyperparameterSpace {
14    /// Parameter name -> domain mapping
15    params: HashMap<String, ParameterDomain>,
16}
17
18impl HyperparameterSpace {
19    /// Create an empty search space
20    pub fn new() -> Self {
21        Self::default()
22    }
23
24    /// Add a parameter to the search space
25    pub fn add(&mut self, name: &str, domain: ParameterDomain) {
26        self.params.insert(name.to_string(), domain);
27    }
28
29    /// Get a parameter domain
30    pub fn get(&self, name: &str) -> Option<&ParameterDomain> {
31        self.params.get(name)
32    }
33
34    /// Check if space is empty
35    pub fn is_empty(&self) -> bool {
36        self.params.is_empty()
37    }
38
39    /// Get number of parameters
40    pub fn len(&self) -> usize {
41        self.params.len()
42    }
43
44    /// Iterate over parameters
45    pub fn iter(&self) -> impl Iterator<Item = (&String, &ParameterDomain)> {
46        self.params.iter()
47    }
48
49    /// Sample a random configuration
50    pub fn sample_random<R: Rng>(&self, rng: &mut R) -> HashMap<String, ParameterValue> {
51        self.params.iter().map(|(name, domain)| (name.clone(), domain.sample(rng))).collect()
52    }
53
54    /// Validate a configuration
55    pub fn validate(&self, config: &HashMap<String, ParameterValue>) -> Result<()> {
56        for (name, domain) in &self.params {
57            match config.get(name) {
58                Some(value) if domain.is_valid(value) => {}
59                Some(value) => {
60                    return Err(HPOError::InvalidValue(name.clone(), format!("{value:?}")))
61                }
62                None => return Err(HPOError::ParameterNotFound(name.clone())),
63            }
64        }
65        Ok(())
66    }
67}