use core::any::Any;
use std::collections::{HashMap, VecDeque};
use std::sync::Arc;
use parking_lot::{Mutex, RwLock};
use crate::param::ParamValue;
use crate::parameter::ParamId;
use crate::pruner::{NopPruner, Pruner};
use crate::sampler::random::RandomSampler;
use crate::sampler::{CompletedTrial, Sampler};
use crate::trial::Trial;
use crate::types::{Direction, TrialState};
mod analysis;
mod builder;
mod export;
mod iter;
mod optimize;
mod persistence;
#[cfg(feature = "async")]
mod async_impl;
pub use builder::StudyBuilder;
#[cfg(feature = "serde")]
pub use persistence::StudySnapshot;
pub struct Study<V = f64>
where
V: PartialOrd,
{
pub(crate) direction: Direction,
pub(crate) sampler: Arc<dyn Sampler>,
pub(crate) pruner: Arc<dyn Pruner>,
pub(crate) storage: Arc<dyn crate::storage::Storage<V>>,
pub(crate) trial_factory: Option<Arc<dyn Fn(u64) -> Trial + Send + Sync>>,
pub(crate) enqueued_params: Arc<Mutex<VecDeque<HashMap<ParamId, ParamValue>>>>,
}
impl<V> Study<V>
where
V: PartialOrd,
{
#[must_use]
pub fn new(direction: Direction) -> Self
where
V: Send + Sync + 'static,
{
Self::with_sampler(direction, RandomSampler::new())
}
#[must_use]
pub fn builder() -> StudyBuilder<V> {
StudyBuilder::new()
}
#[must_use]
pub fn minimize(sampler: impl Sampler + 'static) -> Self
where
V: Send + Sync + 'static,
{
Self::with_sampler(Direction::Minimize, sampler)
}
#[must_use]
pub fn maximize(sampler: impl Sampler + 'static) -> Self
where
V: Send + Sync + 'static,
{
Self::with_sampler(Direction::Maximize, sampler)
}
pub fn with_sampler(direction: Direction, sampler: impl Sampler + 'static) -> Self
where
V: Send + Sync + 'static,
{
Self::with_sampler_and_storage(
direction,
sampler,
crate::storage::MemoryStorage::<V>::new(),
)
}
pub(crate) fn make_trial_factory(
sampler: &Arc<dyn Sampler>,
storage: &Arc<dyn crate::storage::Storage<V>>,
pruner: &Arc<dyn Pruner>,
) -> Option<Arc<dyn Fn(u64) -> Trial + Send + Sync>>
where
V: 'static,
{
let trials_arc = storage.trials_arc();
let any_ref: &dyn Any = trials_arc;
let f64_trials: Option<&Arc<RwLock<Vec<CompletedTrial<f64>>>>> = any_ref.downcast_ref();
f64_trials.map(|trials| {
let sampler = Arc::clone(sampler);
let trials = Arc::clone(trials);
let pruner = Arc::clone(pruner);
let factory: Arc<dyn Fn(u64) -> Trial + Send + Sync> = Arc::new(move |id| {
Trial::with_sampler(
id,
Arc::clone(&sampler),
Arc::clone(&trials),
Arc::clone(&pruner),
)
});
factory
})
}
pub fn with_sampler_and_storage(
direction: Direction,
sampler: impl Sampler + 'static,
storage: impl crate::storage::Storage<V> + 'static,
) -> Self
where
V: 'static,
{
let sampler: Arc<dyn Sampler> = Arc::new(sampler);
let pruner: Arc<dyn Pruner> = Arc::new(NopPruner);
let storage: Arc<dyn crate::storage::Storage<V>> = Arc::new(storage);
let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner);
Self {
direction,
sampler,
pruner,
storage,
trial_factory,
enqueued_params: Arc::new(Mutex::new(VecDeque::new())),
}
}
#[must_use]
pub fn direction(&self) -> Direction {
self.direction
}
pub fn with_sampler_and_pruner(
direction: Direction,
sampler: impl Sampler + 'static,
pruner: impl Pruner + 'static,
) -> Self
where
V: Send + Sync + 'static,
{
let sampler: Arc<dyn Sampler> = Arc::new(sampler);
let pruner: Arc<dyn Pruner> = Arc::new(pruner);
let storage: Arc<dyn crate::storage::Storage<V>> =
Arc::new(crate::storage::MemoryStorage::<V>::new());
let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner);
Self {
direction,
sampler,
pruner,
storage,
trial_factory,
enqueued_params: Arc::new(Mutex::new(VecDeque::new())),
}
}
pub fn set_sampler(&mut self, sampler: impl Sampler + 'static)
where
V: 'static,
{
self.sampler = Arc::new(sampler);
self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner);
}
pub fn set_pruner(&mut self, pruner: impl Pruner + 'static)
where
V: 'static,
{
self.pruner = Arc::new(pruner);
self.trial_factory = Self::make_trial_factory(&self.sampler, &self.storage, &self.pruner);
}
#[must_use]
pub fn pruner(&self) -> &dyn Pruner {
&*self.pruner
}
pub fn enqueue(&self, params: HashMap<ParamId, ParamValue>) {
self.enqueued_params.lock().push_back(params);
}
#[cfg(feature = "tracing")]
pub(crate) fn best_id(&self, trials: &[CompletedTrial<V>]) -> Option<u64> {
let direction = self.direction;
trials
.iter()
.filter(|t| t.state == TrialState::Complete)
.max_by(|a, b| Self::compare_trials(a, b, direction))
.map(|t| t.id)
}
#[must_use]
pub fn n_enqueued(&self) -> usize {
self.enqueued_params.lock().len()
}
pub(crate) fn next_trial_id(&self) -> u64 {
self.storage.next_trial_id()
}
#[must_use]
pub fn create_trial(&self) -> Trial {
self.storage.refresh();
let id = self.next_trial_id();
let mut trial = if let Some(factory) = &self.trial_factory {
factory(id)
} else {
Trial::new(id)
};
if let Some(fixed_params) = self.enqueued_params.lock().pop_front() {
trial.set_fixed_params(fixed_params);
}
trial
}
pub fn complete_trial(&self, trial: Trial, value: V) {
let completed = trial.into_completed(value, TrialState::Complete);
self.storage.push(completed);
}
pub fn fail_trial(&self, mut trial: Trial, _error: impl ToString) {
trial.set_failed();
}
#[must_use]
pub fn ask(&self) -> Trial {
self.create_trial()
}
pub fn tell(&self, trial: Trial, value: core::result::Result<V, impl ToString>) {
match value {
Ok(v) => self.complete_trial(trial, v),
Err(e) => self.fail_trial(trial, e),
}
}
pub fn prune_trial(&self, trial: Trial)
where
V: Default,
{
let completed = trial.into_completed(V::default(), TrialState::Pruned);
self.storage.push(completed);
}
#[must_use]
pub fn trials(&self) -> Vec<CompletedTrial<V>>
where
V: Clone,
{
self.storage.trials_arc().read().clone()
}
#[must_use]
pub fn n_trials(&self) -> usize {
self.storage
.trials_arc()
.read()
.iter()
.filter(|t| t.state == TrialState::Complete)
.count()
}
#[must_use]
pub fn n_pruned_trials(&self) -> usize {
self.storage
.trials_arc()
.read()
.iter()
.filter(|t| t.state == TrialState::Pruned)
.count()
}
pub(crate) fn compare_trials(
a: &CompletedTrial<V>,
b: &CompletedTrial<V>,
direction: Direction,
) -> core::cmp::Ordering {
match (a.is_feasible(), b.is_feasible()) {
(true, false) => core::cmp::Ordering::Greater,
(false, true) => core::cmp::Ordering::Less,
(false, false) => {
let va: f64 = a.constraints.iter().map(|c| c.max(0.0)).sum();
let vb: f64 = b.constraints.iter().map(|c| c.max(0.0)).sum();
vb.partial_cmp(&va).unwrap_or(core::cmp::Ordering::Equal)
}
(true, true) => {
let ordering = a.value.partial_cmp(&b.value);
match direction {
Direction::Minimize => {
ordering.map_or(core::cmp::Ordering::Equal, core::cmp::Ordering::reverse)
}
Direction::Maximize => ordering.unwrap_or(core::cmp::Ordering::Equal),
}
}
}
}
}
impl<V: PartialOrd + Send + Sync + 'static> Study<V> {
pub fn with_sampler_pruner_and_storage(
direction: Direction,
sampler: impl Sampler + 'static,
pruner: impl Pruner + 'static,
storage: impl crate::storage::Storage<V> + 'static,
) -> Self {
let sampler: Arc<dyn Sampler> = Arc::new(sampler);
let pruner: Arc<dyn Pruner> = Arc::new(pruner);
let storage: Arc<dyn crate::storage::Storage<V>> = Arc::new(storage);
let trial_factory = Self::make_trial_factory(&sampler, &storage, &pruner);
Self {
direction,
sampler,
pruner,
storage,
trial_factory,
enqueued_params: Arc::new(Mutex::new(VecDeque::new())),
}
}
}
pub(super) fn is_trial_pruned<E: 'static>(e: &E) -> bool {
let any: &dyn Any = e;
if let Some(err) = any.downcast_ref::<crate::Error>() {
matches!(err, crate::Error::TrialPruned)
} else {
any.downcast_ref::<crate::error::TrialPruned>().is_some()
}
}