Skip to main content

entrenar/optim/hpo/tpe/
optimizer.rs

1//! TPE optimizer core implementation
2
3use rand::Rng;
4use std::collections::HashMap;
5
6use crate::optim::hpo::error::{HPOError, Result};
7use crate::optim::hpo::types::{
8    HyperparameterSpace, ParameterDomain, ParameterValue, Trial, TrialStatus,
9};
10
11use super::sampling::{count_categorical, sample_ei_ratio_continuous, sample_ei_ratio_discrete};
12
13/// Tree-structured Parzen Estimator optimizer
14///
15/// # Toyota Way: Kaizen
16///
17/// Uses accumulated knowledge from trials to make increasingly better suggestions.
18/// Splits trials by quantile to model "good" vs "bad" configurations.
19#[derive(Debug, Clone)]
20pub struct TPEOptimizer {
21    /// Search space
22    space: HyperparameterSpace,
23    /// Quantile for splitting good/bad (default: 0.25)
24    pub(crate) gamma: f64,
25    /// Number of startup trials (random sampling)
26    n_startup: usize,
27    /// KDE bandwidth
28    kde_bandwidth: f64,
29    /// Completed trials
30    trials: Vec<Trial>,
31    /// Next trial ID
32    next_id: usize,
33}
34
35impl TPEOptimizer {
36    /// Create a new TPE optimizer
37    pub fn new(space: HyperparameterSpace) -> Self {
38        Self {
39            space,
40            gamma: 0.25,
41            n_startup: 10,
42            kde_bandwidth: 1.0,
43            trials: Vec::new(),
44            next_id: 0,
45        }
46    }
47
48    /// Set gamma (quantile for splitting)
49    pub fn with_gamma(mut self, gamma: f64) -> Self {
50        self.gamma = gamma.clamp(0.01, 0.99);
51        self
52    }
53
54    /// Set number of startup trials
55    pub fn with_startup(mut self, n: usize) -> Self {
56        self.n_startup = n.max(1);
57        self
58    }
59
60    /// Get number of completed trials
61    pub fn n_trials(&self) -> usize {
62        self.trials.iter().filter(|t| t.status == TrialStatus::Completed).count()
63    }
64
65    /// Get best trial so far
66    pub fn best_trial(&self) -> Option<&Trial> {
67        self.trials
68            .iter()
69            .filter(|t| t.status == TrialStatus::Completed)
70            .min_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal))
71    }
72
73    /// Suggest next configuration to try
74    pub fn suggest(&mut self) -> Result<Trial> {
75        if self.space.is_empty() {
76            return Err(HPOError::EmptySpace);
77        }
78
79        let mut rng = rand::rng();
80        let config = if self.n_trials() < self.n_startup {
81            // Random sampling during startup phase
82            self.space.sample_random(&mut rng)
83        } else {
84            // TPE-guided sampling
85            self.tpe_sample(&mut rng)
86        };
87
88        let trial = Trial::new(self.next_id, config);
89        self.next_id += 1;
90        Ok(trial)
91    }
92
93    /// Record trial result
94    pub fn record(&mut self, mut trial: Trial, score: f64, iterations: usize) {
95        trial.complete(score, iterations);
96        self.trials.push(trial);
97    }
98
99    /// Record failed trial
100    pub fn record_failed(&mut self, mut trial: Trial) {
101        trial.fail();
102        self.trials.push(trial);
103    }
104
105    /// TPE sampling (internal)
106    fn tpe_sample<R: Rng>(&self, rng: &mut R) -> HashMap<String, ParameterValue> {
107        let completed: Vec<_> =
108            self.trials.iter().filter(|t| t.status == TrialStatus::Completed).collect();
109
110        if completed.is_empty() {
111            return self.space.sample_random(rng);
112        }
113
114        // Split trials into good (l) and bad (g) by gamma quantile
115        let n_good = ((completed.len() as f64) * self.gamma).ceil() as usize;
116        let n_good = n_good.max(1).min(completed.len() - 1);
117
118        let mut sorted: Vec<_> = completed.clone();
119        sorted.sort_by(|a, b| a.score.partial_cmp(&b.score).unwrap_or(std::cmp::Ordering::Equal));
120
121        let (good_trials, bad_trials) = sorted.split_at(n_good);
122
123        // Sample each parameter using TPE
124        let mut config = HashMap::new();
125        for (name, domain) in self.space.iter() {
126            let value = self.sample_parameter_tpe(name, domain, good_trials, bad_trials, rng);
127            config.insert(name.clone(), value);
128        }
129
130        config
131    }
132
133    /// Sample a single parameter using TPE
134    fn sample_parameter_tpe<R: Rng>(
135        &self,
136        name: &str,
137        domain: &ParameterDomain,
138        good_trials: &[&Trial],
139        bad_trials: &[&Trial],
140        rng: &mut R,
141    ) -> ParameterValue {
142        match domain {
143            ParameterDomain::Continuous { low, high, log_scale } => {
144                // Extract values from trials
145                let good_values: Vec<f64> = good_trials
146                    .iter()
147                    .filter_map(|t| t.config.get(name)?.as_float())
148                    .map(|v| if *log_scale { v.max(f64::MIN_POSITIVE).ln() } else { v })
149                    .collect();
150
151                let bad_values: Vec<f64> = bad_trials
152                    .iter()
153                    .filter_map(|t| t.config.get(name)?.as_float())
154                    .map(|v| if *log_scale { v.max(f64::MIN_POSITIVE).ln() } else { v })
155                    .collect();
156
157                // Sample from l(x) / g(x) using simple KDE approximation
158                let (effective_low, effective_high) = if *log_scale {
159                    (low.max(f64::MIN_POSITIVE).ln(), high.max(f64::MIN_POSITIVE).ln())
160                } else {
161                    (*low, *high)
162                };
163
164                let value = sample_ei_ratio_continuous(
165                    &good_values,
166                    &bad_values,
167                    effective_low,
168                    effective_high,
169                    self.kde_bandwidth,
170                    rng,
171                );
172
173                let final_value = if *log_scale { value.exp() } else { value };
174                ParameterValue::Float(final_value.clamp(*low, *high))
175            }
176            ParameterDomain::Discrete { low, high } => {
177                // Extract values
178                let good_values: Vec<i64> =
179                    good_trials.iter().filter_map(|t| t.config.get(name)?.as_int()).collect();
180
181                let bad_values: Vec<i64> =
182                    bad_trials.iter().filter_map(|t| t.config.get(name)?.as_int()).collect();
183
184                let value = sample_ei_ratio_discrete(&good_values, &bad_values, *low, *high, rng);
185                ParameterValue::Int(value)
186            }
187            ParameterDomain::Categorical { choices } => {
188                // Count occurrences
189                let good_counts = count_categorical(name, good_trials, choices);
190                let bad_counts = count_categorical(name, bad_trials, choices);
191
192                // Sample based on l(x) / g(x)
193                let mut weights: Vec<f64> = choices
194                    .iter()
195                    .enumerate()
196                    .map(|(i, _)| {
197                        let l = (good_counts[i] + 1) as f64; // Laplace smoothing
198                        let g = (bad_counts[i] + 1) as f64;
199                        l / g
200                    })
201                    .collect();
202
203                // Normalize
204                let total: f64 = weights.iter().sum();
205                for w in &mut weights {
206                    *w /= total;
207                }
208
209                // Sample
210                let r: f64 = rng.random();
211                let mut cumsum = 0.0;
212                for (i, &w) in weights.iter().enumerate() {
213                    cumsum += w;
214                    if r < cumsum {
215                        return ParameterValue::Categorical(choices[i].clone());
216                    }
217                }
218
219                ParameterValue::Categorical(
220                    choices.last().expect("choices is non-empty per validate()").clone(),
221                )
222            }
223        }
224    }
225}