Skip to main content

entrenar/optim/hpo/types/
parameter.rs

1//! Parameter value and domain types
2
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5
6/// Parameter value (sampled from domain)
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub enum ParameterValue {
9    Float(f64),
10    Int(i64),
11    Categorical(String),
12}
13
14impl ParameterValue {
15    /// Get as float (converts int to float if needed)
16    pub fn as_float(&self) -> Option<f64> {
17        match self {
18            ParameterValue::Float(v) => Some(*v),
19            ParameterValue::Int(v) => Some(*v as f64),
20            ParameterValue::Categorical(_) => None,
21        }
22    }
23
24    /// Get as int
25    pub fn as_int(&self) -> Option<i64> {
26        match self {
27            ParameterValue::Int(v) => Some(*v),
28            ParameterValue::Float(v) => Some(*v as i64),
29            ParameterValue::Categorical(_) => None,
30        }
31    }
32
33    /// Get as string
34    pub fn as_str(&self) -> Option<&str> {
35        match self {
36            ParameterValue::Categorical(s) => Some(s),
37            _ => None,
38        }
39    }
40}
41
42/// Parameter domain (search space)
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum ParameterDomain {
45    /// Continuous range [low, high], optionally log-scaled
46    Continuous { low: f64, high: f64, log_scale: bool },
47    /// Discrete integer range [low, high]
48    Discrete { low: i64, high: i64 },
49    /// Categorical choices
50    Categorical { choices: Vec<String> },
51}
52
53impl ParameterDomain {
54    /// Sample a random value from this domain
55    pub fn sample<R: Rng>(&self, rng: &mut R) -> ParameterValue {
56        contract_pre_sample!();
57        match self {
58            ParameterDomain::Continuous { low, high, log_scale } => {
59                let value = if *log_scale {
60                    let log_low = low.max(f64::MIN_POSITIVE).ln();
61                    let log_high = high.max(f64::MIN_POSITIVE).ln();
62                    let log_val = log_low + rng.random::<f64>() * (log_high - log_low);
63                    log_val.exp()
64                } else {
65                    low + rng.random::<f64>() * (high - low)
66                };
67                ParameterValue::Float(value)
68            }
69            ParameterDomain::Discrete { low, high } => {
70                let range = (*high - *low + 1) as usize;
71                let offset = (rng.random::<f64>() * range as f64).floor() as i64;
72                let value = (*low + offset).min(*high);
73                ParameterValue::Int(value)
74            }
75            ParameterDomain::Categorical { choices } => {
76                let idx = (rng.random::<f64>() * choices.len() as f64).floor() as usize;
77                let idx = idx.min(choices.len() - 1);
78                ParameterValue::Categorical(choices[idx].clone())
79            }
80        }
81    }
82
83    /// Check if a value is valid for this domain
84    pub fn is_valid(&self, value: &ParameterValue) -> bool {
85        match (self, value) {
86            (ParameterDomain::Continuous { low, high, .. }, ParameterValue::Float(v)) => {
87                *v >= *low && *v <= *high
88            }
89            (ParameterDomain::Discrete { low, high }, ParameterValue::Int(v)) => {
90                *v >= *low && *v <= *high
91            }
92            (ParameterDomain::Categorical { choices }, ParameterValue::Categorical(s)) => {
93                choices.contains(s)
94            }
95            (
96                ParameterDomain::Continuous { .. } | ParameterDomain::Categorical { .. },
97                ParameterValue::Int(_),
98            )
99            | (
100                ParameterDomain::Continuous { .. } | ParameterDomain::Discrete { .. },
101                ParameterValue::Categorical(_),
102            )
103            | (
104                ParameterDomain::Discrete { .. } | ParameterDomain::Categorical { .. },
105                ParameterValue::Float(_),
106            ) => false,
107        }
108    }
109}