entrenar/optim/hpo/types/
space.rs1use rand::Rng;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use crate::optim::hpo::error::{HPOError, Result};
8
9use super::parameter::{ParameterDomain, ParameterValue};
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct HyperparameterSpace {
14 params: HashMap<String, ParameterDomain>,
16}
17
18impl HyperparameterSpace {
19 pub fn new() -> Self {
21 Self::default()
22 }
23
24 pub fn add(&mut self, name: &str, domain: ParameterDomain) {
26 self.params.insert(name.to_string(), domain);
27 }
28
29 pub fn get(&self, name: &str) -> Option<&ParameterDomain> {
31 self.params.get(name)
32 }
33
34 pub fn is_empty(&self) -> bool {
36 self.params.is_empty()
37 }
38
39 pub fn len(&self) -> usize {
41 self.params.len()
42 }
43
44 pub fn iter(&self) -> impl Iterator<Item = (&String, &ParameterDomain)> {
46 self.params.iter()
47 }
48
49 pub fn sample_random<R: Rng>(&self, rng: &mut R) -> HashMap<String, ParameterValue> {
51 self.params.iter().map(|(name, domain)| (name.clone(), domain.sample(rng))).collect()
52 }
53
54 pub fn validate(&self, config: &HashMap<String, ParameterValue>) -> Result<()> {
56 for (name, domain) in &self.params {
57 match config.get(name) {
58 Some(value) if domain.is_valid(value) => {}
59 Some(value) => {
60 return Err(HPOError::InvalidValue(name.clone(), format!("{value:?}")))
61 }
62 None => return Err(HPOError::ParameterNotFound(name.clone())),
63 }
64 }
65 Ok(())
66 }
67}