use crate::distribution::{FloatDistribution, IntDistribution};
use crate::kde::KernelDensityEstimator;
use crate::rng_util;
pub(crate) fn sample_tpe_float(
dist: &FloatDistribution,
good_values: Vec<f64>,
bad_values: Vec<f64>,
n_ei_candidates: usize,
kde_bandwidth: Option<f64>,
rng: &mut fastrand::Rng,
) -> f64 {
let (low, high, log_scale, step) = (dist.low, dist.high, dist.log_scale, dist.step);
let (internal_low, internal_high, good_internal, bad_internal) = if log_scale {
let i_low = low.ln();
let i_high = high.ln();
let g = {
let mut v = good_values;
for x in &mut v {
*x = x.ln();
}
v
};
let b = {
let mut v = bad_values;
for x in &mut v {
*x = x.ln();
}
v
};
(i_low, i_high, g, b)
} else {
(low, high, good_values, bad_values)
};
let l_kde = match kde_bandwidth {
Some(bw) => KernelDensityEstimator::with_bandwidth(good_internal, bw),
None => KernelDensityEstimator::new(good_internal),
};
let g_kde = match kde_bandwidth {
Some(bw) => KernelDensityEstimator::with_bandwidth(bad_internal, bw),
None => KernelDensityEstimator::new(bad_internal),
};
let (Ok(l_kde), Ok(g_kde)) = (l_kde, g_kde) else {
return rng_util::f64_range(rng, low, high);
};
let mut best_candidate = internal_low;
let mut best_ratio = f64::NEG_INFINITY;
for _ in 0..n_ei_candidates {
let candidate = l_kde.sample(rng).clamp(internal_low, internal_high);
let l_density = l_kde.pdf(candidate);
let g_density = g_kde.pdf(candidate);
let ratio = if g_density < f64::EPSILON {
if l_density > f64::EPSILON {
f64::INFINITY
} else {
0.0
}
} else {
l_density / g_density
};
if ratio > best_ratio {
best_ratio = ratio;
best_candidate = candidate;
}
}
let mut value = if log_scale {
best_candidate.exp()
} else {
best_candidate
};
if let Some(step) = step {
let k = ((value - low) / step).round();
value = low + k * step;
}
value.clamp(low, high)
}
#[allow(clippy::cast_precision_loss, clippy::cast_possible_truncation)]
pub(crate) fn sample_tpe_int(
dist: &IntDistribution,
good_values: Vec<i64>,
bad_values: Vec<i64>,
n_ei_candidates: usize,
kde_bandwidth: Option<f64>,
rng: &mut fastrand::Rng,
) -> i64 {
let (low, high, log_scale, step) = (dist.low, dist.high, dist.log_scale, dist.step);
let good_floats: Vec<f64> = good_values.into_iter().map(|v| v as f64).collect();
let bad_floats: Vec<f64> = bad_values.into_iter().map(|v| v as f64).collect();
let float_dist = FloatDistribution {
low: low as f64,
high: high as f64,
log_scale,
step: step.map(|s| s as f64),
};
let float_value = sample_tpe_float(
&float_dist,
good_floats,
bad_floats,
n_ei_candidates,
kde_bandwidth,
rng,
);
let int_value = float_value.round() as i64;
let int_value = if let Some(step) = step {
let k = ((int_value - low) as f64 / step as f64).round() as i64;
low + k * step
} else {
int_value
};
int_value.clamp(low, high)
}
#[allow(clippy::cast_precision_loss)]
pub(crate) fn sample_tpe_categorical(
n_choices: usize,
good_indices: &[usize],
bad_indices: &[usize],
rng: &mut fastrand::Rng,
) -> usize {
let mut good_buf = [0usize; 32];
let mut bad_buf = [0usize; 32];
let mut weight_buf = [0.0f64; 32];
let mut good_vec;
let mut bad_vec;
let mut weight_vec;
let (good_counts, bad_counts, weights): (&mut [usize], &mut [usize], &mut [f64]) =
if n_choices <= 32 {
(
&mut good_buf[..n_choices],
&mut bad_buf[..n_choices],
&mut weight_buf[..n_choices],
)
} else {
good_vec = vec![0usize; n_choices];
bad_vec = vec![0usize; n_choices];
weight_vec = vec![0.0f64; n_choices];
(&mut good_vec, &mut bad_vec, &mut weight_vec)
};
for &idx in good_indices {
if idx < n_choices {
good_counts[idx] += 1;
}
}
for &idx in bad_indices {
if idx < n_choices {
bad_counts[idx] += 1;
}
}
let good_total = good_indices.len() as f64 + n_choices as f64;
let bad_total = bad_indices.len() as f64 + n_choices as f64;
for i in 0..n_choices {
let l_prob = (good_counts[i] as f64 + 1.0) / good_total;
let g_prob = (bad_counts[i] as f64 + 1.0) / bad_total;
weights[i] = l_prob / g_prob;
}
let total_weight: f64 = weights.iter().sum();
let threshold = rng.f64() * total_weight;
let mut cumulative = 0.0;
for (i, &w) in weights.iter().enumerate() {
cumulative += w;
if cumulative >= threshold {
return i;
}
}
n_choices - 1
}