entrenar/optim/hpo/tpe/
optimizer.rs1use rand::Rng;
4use std::collections::HashMap;
5
6use crate::optim::hpo::error::{HPOError, Result};
7use crate::optim::hpo::types::{
8 HyperparameterSpace, ParameterDomain, ParameterValue, Trial, TrialStatus,
9};
10
11use super::sampling::{count_categorical, sample_ei_ratio_continuous, sample_ei_ratio_discrete};
12
13#[derive(Debug, Clone)]
20pub struct TPEOptimizer {
21 space: HyperparameterSpace,
23 pub(crate) gamma: f64,
25 n_startup: usize,
27 kde_bandwidth: f64,
29 trials: Vec<Trial>,
31 next_id: usize,
33}
34
35impl TPEOptimizer {
36 pub fn new(space: HyperparameterSpace) -> Self {
38 Self {
39 space,
40 gamma: 0.25,
41 n_startup: 10,
42 kde_bandwidth: 1.0,
43 trials: Vec::new(),
44 next_id: 0,
45 }
46 }
47
48 pub fn with_gamma(mut self, gamma: f64) -> Self {
50 self.gamma = gamma.clamp(0.01, 0.99);
51 self
52 }
53
54 pub fn with_startup(mut self, n: usize) -> Self {
56 self.n_startup = n.max(1);
57 self
58 }
59
60 pub fn n_trials(&self) -> usize {
62 self.trials.iter().filter(|t| t.status == TrialStatus::Completed).count()
63 }
64
65 pub fn best_trial(&self) -> Option<&Trial> {
67 self.trials
68 .iter()
69 .filter(|t| t.status == TrialStatus::Completed)
70 .min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
71 }
72
73 pub fn suggest(&mut self) -> Result<Trial> {
75 if self.space.is_empty() {
76 return Err(HPOError::EmptySpace);
77 }
78
79 let mut rng = rand::rng();
80 let config = if self.n_trials() < self.n_startup {
81 self.space.sample_random(&mut rng)
83 } else {
84 self.tpe_sample(&mut rng)
86 };
87
88 let trial = Trial::new(self.next_id, config);
89 self.next_id += 1;
90 Ok(trial)
91 }
92
93 pub fn record(&mut self, mut trial: Trial, score: f64, iterations: usize) {
95 trial.complete(score, iterations);
96 self.trials.push(trial);
97 }
98
99 pub fn record_failed(&mut self, mut trial: Trial) {
101 trial.fail();
102 self.trials.push(trial);
103 }
104
105 fn tpe_sample<R: Rng>(&self, rng: &mut R) -> HashMap<String, ParameterValue> {
107 let completed: Vec<_> =
108 self.trials.iter().filter(|t| t.status == TrialStatus::Completed).collect();
109
110 if completed.is_empty() {
111 return self.space.sample_random(rng);
112 }
113
114 let n_good = ((completed.len() as f64) * self.gamma).ceil() as usize;
116 let n_good = n_good.max(1).min(completed.len() - 1);
117
118 let mut sorted: Vec<_> = completed.clone();
119 sorted.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal));
120
121 let (good_trials, bad_trials) = sorted.split_at(n_good);
122
123 let mut config = HashMap::new();
125 for (name, domain) in self.space.iter() {
126 let value = self.sample_parameter_tpe(name, domain, good_trials, bad_trials, rng);
127 config.insert(name.clone(), value);
128 }
129
130 config
131 }
132
133 fn sample_parameter_tpe<R: Rng>(
135 &self,
136 name: &str,
137 domain: &ParameterDomain,
138 good_trials: &[&Trial],
139 bad_trials: &[&Trial],
140 rng: &mut R,
141 ) -> ParameterValue {
142 match domain {
143 ParameterDomain::Continuous { low, high, log_scale } => {
144 let good_values: Vec<f64> = good_trials
146 .iter()
147 .filter_map(|t| t.config.get(name)?.as_float())
148 .map(|v| if *log_scale { v.max(f64::MIN_POSITIVE).ln() } else { v })
149 .collect();
150
151 let bad_values: Vec<f64> = bad_trials
152 .iter()
153 .filter_map(|t| t.config.get(name)?.as_float())
154 .map(|v| if *log_scale { v.max(f64::MIN_POSITIVE).ln() } else { v })
155 .collect();
156
157 let (effective_low, effective_high) = if *log_scale {
159 (low.max(f64::MIN_POSITIVE).ln(), high.max(f64::MIN_POSITIVE).ln())
160 } else {
161 (*low, *high)
162 };
163
164 let value = sample_ei_ratio_continuous(
165 &good_values,
166 &bad_values,
167 effective_low,
168 effective_high,
169 self.kde_bandwidth,
170 rng,
171 );
172
173 let final_value = if *log_scale { value.exp() } else { value };
174 ParameterValue::Float(final_value.clamp(*low, *high))
175 }
176 ParameterDomain::Discrete { low, high } => {
177 let good_values: Vec<i64> =
179 good_trials.iter().filter_map(|t| t.config.get(name)?.as_int()).collect();
180
181 let bad_values: Vec<i64> =
182 bad_trials.iter().filter_map(|t| t.config.get(name)?.as_int()).collect();
183
184 let value = sample_ei_ratio_discrete(&good_values, &bad_values, *low, *high, rng);
185 ParameterValue::Int(value)
186 }
187 ParameterDomain::Categorical { choices } => {
188 let good_counts = count_categorical(name, good_trials, choices);
190 let bad_counts = count_categorical(name, bad_trials, choices);
191
192 let mut weights: Vec<f64> = choices
194 .iter()
195 .enumerate()
196 .map(|(i, _)| {
197 let l = (good_counts[i] + 1) as f64; let g = (bad_counts[i] + 1) as f64;
199 l / g
200 })
201 .collect();
202
203 let total: f64 = weights.iter().sum();
205 for w in &mut weights {
206 *w /= total;
207 }
208
209 let r: f64 = rng.random();
211 let mut cumsum = 0.0;
212 for (i, &w) in weights.iter().enumerate() {
213 cumsum += w;
214 if r < cumsum {
215 return ParameterValue::Categorical(choices[i].clone());
216 }
217 }
218
219 ParameterValue::Categorical(
220 choices.last().expect("choices is non-empty per validate()").clone(),
221 )
222 }
223 }
224 }
225}