use std::{
collections::{btree_set::Iter, BTreeSet},
fmt::Debug,
iter::Copied,
};
#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
pub struct Trial<P, M> {
pub metric: M,
pub parameter: P,
}
#[derive(Debug)]
pub struct Trials<P, M> {
by_metric: BTreeSet<Trial<P, M>>,
by_parameter: BTreeSet<P>,
}
impl<P, M> Trials<P, M> {
pub const fn new() -> Self {
Self {
by_metric: BTreeSet::new(),
by_parameter: BTreeSet::new(),
}
}
pub fn len(&self) -> usize {
self.by_parameter.len()
}
pub fn contains(&self, parameter: &P) -> bool
where
P: Ord,
{
self.by_parameter.contains(parameter)
}
pub fn iter_parameters(&self) -> Copied<Iter<P>>
where
P: Copy,
{
self.by_parameter.iter().copied()
}
pub fn insert(&mut self, trial: Trial<P, M>) -> bool
where
P: Copy + Ord,
M: Ord,
{
if self.by_parameter.insert(trial.parameter) {
assert!(self.by_metric.insert(trial));
assert_eq!(self.by_parameter.len(), self.by_metric.len());
true
} else {
false
}
}
pub fn best(&self) -> Option<&Trial<P, M>>
where
P: Ord,
M: Ord,
{
self.by_metric.first()
}
pub fn worst(&self) -> Option<&Trial<P, M>>
where
P: Ord,
M: Ord,
{
self.by_metric.last()
}
pub fn pop_best(&mut self) -> Option<Trial<P, M>>
where
P: Ord,
M: Ord,
{
let best_trial = self.by_metric.pop_first()?;
self.remove_parameter(&best_trial.parameter);
Some(best_trial)
}
pub fn pop_worst(&mut self) -> Option<Trial<P, M>>
where
P: Ord,
M: Ord,
{
let worst_trial = self.by_metric.pop_last()?;
self.remove_parameter(&worst_trial.parameter);
Some(worst_trial)
}
fn remove_parameter(&mut self, parameter: &P)
where
P: Ord,
M: Ord,
{
assert!(self.by_parameter.remove(parameter));
assert_eq!(self.by_parameter.len(), self.by_metric.len());
}
}
#[cfg(test)]
mod tests {
use crate::optimizer::trial::{Trial, Trials};
#[test]
fn ordering_ok() {
assert!(
Trial {
metric: 42,
parameter: 1,
} < Trial {
metric: 43,
parameter: 0,
}
);
}
#[test]
fn trials_ok() {
let mut trials = Trials::new();
assert!(trials.insert(Trial {
metric: 42,
parameter: 1,
}));
assert_eq!(trials.len(), 1);
assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [1]);
assert!(trials.insert(Trial {
metric: 41,
parameter: 2,
}));
assert_eq!(trials.len(), 2);
assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [1, 2]);
assert!(!trials.insert(Trial {
metric: 41,
parameter: 2,
}));
assert_eq!(
trials.pop_worst(),
Some(Trial {
metric: 42,
parameter: 1
})
);
assert_eq!(trials.len(), 1);
assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [2]);
}
}