hyperopt/optimizer/
trial.rs

1use std::{
2    collections::{btree_set::Iter, BTreeSet},
3    fmt::Debug,
4    iter::Copied,
5};
6
7/// Single trial in the optimizer.
8///
9/// Note that trials are ordered first by metric, and then by tag.
10#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
11pub struct Trial<P, M> {
12    pub metric: M,
13    pub parameter: P,
14}
15
16/// Ordered collection of trials.
17///
18/// Here be dragons! 🐉 It basically maintains two inner collections:
19///
20/// - Set of trials (a pair of parameter and metric, ordered by metric): that allows tracking of
21///   the best (worst) trials
22/// - Set of parameters, ordered by parameter itself: that allows to estimate bandwidth for each trial
23///
24/// All this is for the sake of insertion and removal in `O(log n)` time.
25///
26/// The optimizer **should not** try the same parameter twice.
27#[derive(Debug)]
28pub struct Trials<P, M> {
29    by_metric: BTreeSet<Trial<P, M>>,
30    by_parameter: BTreeSet<P>,
31}
32
33impl<P, M> Trials<P, M> {
34    /// Instantiate a new empty trial collection.
35    pub const fn new() -> Self {
36        Self {
37            by_metric: BTreeSet::new(),
38            by_parameter: BTreeSet::new(),
39        }
40    }
41
42    pub fn len(&self) -> usize {
43        self.by_parameter.len()
44    }
45
46    pub fn contains(&self, parameter: &P) -> bool
47    where
48        P: Ord,
49    {
50        self.by_parameter.contains(parameter)
51    }
52
53    /// Iterate parameters of the trials in ascending order.
54    pub fn iter_parameters(&self) -> Copied<Iter<P>>
55    where
56        P: Copy,
57    {
58        self.by_parameter.iter().copied()
59    }
60
61    /// Push the trial to the collection.
62    ///
63    /// **Repetitive parameters will be ignored.**
64    ///
65    /// # Returns
66    ///
67    /// [`true`], if the trial was inserted, and [`false`] if it was ignored as repetitive.
68    pub fn insert(&mut self, trial: Trial<P, M>) -> bool
69    where
70        P: Copy + Ord,
71        M: Ord,
72    {
73        if self.by_parameter.insert(trial.parameter) {
74            assert!(self.by_metric.insert(trial));
75            assert_eq!(self.by_parameter.len(), self.by_metric.len());
76            true
77        } else {
78            false
79        }
80    }
81
82    /// Retrieve the best trial.
83    pub fn best(&self) -> Option<&Trial<P, M>>
84    where
85        P: Ord,
86        M: Ord,
87    {
88        self.by_metric.first()
89    }
90
91    /// Retrieve the worst trial.
92    pub fn worst(&self) -> Option<&Trial<P, M>>
93    where
94        P: Ord,
95        M: Ord,
96    {
97        self.by_metric.last()
98    }
99
100    /// Pop the best trial.
101    pub fn pop_best(&mut self) -> Option<Trial<P, M>>
102    where
103        P: Ord,
104        M: Ord,
105    {
106        let best_trial = self.by_metric.pop_first()?;
107        self.remove_parameter(&best_trial.parameter);
108        Some(best_trial)
109    }
110
111    /// Pop the worst trial.
112    pub fn pop_worst(&mut self) -> Option<Trial<P, M>>
113    where
114        P: Ord,
115        M: Ord,
116    {
117        let worst_trial = self.by_metric.pop_last()?;
118        self.remove_parameter(&worst_trial.parameter);
119        Some(worst_trial)
120    }
121
122    /// Remove the parameter and ensure the variants.
123    fn remove_parameter(&mut self, parameter: &P)
124    where
125        P: Ord,
126        M: Ord,
127    {
128        assert!(self.by_parameter.remove(parameter));
129        assert_eq!(self.by_parameter.len(), self.by_metric.len());
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use crate::optimizer::trial::{Trial, Trials};
136
137    #[test]
138    fn ordering_ok() {
139        assert!(
140            Trial {
141                metric: 42,
142                parameter: 1,
143            } < Trial {
144                metric: 43,
145                parameter: 0,
146            }
147        );
148    }
149
150    #[test]
151    fn trials_ok() {
152        let mut trials = Trials::new();
153
154        assert!(trials.insert(Trial {
155            metric: 42,
156            parameter: 1,
157        }));
158        assert_eq!(trials.len(), 1);
159        assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [1]);
160
161        assert!(trials.insert(Trial {
162            metric: 41,
163            parameter: 2,
164        }));
165        assert_eq!(trials.len(), 2);
166        assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [1, 2]);
167
168        assert!(!trials.insert(Trial {
169            metric: 41,
170            parameter: 2,
171        }));
172
173        assert_eq!(
174            trials.pop_worst(),
175            Some(Trial {
176                metric: 42,
177                parameter: 1
178            })
179        );
180        assert_eq!(trials.len(), 1);
181        assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [2]);
182    }
183}