hyperopt/optimizer/
trial.rs1use std::{
2 collections::{btree_set::Iter, BTreeSet},
3 fmt::Debug,
4 iter::Copied,
5};
6
7#[derive(Debug, Ord, PartialOrd, Eq, PartialEq)]
11pub struct Trial<P, M> {
12 pub metric: M,
13 pub parameter: P,
14}
15
16#[derive(Debug)]
28pub struct Trials<P, M> {
29 by_metric: BTreeSet<Trial<P, M>>,
30 by_parameter: BTreeSet<P>,
31}
32
33impl<P, M> Trials<P, M> {
34 pub const fn new() -> Self {
36 Self {
37 by_metric: BTreeSet::new(),
38 by_parameter: BTreeSet::new(),
39 }
40 }
41
42 pub fn len(&self) -> usize {
43 self.by_parameter.len()
44 }
45
46 pub fn contains(&self, parameter: &P) -> bool
47 where
48 P: Ord,
49 {
50 self.by_parameter.contains(parameter)
51 }
52
53 pub fn iter_parameters(&self) -> Copied<Iter<P>>
55 where
56 P: Copy,
57 {
58 self.by_parameter.iter().copied()
59 }
60
61 pub fn insert(&mut self, trial: Trial<P, M>) -> bool
69 where
70 P: Copy + Ord,
71 M: Ord,
72 {
73 if self.by_parameter.insert(trial.parameter) {
74 assert!(self.by_metric.insert(trial));
75 assert_eq!(self.by_parameter.len(), self.by_metric.len());
76 true
77 } else {
78 false
79 }
80 }
81
82 pub fn best(&self) -> Option<&Trial<P, M>>
84 where
85 P: Ord,
86 M: Ord,
87 {
88 self.by_metric.first()
89 }
90
91 pub fn worst(&self) -> Option<&Trial<P, M>>
93 where
94 P: Ord,
95 M: Ord,
96 {
97 self.by_metric.last()
98 }
99
100 pub fn pop_best(&mut self) -> Option<Trial<P, M>>
102 where
103 P: Ord,
104 M: Ord,
105 {
106 let best_trial = self.by_metric.pop_first()?;
107 self.remove_parameter(&best_trial.parameter);
108 Some(best_trial)
109 }
110
111 pub fn pop_worst(&mut self) -> Option<Trial<P, M>>
113 where
114 P: Ord,
115 M: Ord,
116 {
117 let worst_trial = self.by_metric.pop_last()?;
118 self.remove_parameter(&worst_trial.parameter);
119 Some(worst_trial)
120 }
121
122 fn remove_parameter(&mut self, parameter: &P)
124 where
125 P: Ord,
126 M: Ord,
127 {
128 assert!(self.by_parameter.remove(parameter));
129 assert_eq!(self.by_parameter.len(), self.by_metric.len());
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use crate::optimizer::trial::{Trial, Trials};
136
137 #[test]
138 fn ordering_ok() {
139 assert!(
140 Trial {
141 metric: 42,
142 parameter: 1,
143 } < Trial {
144 metric: 43,
145 parameter: 0,
146 }
147 );
148 }
149
150 #[test]
151 fn trials_ok() {
152 let mut trials = Trials::new();
153
154 assert!(trials.insert(Trial {
155 metric: 42,
156 parameter: 1,
157 }));
158 assert_eq!(trials.len(), 1);
159 assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [1]);
160
161 assert!(trials.insert(Trial {
162 metric: 41,
163 parameter: 2,
164 }));
165 assert_eq!(trials.len(), 2);
166 assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [1, 2]);
167
168 assert!(!trials.insert(Trial {
169 metric: 41,
170 parameter: 2,
171 }));
172
173 assert_eq!(
174 trials.pop_worst(),
175 Some(Trial {
176 metric: 42,
177 parameter: 1
178 })
179 );
180 assert_eq!(trials.len(), 1);
181 assert_eq!(trials.iter_parameters().collect::<Vec<_>>(), [2]);
182 }
183}