entrenar/optim/hpo/tpe/
sampling.rs1use rand::Rng;
4
5use crate::optim::hpo::types::{ParameterValue, Trial};
6
7pub fn sample_ei_ratio_continuous<R: Rng>(
9 good_values: &[f64],
10 bad_values: &[f64],
11 low: f64,
12 high: f64,
13 kde_bandwidth: f64,
14 rng: &mut R,
15) -> f64 {
16 if good_values.is_empty() {
17 return low + rng.random::<f64>() * (high - low);
18 }
19
20 let n_candidates = 24;
22 let mut best_value = low;
23 let mut best_ei = f64::NEG_INFINITY;
24
25 let bandwidth = kde_bandwidth * (high - low) / 10.0;
26
27 for _ in 0..n_candidates {
28 let idx = (rng.random::<f64>() * good_values.len() as f64).floor() as usize;
30 let idx = idx.min(good_values.len() - 1);
31 let base = good_values[idx];
32 let u1: f64 = rng.random::<f64>().max(1e-10);
34 let u2: f64 = rng.random::<f64>();
35 let noise = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos() * bandwidth;
36 let candidate = (base + noise).clamp(low, high);
37
38 let l_score = kde_score(candidate, good_values, bandwidth);
40 let g_score = kde_score(candidate, bad_values, bandwidth);
41 let ei = l_score / (g_score + 1e-10);
42
43 if ei > best_ei {
44 best_ei = ei;
45 best_value = candidate;
46 }
47 }
48
49 best_value
50}
51
52fn kde_score(x: f64, values: &[f64], bandwidth: f64) -> f64 {
54 if values.is_empty() {
55 return 1.0;
56 }
57 values.iter().map(|&v| (-(x - v).powi(2) / (2.0 * bandwidth.powi(2))).exp()).sum::<f64>()
58 / values.len() as f64
59}
60
61pub fn sample_ei_ratio_discrete<R: Rng>(
63 good_values: &[i64],
64 bad_values: &[i64],
65 low: i64,
66 high: i64,
67 rng: &mut R,
68) -> i64 {
69 if good_values.is_empty() {
70 let range = (high - low + 1) as usize;
71 let offset = (rng.random::<f64>() * range as f64).floor() as i64;
72 return (low + offset).min(high);
73 }
74
75 let range = (high - low + 1) as usize;
77 let mut good_counts = vec![1.0; range]; let mut bad_counts = vec![1.0; range];
79
80 for &v in good_values {
81 good_counts[(v - low) as usize] += 1.0;
82 }
83 for &v in bad_values {
84 bad_counts[(v - low) as usize] += 1.0;
85 }
86
87 let mut weights: Vec<f64> =
89 good_counts.iter().zip(bad_counts.iter()).map(|(l, g)| l / g).collect();
90
91 let total: f64 = weights.iter().sum();
93 for w in &mut weights {
94 *w /= total;
95 }
96
97 let r: f64 = rng.random();
99 let mut cumsum = 0.0;
100 for (i, &w) in weights.iter().enumerate() {
101 cumsum += w;
102 if r < cumsum {
103 return low + i as i64;
104 }
105 }
106
107 high
108}
109
110pub fn count_categorical(name: &str, trials: &[&Trial], choices: &[String]) -> Vec<usize> {
112 let mut counts = vec![0usize; choices.len()];
113 for trial in trials {
114 if let Some(ParameterValue::Categorical(s)) = trial.config.get(name) {
115 if let Some(idx) = choices.iter().position(|c| c == s) {
116 counts[idx] += 1;
117 }
118 }
119 }
120 counts
121}