param_opt/selector/
grid_search.rs

1use std::{num::NonZeroUsize, ops::Range};
2
3use itertools::{Itertools, MultiProduct};
4
5#[derive(Debug, Clone)]
6pub struct GridSearch {
7    cartesian_product: MultiProduct<Range<usize>>,
8}
9impl GridSearch {
10    pub fn new(parameter_spaces: impl Iterator<Item = NonZeroUsize>) -> Self {
11        let mut iterators = vec![];
12        for space in parameter_spaces {
13            let range = 0..space.get();
14            iterators.push(range);
15        }
16        let cartesian_product = iterators.into_iter().multi_cartesian_product();
17
18        Self { cartesian_product }
19    }
20}
21impl Iterator for GridSearch {
22    type Item = Vec<usize>;
23    fn next(&mut self) -> Option<Self::Item> {
24        self.cartesian_product.next()
25    }
26    fn count(self) -> usize {
27        self.cartesian_product.count()
28    }
29    fn size_hint(&self) -> (usize, Option<usize>) {
30        self.cartesian_product.size_hint()
31    }
32    fn last(self) -> Option<Self::Item> {
33        self.cartesian_product.last()
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use std::num::NonZeroUsize;
40
41    use super::*;
42
43    #[test]
44    fn test_grid_search() {
45        let spaces = [1, 2, 3];
46        let spaces = spaces.into_iter().map(|x| NonZeroUsize::new(x).unwrap());
47        let mut grid_search = GridSearch::new(spaces);
48        assert_eq!(grid_search.next(), Some(vec![0, 0, 0]));
49        assert_eq!(grid_search.next(), Some(vec![0, 0, 1]));
50        assert_eq!(grid_search.next(), Some(vec![0, 0, 2]));
51        assert_eq!(grid_search.next(), Some(vec![0, 1, 0]));
52        assert_eq!(grid_search.next(), Some(vec![0, 1, 1]));
53        assert_eq!(grid_search.next(), Some(vec![0, 1, 2]));
54        assert_eq!(grid_search.next(), None);
55    }
56}