multivariate_optimization/
splitter.rs

1//! Randomly assign indices to a fixed number of groups (to split work into
2//! smaller parts).
3
4use rand::{seq::SliceRandom, Rng};
5
6use std::borrow::BorrowMut;
7use std::ops::Deref;
8
9/// Memorizes random assignment of indices to groups.
10#[derive(Debug)]
11pub struct Splitter {
12    assignments: Box<[usize]>,
13    groups: Box<[Vec<usize>]>,
14}
15
16impl Splitter {
17    /// Randomly assign indices to fixed number of groups.
18    ///
19    /// Create a `Splitter` struct, by randomly assigning indices (from
20    /// `0..source_len`) to a fixed number of groups (`group_count`).
21    /// The returned struct provides access to the created groups (containing
22    /// their assigned indices, see [`groups`]) and allows merging of iterators
23    /// (see [`merge`]).
24    /// If `group_count` is smaller than `source_len`, the number of groups is
25    /// set (i.e. limited) to `input_len`.
26    ///
27    /// [`groups`]: Self::groups
28    /// [`merge`]: Self::merge
29    pub fn new<R: Rng + ?Sized>(rng: &mut R, source_len: usize, group_count: usize) -> Self {
30        assert_ne!(group_count, 0, "group_count must be positive");
31        let group_count = group_count.min(source_len);
32        let mut assignments: Box<[usize]> = (0..source_len)
33            .into_iter()
34            .map(|i| i % group_count)
35            .collect();
36        let mut groups: Box<[Vec<usize>]> = (0..group_count).map(|_| Vec::<usize>::new()).collect();
37        if source_len > 0 {
38            assignments.shuffle(rng);
39            for group in groups.iter_mut() {
40                group.reserve((source_len - 1) / group_count + 1);
41            }
42            for (source_idx, assignment) in assignments.iter().copied().enumerate() {
43                groups[assignment].push(source_idx);
44            }
45        }
46        Self {
47            assignments,
48            groups,
49        }
50    }
51    /// Return slice containing index of assigned group for each original index.
52    pub fn assignments(&self) -> &[usize] {
53        &self.assignments
54    }
55    /// Return slice containing indices of created groups.
56    pub fn groups(&self) -> &[impl Deref<Target = [usize]>] {
57        &self.groups
58    }
59    /// Merge iterators returning results for each group to a single iterator.
60    pub fn merge<'a, T, I, R>(&'a self, mut results: R) -> impl 'a + Iterator<Item = T>
61    where
62        I: Iterator<Item = T> + 'a,
63        R: BorrowMut<[I]> + 'a,
64    {
65        self.assignments.iter().copied().map(move |assignment| {
66            results.borrow_mut()[assignment]
67                .next()
68                .unwrap_or_else(|| panic!("iterator for group #{} exhausted", assignment))
69        })
70    }
71}
72
73#[cfg(test)]
74mod tests {
75    use super::Splitter;
76    use rand::thread_rng;
77    #[test]
78    fn test_internal() {
79        let mut rng = thread_rng();
80        let c = Splitter::new(&mut rng, 100, 3);
81        let group_count = c.groups.len();
82        assert_eq!(group_count, 3);
83        for group in 0..group_count {
84            let group_size = c.groups[group].len();
85            assert!(group_size >= 33 && group_size <= 34);
86            for element in 0..group_size {
87                assert_eq!(c.assignments[c.groups[group][element]], group);
88            }
89        }
90    }
91    #[test]
92    fn test_run() {
93        let mut rng = thread_rng();
94        let specimens: Vec<char> = vec!['A', 'B', 'C', 'D', 'E'];
95        let c = Splitter::new(&mut rng, specimens.len(), 2);
96        let parts = c
97            .groups()
98            .iter()
99            .map(|group| group.iter().map(|idx| specimens[*idx]))
100            .collect::<Box<[_]>>();
101        let merged: Vec<char> = c.merge(parts).collect();
102        assert_eq!(specimens, merged);
103    }
104}