multivariate_optimization/
splitter.rs1use rand::{seq::SliceRandom, Rng};
5
6use std::borrow::BorrowMut;
7use std::ops::Deref;
8
9#[derive(Debug)]
11pub struct Splitter {
12 assignments: Box<[usize]>,
13 groups: Box<[Vec<usize>]>,
14}
15
16impl Splitter {
17 pub fn new<R: Rng + ?Sized>(rng: &mut R, source_len: usize, group_count: usize) -> Self {
30 assert_ne!(group_count, 0, "group_count must be positive");
31 let group_count = group_count.min(source_len);
32 let mut assignments: Box<[usize]> = (0..source_len)
33 .into_iter()
34 .map(|i| i % group_count)
35 .collect();
36 let mut groups: Box<[Vec<usize>]> = (0..group_count).map(|_| Vec::<usize>::new()).collect();
37 if source_len > 0 {
38 assignments.shuffle(rng);
39 for group in groups.iter_mut() {
40 group.reserve((source_len - 1) / group_count + 1);
41 }
42 for (source_idx, assignment) in assignments.iter().copied().enumerate() {
43 groups[assignment].push(source_idx);
44 }
45 }
46 Self {
47 assignments,
48 groups,
49 }
50 }
51 pub fn assignments(&self) -> &[usize] {
53 &self.assignments
54 }
55 pub fn groups(&self) -> &[impl Deref<Target = [usize]>] {
57 &self.groups
58 }
59 pub fn merge<'a, T, I, R>(&'a self, mut results: R) -> impl 'a + Iterator<Item = T>
61 where
62 I: Iterator<Item = T> + 'a,
63 R: BorrowMut<[I]> + 'a,
64 {
65 self.assignments.iter().copied().map(move |assignment| {
66 results.borrow_mut()[assignment]
67 .next()
68 .unwrap_or_else(|| panic!("iterator for group #{} exhausted", assignment))
69 })
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::Splitter;
76 use rand::thread_rng;
77 #[test]
78 fn test_internal() {
79 let mut rng = thread_rng();
80 let c = Splitter::new(&mut rng, 100, 3);
81 let group_count = c.groups.len();
82 assert_eq!(group_count, 3);
83 for group in 0..group_count {
84 let group_size = c.groups[group].len();
85 assert!(group_size >= 33 && group_size <= 34);
86 for element in 0..group_size {
87 assert_eq!(c.assignments[c.groups[group][element]], group);
88 }
89 }
90 }
91 #[test]
92 fn test_run() {
93 let mut rng = thread_rng();
94 let specimens: Vec<char> = vec!['A', 'B', 'C', 'D', 'E'];
95 let c = Splitter::new(&mut rng, specimens.len(), 2);
96 let parts = c
97 .groups()
98 .iter()
99 .map(|group| group.iter().map(|idx| specimens[*idx]))
100 .collect::<Box<[_]>>();
101 let merged: Vec<char> = c.merge(parts).collect();
102 assert_eq!(specimens, merged);
103 }
104}