param_opt/selector/
grid_search.rs1use std::{num::NonZeroUsize, ops::Range};
2
3use itertools::{Itertools, MultiProduct};
4
5#[derive(Debug, Clone)]
6pub struct GridSearch {
7 cartesian_product: MultiProduct<Range<usize>>,
8}
9impl GridSearch {
10 pub fn new(parameter_spaces: impl Iterator<Item = NonZeroUsize>) -> Self {
11 let mut iterators = vec![];
12 for space in parameter_spaces {
13 let range = 0..space.get();
14 iterators.push(range);
15 }
16 let cartesian_product = iterators.into_iter().multi_cartesian_product();
17
18 Self { cartesian_product }
19 }
20}
21impl Iterator for GridSearch {
22 type Item = Vec<usize>;
23 fn next(&mut self) -> Option<Self::Item> {
24 self.cartesian_product.next()
25 }
26 fn count(self) -> usize {
27 self.cartesian_product.count()
28 }
29 fn size_hint(&self) -> (usize, Option<usize>) {
30 self.cartesian_product.size_hint()
31 }
32 fn last(self) -> Option<Self::Item> {
33 self.cartesian_product.last()
34 }
35}
36
37#[cfg(test)]
38mod tests {
39 use std::num::NonZeroUsize;
40
41 use super::*;
42
43 #[test]
44 fn test_grid_search() {
45 let spaces = [1, 2, 3];
46 let spaces = spaces.into_iter().map(|x| NonZeroUsize::new(x).unwrap());
47 let mut grid_search = GridSearch::new(spaces);
48 assert_eq!(grid_search.next(), Some(vec![0, 0, 0]));
49 assert_eq!(grid_search.next(), Some(vec![0, 0, 1]));
50 assert_eq!(grid_search.next(), Some(vec![0, 0, 2]));
51 assert_eq!(grid_search.next(), Some(vec![0, 1, 0]));
52 assert_eq!(grid_search.next(), Some(vec![0, 1, 1]));
53 assert_eq!(grid_search.next(), Some(vec![0, 1, 2]));
54 assert_eq!(grid_search.next(), None);
55 }
56}