use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use crate::distribution::Distribution;
use crate::error::{Error, Result};
use crate::multi_objective::MultiObjectiveTrial;
use crate::param::ParamValue;
use crate::parameter::{ParamId, Parameter};
use crate::pruner::Pruner;
use crate::sampler::{CompletedTrial, Sampler};
use crate::types::TrialState;
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum AttrValue {
Float(f64),
Int(i64),
String(String),
Bool(bool),
}
impl From<f64> for AttrValue {
fn from(v: f64) -> Self {
Self::Float(v)
}
}
impl From<i64> for AttrValue {
fn from(v: i64) -> Self {
Self::Int(v)
}
}
impl From<String> for AttrValue {
fn from(v: String) -> Self {
Self::String(v)
}
}
impl From<&str> for AttrValue {
fn from(v: &str) -> Self {
Self::String(v.to_owned())
}
}
impl From<bool> for AttrValue {
fn from(v: bool) -> Self {
Self::Bool(v)
}
}
#[derive(Clone)]
pub struct Trial {
id: u64,
state: TrialState,
params: HashMap<ParamId, ParamValue>,
distributions: HashMap<ParamId, Distribution>,
param_labels: HashMap<ParamId, String>,
sampler: Option<Arc<dyn Sampler>>,
history: Option<Arc<RwLock<Vec<CompletedTrial<f64>>>>>,
intermediate_values: Vec<(u64, f64)>,
pruner: Option<Arc<dyn Pruner>>,
user_attrs: HashMap<String, AttrValue>,
fixed_params: HashMap<ParamId, ParamValue>,
constraint_values: Vec<f64>,
}
impl core::fmt::Debug for Trial {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("Trial")
.field("id", &self.id)
.field("state", &self.state)
.field("params", &self.params)
.field("distributions", &self.distributions)
.field("param_labels", &self.param_labels)
.field("has_sampler", &self.sampler.is_some())
.field("has_history", &self.history.is_some())
.field("intermediate_values", &self.intermediate_values)
.field("has_pruner", &self.pruner.is_some())
.field("user_attrs", &self.user_attrs)
.field("fixed_params", &self.fixed_params)
.field("constraint_values", &self.constraint_values)
.finish()
}
}
impl Trial {
#[must_use]
pub fn new(id: u64) -> Self {
Self {
id,
state: TrialState::Running,
params: HashMap::new(),
distributions: HashMap::new(),
param_labels: HashMap::new(),
sampler: None,
history: None,
intermediate_values: Vec::new(),
pruner: None,
user_attrs: HashMap::new(),
fixed_params: HashMap::new(),
constraint_values: Vec::new(),
}
}
pub(crate) fn with_sampler(
id: u64,
sampler: Arc<dyn Sampler>,
history: Arc<RwLock<Vec<CompletedTrial<f64>>>>,
pruner: Arc<dyn Pruner>,
) -> Self {
Self {
id,
state: TrialState::Running,
params: HashMap::new(),
distributions: HashMap::new(),
param_labels: HashMap::new(),
sampler: Some(sampler),
history: Some(history),
intermediate_values: Vec::new(),
pruner: Some(pruner),
user_attrs: HashMap::new(),
fixed_params: HashMap::new(),
constraint_values: Vec::new(),
}
}
pub(crate) fn set_fixed_params(&mut self, params: HashMap<ParamId, ParamValue>) {
self.fixed_params = params;
}
fn sample_value(&self, distribution: &Distribution) -> ParamValue {
if let (Some(sampler), Some(history)) = (&self.sampler, &self.history) {
let history_guard = history.read();
sampler.sample(distribution, self.id, &history_guard)
} else {
use crate::sampler::random::RandomSampler;
let fallback = RandomSampler::new();
fallback.sample(distribution, self.id, &[])
}
}
#[must_use]
pub fn id(&self) -> u64 {
self.id
}
#[must_use]
pub fn state(&self) -> TrialState {
self.state
}
#[must_use]
pub fn params(&self) -> &HashMap<ParamId, ParamValue> {
&self.params
}
#[must_use]
pub fn distributions(&self) -> &HashMap<ParamId, Distribution> {
&self.distributions
}
#[must_use]
pub fn param_labels(&self) -> &HashMap<ParamId, String> {
&self.param_labels
}
pub fn report(&mut self, step: u64, value: f64) {
if let Some(entry) = self
.intermediate_values
.iter_mut()
.find(|(s, _)| *s == step)
{
entry.1 = value;
} else {
self.intermediate_values.push((step, value));
}
}
#[must_use]
pub fn should_prune(&self) -> bool {
let (Some(pruner), Some(history)) = (&self.pruner, &self.history) else {
return false;
};
let Some(&(step, _)) = self.intermediate_values.last() else {
return false;
};
let history_guard = history.read();
let prune = pruner.should_prune(self.id, step, &self.intermediate_values, &history_guard);
if prune {
trace_info!(trial_id = self.id, step, "pruner recommends stopping");
}
prune
}
#[must_use]
pub fn intermediate_values(&self) -> &[(u64, f64)] {
&self.intermediate_values
}
pub fn set_user_attr(&mut self, key: impl Into<String>, value: impl Into<AttrValue>) {
self.user_attrs.insert(key.into(), value.into());
}
#[must_use]
pub fn user_attr(&self, key: &str) -> Option<&AttrValue> {
self.user_attrs.get(key)
}
#[must_use]
pub fn user_attrs(&self) -> &HashMap<String, AttrValue> {
&self.user_attrs
}
pub fn set_constraints(&mut self, values: Vec<f64>) {
self.constraint_values = values;
}
#[must_use]
pub fn constraint_values(&self) -> &[f64] {
&self.constraint_values
}
pub(crate) fn set_failed(&mut self) {
self.state = TrialState::Failed;
}
pub fn suggest_param<P: Parameter>(&mut self, param: &P) -> Result<P::Value> {
param.validate()?;
let param_id = param.id();
let distribution = param.distribution();
if let Some(existing_dist) = self.distributions.get(¶m_id) {
if *existing_dist == distribution {
if let Some(value) = self.params.get(¶m_id) {
return param.cast_param_value(value);
}
}
return Err(Error::ParameterConflict {
name: param.label(),
reason: "parameter was previously sampled with different configuration or type"
.to_string(),
});
}
let value = if let Some(fixed_value) = self.fixed_params.remove(¶m_id) {
fixed_value
} else {
self.sample_value(&distribution)
};
let result = param.cast_param_value(&value)?;
trace_debug!(
trial_id = self.id,
param = %param.label(),
value = %value,
"parameter sampled"
);
self.distributions.insert(param_id, distribution);
self.params.insert(param_id, value);
self.param_labels.insert(param_id, param.label());
Ok(result)
}
pub(crate) fn into_completed<V>(self, value: V, state: TrialState) -> CompletedTrial<V> {
CompletedTrial {
id: self.id,
params: self.params,
distributions: self.distributions,
param_labels: self.param_labels,
value,
intermediate_values: self.intermediate_values,
state,
user_attrs: self.user_attrs,
constraints: self.constraint_values,
}
}
pub(crate) fn into_multi_objective_trial(
self,
values: Vec<f64>,
state: TrialState,
) -> MultiObjectiveTrial {
MultiObjectiveTrial {
id: self.id,
params: self.params,
distributions: self.distributions,
param_labels: self.param_labels,
values,
state,
user_attrs: self.user_attrs,
constraints: self.constraint_values,
}
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use crate::parameter::{BoolParam, CategoricalParam, FloatParam, IntParam, Parameter};
use crate::types::TrialState;
#[test]
fn trial_state() {
let trial = super::Trial::new(0);
assert_eq!(trial.state(), TrialState::Running);
}
#[test]
fn trial_params_access() {
let x_param = FloatParam::new(0.0, 1.0);
let n_param = IntParam::new(1, 10);
let mut trial = super::Trial::new(0);
x_param.suggest(&mut trial).unwrap();
n_param.suggest(&mut trial).unwrap();
let params = trial.params();
assert_eq!(params.len(), 2);
}
#[test]
fn trial_debug_format() {
let param = FloatParam::new(0.0, 1.0);
let mut trial = super::Trial::new(42);
param.suggest(&mut trial).unwrap();
let debug_str = format!("{trial:?}");
assert!(debug_str.contains("Trial"));
assert!(debug_str.contains("42"));
assert!(debug_str.contains("has_sampler"));
}
#[test]
fn distributions_access() {
let x_param = FloatParam::new(0.0, 1.0);
let n_param = IntParam::new(1, 10);
let opt_param = CategoricalParam::new(vec!["a", "b", "c"]);
let mut trial = super::Trial::new(0);
x_param.suggest(&mut trial).unwrap();
n_param.suggest(&mut trial).unwrap();
opt_param.suggest(&mut trial).unwrap();
let dists = trial.distributions();
assert_eq!(dists.len(), 3);
}
#[test]
fn multiple_parameters_independent_caching() {
let x_param = FloatParam::new(0.0, 1.0);
let y_param = FloatParam::new(0.0, 1.0);
let n_param = IntParam::new(1, 10);
let opt_param = CategoricalParam::new(vec!["a", "b"]);
let mut trial = super::Trial::new(0);
let x = x_param.suggest(&mut trial).unwrap();
let y = y_param.suggest(&mut trial).unwrap();
let n = n_param.suggest(&mut trial).unwrap();
let opt = opt_param.suggest(&mut trial).unwrap();
assert_eq!(x, x_param.suggest(&mut trial).unwrap());
assert_eq!(y, y_param.suggest(&mut trial).unwrap());
assert_eq!(n, n_param.suggest(&mut trial).unwrap());
assert_eq!(opt, opt_param.suggest(&mut trial).unwrap());
}
#[test]
fn suggest_bool_multiple_parameters() {
let dropout_param = BoolParam::new();
let batchnorm_param = BoolParam::new();
let skip_param = BoolParam::new();
let mut trial = super::Trial::new(0);
let a = dropout_param.suggest(&mut trial).unwrap();
let b = batchnorm_param.suggest(&mut trial).unwrap();
let c = skip_param.suggest(&mut trial).unwrap();
assert_eq!(a, dropout_param.suggest(&mut trial).unwrap());
assert_eq!(b, batchnorm_param.suggest(&mut trial).unwrap());
assert_eq!(c, skip_param.suggest(&mut trial).unwrap());
}
#[test]
fn param_name() {
let param = FloatParam::new(0.0, 1.0).name("learning_rate");
let mut trial = super::Trial::new(0);
param.suggest(&mut trial).unwrap();
let labels = trial.param_labels();
let label = labels.values().next().unwrap();
assert_eq!(label, "learning_rate");
}
#[test]
fn step_float_snaps_to_grid() {
let param = FloatParam::new(0.0, 1.0).step(0.25);
let mut trial = super::Trial::new(0);
let x = param.suggest(&mut trial).unwrap();
let valid_values = [0.0, 0.25, 0.5, 0.75, 1.0];
let is_valid = valid_values.iter().any(|&v| (x - v).abs() < 1e-10);
assert!(is_valid, "stepped float {x} should snap to grid");
}
#[test]
fn step_int_snaps_to_grid() {
let param = IntParam::new(0, 100).step(25);
let mut trial = super::Trial::new(0);
let n = param.suggest(&mut trial).unwrap();
assert!(
n % 25 == 0 && (0..=100).contains(&n),
"stepped int {n} should snap to grid"
);
}
#[test]
fn int_bounds_with_low_equals_high() {
let mut trial = super::Trial::new(0);
let n_param = IntParam::new(5, 5);
let n = n_param.suggest(&mut trial).unwrap();
assert_eq!(n, 5);
let x_param = FloatParam::new(3.0, 3.0);
let x = x_param.suggest(&mut trial).unwrap();
assert_eq!(x, 3.0);
}
#[test]
fn single_value_float_range() {
let param = FloatParam::new(4.2, 4.2);
let mut trial = super::Trial::new(0);
let x = param.suggest(&mut trial).unwrap();
assert!(
(x - 4.2).abs() < f64::EPSILON,
"single-value range should return that value"
);
}
}