1#![allow(clippy::needless_doctest_main)]
2#![cfg_attr(test, feature(unboxed_closures))]
3#![cfg_attr(test, feature(fn_traits))]
4#![warn(missing_debug_implementations)]
5#![warn(missing_docs)]
6
7use blanket::blanket;
18use streaming_iterator::StreamingIterator;
19
20#[blanket(derive(Ref, Rc, Arc, Mut, Box))]
24pub trait Optimizer {
25 type Point;
27
28 fn best_point(&self) -> Self::Point;
30}
31
32pub trait StreamingIteratorExt: StreamingIterator {
34 fn last(&mut self) -> Option<&Self::Item> {
39 while !self.is_done() {
40 self.advance()
41 }
42 (*self).get()
43 }
44}
45
46pub trait OptimizerExt: StreamingIteratorExt {
49 fn argmin(&mut self) -> Option<<Self::Item as Optimizer>::Point>
55 where
56 Self::Item: Optimizer,
57 {
58 self.last().map(|x| x.best_point())
59 }
60}
61
62impl<T> StreamingIteratorExt for T where T: StreamingIterator {}
63impl<T> OptimizerExt for T where T: StreamingIterator {}
64
65pub mod prelude {
70 pub use streaming_iterator::StreamingIterator;
71
72 pub use super::{Optimizer, OptimizerExt, StreamingIteratorExt};
73}
74
75#[cfg(test)]
76mod tests {
77 use std::fmt::Debug;
87
88 use replace_with::replace_with_or_abort;
89 use serde::{Deserialize, Serialize};
90 use static_assertions::assert_obj_safe;
91
92 use crate::prelude::*;
93
94 assert_obj_safe!(Optimizer<Point = ()>);
95
96 fn mock_obj_func(x: usize) -> usize {
97 x + 1
98 }
99
100 macro_rules! mock_optimizer {
101 ( $id:ident ) => {
102 paste::paste! {
103 #[derive(Clone, Debug, Serialize, Deserialize)]
104 struct [< MockOptimizer $id >]<F> {
105 obj_func: F,
106 state: usize,
107 }
108
109 impl<F> [< MockOptimizer $id >]<F> {
110 fn new(obj_func: F) -> Self {
111 Self { obj_func, state: 0 }
112 }
113
114 fn evaluation(&self) -> usize
115 where
116 F: Fn(usize) -> usize
117 {
118 (self.obj_func)(self.state)
119 }
120 }
121
122 impl<F> StreamingIterator for [< MockOptimizer $id >]<F>
123 where
124 F: Fn(usize) -> usize
125 {
126 type Item = Self;
127
128 fn advance(&mut self) {
129 self.state += self.evaluation()
130 }
131
132 fn get(&self) -> Option<&Self::Item> {
133 Some(self)
134 }
135 }
136
137 impl<P> Optimizer for [< MockOptimizer $id >]<P> {
138 type Point = usize;
139
140 fn best_point(&self) -> Self::Point {
141 self.state
142 }
143 }
144 }
145 };
146 }
147
148 mock_optimizer!(A);
149 mock_optimizer!(B);
150
151 #[derive(Clone, Debug, Serialize, Deserialize)]
152 struct MaxSteps<I> {
153 max_i: usize,
154 i: usize,
155 it: I,
156 }
157
158 #[derive(Clone, Debug, Serialize, Deserialize)]
159 struct MaxStepsConfig(usize);
160
161 impl MaxStepsConfig {
162 fn start<I>(self, it: I) -> MaxSteps<I> {
163 MaxSteps {
164 i: 0,
165 max_i: self.0,
166 it,
167 }
168 }
169 }
170
171 impl<I> StreamingIterator for MaxSteps<I>
172 where
173 I: StreamingIterator,
174 {
175 type Item = I::Item;
176
177 fn advance(&mut self) {
178 self.it.advance();
179 self.i += 1;
180 }
181
182 fn get(&self) -> Option<&Self::Item> {
183 self.it.get()
184 }
185
186 fn is_done(&self) -> bool {
187 self.it.is_done() || self.i >= self.max_i
188 }
189 }
190
191 #[test]
192 fn optimizers_should_be_easily_comparable() {
193 type BoxedOptimizer<A> = Box<dyn StreamingIterator<Item = dyn Optimizer<Point = A>>>;
194
195 fn best_optimizer<A, B, F, I>(obj_func: F, optimizers: I) -> usize
196 where
197 B: Ord,
198 F: Fn(A) -> B,
199 I: IntoIterator<Item = BoxedOptimizer<A>>,
200 {
201 optimizers
202 .into_iter()
203 .enumerate()
204 .map(|(i, mut o)| {
205 let o = o.nth(10).unwrap();
206 (obj_func(o.best_point()), i)
207 })
208 .min()
209 .expect("`optimizers` should be non-empty")
210 .1
211 }
212
213 best_optimizer(
214 mock_obj_func,
215 [
216 Box::new(
217 MockOptimizerA::new(mock_obj_func)
218 .map_ref(|x| x as &dyn Optimizer<Point = usize>),
219 ) as BoxedOptimizer<usize>,
220 Box::new(
221 MockOptimizerB::new(mock_obj_func)
222 .map_ref(|x| x as &dyn Optimizer<Point = usize>),
223 ) as BoxedOptimizer<usize>,
224 ],
225 );
226 }
227
228 #[test]
229 fn parallel_optimization_runs_should_be_easy() {
230 use std::thread::spawn;
231
232 fn parallel<A, O, F>(start: F)
233 where
234 A: Send + 'static,
235 O: StreamingIterator + Send + 'static,
236 O::Item: Optimizer<Point = A>,
237 F: Fn() -> O,
238 {
239 let o1 = start();
240 let o2 = start();
241 let handler1 = spawn(move || MaxStepsConfig(10).start(o1).argmin());
242 let handler2 = spawn(move || MaxStepsConfig(10).start(o2).argmin());
243 handler1.join().unwrap();
244 handler2.join().unwrap();
245 }
246
247 parallel(|| MockOptimizerA::new(mock_obj_func));
248 }
249
250 #[test]
251 fn examining_state_and_corresponding_evaluations_should_be_easy() {
252 MockOptimizerA::new(mock_obj_func)
257 .inspect(|o| println!("state: {:?}, evaluation: {:?}", o.state, o.evaluation()))
258 .nth(10);
259 }
260
261 #[test]
262 fn optimizers_should_be_able_to_restart_automatically() {
263 trait Restart {
268 fn restart(&mut self);
269 }
270
271 impl<P> Restart for MockOptimizerA<P> {
272 fn restart(&mut self) {
273 replace_with_or_abort(self, |this| MockOptimizerA::new(this.obj_func))
274 }
275 }
276
277 impl<I> Restart for MaxSteps<I>
278 where
279 I: Restart,
280 {
281 fn restart(&mut self) {
282 replace_with_or_abort(self, |this| {
283 let mut it = this.it;
284 it.restart();
285 MaxStepsConfig(this.max_i).start(it)
286 })
287 }
288 }
289
290 struct Restarter<I> {
291 max_restarts: usize,
292 restarts: usize,
293 it: I,
294 }
295
296 struct RestarterConfig {
297 max_restarts: usize,
298 }
299
300 impl RestarterConfig {
301 fn start<I>(self, it: I) -> Restarter<I> {
302 Restarter {
303 max_restarts: self.max_restarts,
304 restarts: 0,
305 it,
306 }
307 }
308 }
309
310 impl<I> StreamingIterator for Restarter<I>
311 where
312 I: StreamingIterator + Restart,
313 {
314 type Item = I::Item;
315
316 fn advance(&mut self) {
317 if self.restarts < self.max_restarts && self.it.is_done() {
318 self.restarts += 1;
319 self.it.restart();
320 } else {
321 self.it.advance()
322 }
323 }
324
325 fn get(&self) -> Option<&Self::Item> {
326 self.it.get()
327 }
328 }
329
330 let _ = RestarterConfig { max_restarts: 10 }
331 .start(MaxStepsConfig(10).start(MockOptimizerA::new(mock_obj_func)))
332 .nth(100);
333 }
334
335 #[test]
340 fn dynamic_optimizers_should_be_partially_runable() {
341 #[derive(Clone, Debug, Serialize, Deserialize)]
342 enum DynOptimizer<F> {
343 A(MockOptimizerA<F>),
344 B(MockOptimizerB<F>),
345 }
346
347 impl<F> StreamingIterator for DynOptimizer<F>
348 where
349 F: Fn(usize) -> usize,
350 {
351 type Item = Self;
352
353 fn advance(&mut self) {
354 match self {
355 Self::A(x) => x.advance(),
356 Self::B(x) => x.advance(),
357 }
358 }
359
360 fn get(&self) -> Option<&Self::Item> {
361 Some(self)
362 }
363 }
364
365 impl<F> Optimizer for DynOptimizer<F> {
366 type Point = usize;
367
368 fn best_point(&self) -> Self::Point {
369 match self {
370 Self::A(x) => x.best_point(),
371 Self::B(x) => x.best_point(),
372 }
373 }
374 }
375
376 #[derive(Clone, Debug, Serialize, Deserialize)]
377 struct MockObjFunc;
378
379 impl FnOnce<(usize,)> for MockObjFunc {
380 type Output = usize;
381 extern "rust-call" fn call_once(self, args: (usize,)) -> Self::Output {
382 mock_obj_func(args.0)
383 }
384 }
385
386 impl FnMut<(usize,)> for MockObjFunc {
387 extern "rust-call" fn call_mut(&mut self, args: (usize,)) -> Self::Output {
388 mock_obj_func(args.0)
389 }
390 }
391
392 impl Fn<(usize,)> for MockObjFunc {
393 extern "rust-call" fn call(&self, args: (usize,)) -> Self::Output {
394 mock_obj_func(args.0)
395 }
396 }
397
398 let mut o = MaxStepsConfig(10).start(DynOptimizer::A(MockOptimizerA::new(MockObjFunc)));
399 o.next();
400 let store = serde_json::to_string(&o).unwrap();
401 o = serde_json::from_str(&store).unwrap();
402 o.next();
403 o.get().unwrap().best_point();
404 }
405}