use std::collections::HashMap;
use crate::distribution::Distribution;
use crate::param::ParamValue;
use crate::parameter::ParamId;
use crate::sampler::{CompletedTrial, PendingTrial};
use super::{ConstantLiarStrategy, MultivariateTpeSampler};
impl MultivariateTpeSampler {
#[must_use]
pub fn impute_pending_trials(
&self,
pending_trials: &[PendingTrial],
completed_trials: &[CompletedTrial],
) -> Vec<CompletedTrial> {
let mut result: Vec<CompletedTrial> = completed_trials.to_vec();
if matches!(self.constant_liar, ConstantLiarStrategy::None) || pending_trials.is_empty() {
return result;
}
let imputed_value = self.compute_imputation_value(completed_trials);
for pending in pending_trials {
result.push(CompletedTrial::new(
pending.id,
pending.params.clone(),
pending.distributions.clone(),
HashMap::new(),
imputed_value,
));
}
result
}
#[allow(clippy::cast_precision_loss)]
pub(crate) fn compute_imputation_value(&self, completed_trials: &[CompletedTrial]) -> f64 {
match self.constant_liar {
ConstantLiarStrategy::None => 0.0, ConstantLiarStrategy::Mean => {
if completed_trials.is_empty() {
0.0
} else {
let sum: f64 = completed_trials.iter().map(|t| t.value).sum();
sum / completed_trials.len() as f64
}
}
ConstantLiarStrategy::Best => {
completed_trials
.iter()
.map(|t| t.value)
.fold(f64::INFINITY, f64::min)
}
ConstantLiarStrategy::Worst => {
completed_trials
.iter()
.map(|t| t.value)
.fold(f64::NEG_INFINITY, f64::max)
}
ConstantLiarStrategy::Custom(v) => v,
}
}
#[must_use]
pub fn filter_trials<'a>(
&self,
history: &'a [CompletedTrial],
search_space: &HashMap<ParamId, Distribution>,
) -> Vec<&'a CompletedTrial> {
history
.iter()
.filter(|trial| {
search_space
.keys()
.all(|param_id| trial.params.contains_key(param_id))
})
.collect()
}
#[allow(
clippy::cast_precision_loss,
clippy::cast_possible_truncation,
clippy::cast_sign_loss
)]
#[must_use]
pub fn split_trials<'a>(
&self,
trials: &[&'a CompletedTrial],
) -> (Vec<&'a CompletedTrial>, Vec<&'a CompletedTrial>) {
if trials.is_empty() {
return (vec![], vec![]);
}
let mut sorted_indices: Vec<usize> = (0..trials.len()).collect();
sorted_indices.sort_by(|&a, &b| {
trials[a]
.value
.partial_cmp(&trials[b].value)
.unwrap_or(core::cmp::Ordering::Equal)
});
let gamma = self
.gamma_strategy
.gamma(trials.len())
.clamp(f64::EPSILON, 1.0 - f64::EPSILON);
let n_good = ((trials.len() as f64 * gamma).ceil() as usize)
.max(1)
.min(trials.len().saturating_sub(1));
if trials.len() == 1 {
return (vec![trials[0]], vec![]);
}
let good: Vec<_> = sorted_indices[..n_good]
.iter()
.map(|&i| trials[i])
.collect();
let bad: Vec<_> = sorted_indices[n_good..]
.iter()
.map(|&i| trials[i])
.collect();
(good, bad)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn extract_observations(
&self,
trials: &[&CompletedTrial],
param_order: &[ParamId],
) -> Vec<Vec<f64>> {
trials
.iter()
.map(|trial| {
param_order
.iter()
.filter_map(|param_id| {
trial.params.get(param_id).and_then(|value| match value {
crate::param::ParamValue::Float(f) => Some(*f),
crate::param::ParamValue::Int(i) => Some(*i as f64),
crate::param::ParamValue::Categorical(_) => None, })
})
.collect()
})
.collect()
}
pub(crate) fn extract_categorical_indices(
trials: &[&CompletedTrial],
param_id: ParamId,
) -> Vec<usize> {
trials
.iter()
.filter_map(|trial| {
trial.params.get(¶m_id).and_then(|value| {
if let ParamValue::Categorical(idx) = value {
Some(*idx)
} else {
None
}
})
})
.collect()
}
}