use core::sync::atomic::{AtomicU64, Ordering};
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use crate::distribution::Distribution;
use crate::param::ParamValue;
use crate::parameter::{ParamId, Parameter};
use crate::pruner::NopPruner;
use crate::sampler::random::RandomMultiObjectiveSampler;
use crate::sampler::{CompletedTrial, Sampler};
use crate::trial::{AttrValue, Trial};
use crate::types::{Direction, TrialState};
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct MultiObjectiveTrial {
pub id: u64,
pub params: HashMap<ParamId, ParamValue>,
pub distributions: HashMap<ParamId, Distribution>,
pub param_labels: HashMap<ParamId, String>,
pub values: Vec<f64>,
pub state: TrialState,
pub user_attrs: HashMap<String, AttrValue>,
#[cfg_attr(feature = "serde", serde(default))]
pub constraints: Vec<f64>,
}
impl MultiObjectiveTrial {
pub fn get<P: Parameter>(&self, param: &P) -> Option<P::Value> {
self.params.get(¶m.id()).map(|v| {
param
.cast_param_value(v)
.expect("parameter type mismatch: stored value incompatible with parameter")
})
}
#[must_use]
pub fn is_feasible(&self) -> bool {
self.constraints.iter().all(|&c| c <= 0.0)
}
#[must_use]
pub fn user_attr(&self, key: &str) -> Option<&AttrValue> {
self.user_attrs.get(key)
}
#[must_use]
pub fn user_attrs(&self) -> &HashMap<String, AttrValue> {
&self.user_attrs
}
}
pub trait MultiObjectiveSampler: Send + Sync {
fn sample(
&self,
distribution: &Distribution,
trial_id: u64,
history: &[MultiObjectiveTrial],
directions: &[Direction],
) -> ParamValue;
}
struct MoSamplerBridge {
inner: Arc<dyn MultiObjectiveSampler>,
history: Arc<RwLock<Vec<MultiObjectiveTrial>>>,
directions: Vec<Direction>,
}
impl Sampler for MoSamplerBridge {
fn sample(
&self,
distribution: &Distribution,
trial_id: u64,
_history: &[CompletedTrial],
) -> ParamValue {
let mo_history = self.history.read();
self.inner
.sample(distribution, trial_id, &mo_history, &self.directions)
}
}
pub struct MultiObjectiveStudy {
directions: Vec<Direction>,
sampler: Arc<dyn MultiObjectiveSampler>,
completed_trials: Arc<RwLock<Vec<MultiObjectiveTrial>>>,
next_trial_id: AtomicU64,
}
impl MultiObjectiveStudy {
#[must_use]
pub fn new(directions: Vec<Direction>) -> Self {
Self {
directions,
sampler: Arc::new(RandomMultiObjectiveSampler::new()),
completed_trials: Arc::new(RwLock::new(Vec::new())),
next_trial_id: AtomicU64::new(0),
}
}
#[must_use]
pub fn with_sampler(
directions: Vec<Direction>,
sampler: impl MultiObjectiveSampler + 'static,
) -> Self {
Self {
directions,
sampler: Arc::new(sampler),
completed_trials: Arc::new(RwLock::new(Vec::new())),
next_trial_id: AtomicU64::new(0),
}
}
#[must_use]
pub fn directions(&self) -> &[Direction] {
&self.directions
}
#[must_use]
pub fn n_objectives(&self) -> usize {
self.directions.len()
}
#[must_use]
pub fn n_trials(&self) -> usize {
self.completed_trials.read().len()
}
#[must_use]
pub fn trials(&self) -> Vec<MultiObjectiveTrial> {
self.completed_trials.read().clone()
}
#[must_use]
pub fn pareto_front(&self) -> Vec<MultiObjectiveTrial> {
let trials = self.completed_trials.read();
let complete: Vec<_> = trials
.iter()
.filter(|t| t.state == TrialState::Complete)
.collect();
if complete.is_empty() {
return Vec::new();
}
let values: Vec<Vec<f64>> = complete.iter().map(|t| t.values.clone()).collect();
let fronts = crate::pareto::fast_non_dominated_sort(&values, &self.directions);
if fronts.is_empty() {
return Vec::new();
}
fronts[0].iter().map(|&i| complete[i].clone()).collect()
}
fn create_trial(&self) -> Trial {
let id = self.next_trial_id.fetch_add(1, Ordering::SeqCst);
let bridge: Arc<dyn Sampler> = Arc::new(MoSamplerBridge {
inner: Arc::clone(&self.sampler),
history: Arc::clone(&self.completed_trials),
directions: self.directions.clone(),
});
let dummy_history: Arc<RwLock<Vec<CompletedTrial<f64>>>> =
Arc::new(RwLock::new(Vec::new()));
Trial::with_sampler(id, bridge, dummy_history, Arc::new(NopPruner))
}
fn complete_trial(&self, trial: Trial, values: Vec<f64>) {
let mo_trial = trial.into_multi_objective_trial(values, TrialState::Complete);
self.completed_trials.write().push(mo_trial);
}
fn fail_trial(trial: &mut Trial) {
trial.set_failed();
}
pub fn ask(&self) -> Trial {
self.create_trial()
}
pub fn tell(
&self,
mut trial: Trial,
result: core::result::Result<Vec<f64>, impl ToString>,
) -> crate::Result<()> {
if let Ok(values) = result {
if values.len() != self.directions.len() {
return Err(crate::Error::ObjectiveDimensionMismatch {
expected: self.directions.len(),
got: values.len(),
});
}
self.complete_trial(trial, values);
} else {
Self::fail_trial(&mut trial);
}
Ok(())
}
pub fn optimize<F, E>(&self, n_trials: usize, mut objective: F) -> crate::Result<()>
where
F: FnMut(&mut Trial) -> core::result::Result<Vec<f64>, E>,
E: ToString,
{
for _ in 0..n_trials {
let mut trial = self.create_trial();
match objective(&mut trial) {
Ok(values) => {
if values.len() != self.directions.len() {
return Err(crate::Error::ObjectiveDimensionMismatch {
expected: self.directions.len(),
got: values.len(),
});
}
self.complete_trial(trial, values);
}
Err(_) => {
Self::fail_trial(&mut trial);
}
}
}
let has_complete = self
.completed_trials
.read()
.iter()
.any(|t| t.state == TrialState::Complete);
if !has_complete {
return Err(crate::Error::NoCompletedTrials);
}
Ok(())
}
}