entrenar/optim/hpo/types/
parameter.rs1use rand::Rng;
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
8pub enum ParameterValue {
9 Float(f64),
10 Int(i64),
11 Categorical(String),
12}
13
14impl ParameterValue {
15 pub fn as_float(&self) -> Option<f64> {
17 match self {
18 ParameterValue::Float(v) => Some(*v),
19 ParameterValue::Int(v) => Some(*v as f64),
20 ParameterValue::Categorical(_) => None,
21 }
22 }
23
24 pub fn as_int(&self) -> Option<i64> {
26 match self {
27 ParameterValue::Int(v) => Some(*v),
28 ParameterValue::Float(v) => Some(*v as i64),
29 ParameterValue::Categorical(_) => None,
30 }
31 }
32
33 pub fn as_str(&self) -> Option<&str> {
35 match self {
36 ParameterValue::Categorical(s) => Some(s),
37 _ => None,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub enum ParameterDomain {
45 Continuous { low: f64, high: f64, log_scale: bool },
47 Discrete { low: i64, high: i64 },
49 Categorical { choices: Vec<String> },
51}
52
53impl ParameterDomain {
54 pub fn sample<R: Rng>(&self, rng: &mut R) -> ParameterValue {
56 contract_pre_sample!();
57 match self {
58 ParameterDomain::Continuous { low, high, log_scale } => {
59 let value = if *log_scale {
60 let log_low = low.max(f64::MIN_POSITIVE).ln();
61 let log_high = high.max(f64::MIN_POSITIVE).ln();
62 let log_val = log_low + rng.random::<f64>() * (log_high - log_low);
63 log_val.exp()
64 } else {
65 low + rng.random::<f64>() * (high - low)
66 };
67 ParameterValue::Float(value)
68 }
69 ParameterDomain::Discrete { low, high } => {
70 let range = (*high - *low + 1) as usize;
71 let offset = (rng.random::<f64>() * range as f64).floor() as i64;
72 let value = (*low + offset).min(*high);
73 ParameterValue::Int(value)
74 }
75 ParameterDomain::Categorical { choices } => {
76 let idx = (rng.random::<f64>() * choices.len() as f64).floor() as usize;
77 let idx = idx.min(choices.len() - 1);
78 ParameterValue::Categorical(choices[idx].clone())
79 }
80 }
81 }
82
83 pub fn is_valid(&self, value: &ParameterValue) -> bool {
85 match (self, value) {
86 (ParameterDomain::Continuous { low, high, .. }, ParameterValue::Float(v)) => {
87 *v >= *low && *v <= *high
88 }
89 (ParameterDomain::Discrete { low, high }, ParameterValue::Int(v)) => {
90 *v >= *low && *v <= *high
91 }
92 (ParameterDomain::Categorical { choices }, ParameterValue::Categorical(s)) => {
93 choices.contains(s)
94 }
95 (
96 ParameterDomain::Continuous { .. } | ParameterDomain::Categorical { .. },
97 ParameterValue::Int(_),
98 )
99 | (
100 ParameterDomain::Continuous { .. } | ParameterDomain::Discrete { .. },
101 ParameterValue::Categorical(_),
102 )
103 | (
104 ParameterDomain::Discrete { .. } | ParameterDomain::Categorical { .. },
105 ParameterValue::Float(_),
106 ) => false,
107 }
108 }
109}