1#[must_use]
26pub fn balance211(n: usize, nthreads: usize) -> Vec<(usize, usize)> {
27 if nthreads == 0 {
28 return vec![];
29 }
30 let div = n / nthreads;
31 let rem = n % nthreads;
32
33 (0..nthreads)
34 .map(|i| {
35 let offset = if i < rem { (div + 1) * i } else { div * i + rem };
36 let count = if i < rem { div + 1 } else { div };
37 (offset, count)
38 })
39 .collect()
40}
41
42pub struct Balance211Iter {
44 ranges: Vec<(usize, usize)>,
45 current: usize,
46}
47
48impl Balance211Iter {
49 pub fn new(n: usize, nthreads: usize) -> Self {
51 Self { ranges: balance211(n, nthreads), current: 0 }
52 }
53}
54
55impl Iterator for Balance211Iter {
56 type Item = std::ops::Range<usize>;
57
58 fn next(&mut self) -> Option<Self::Item> {
59 if self.current >= self.ranges.len() {
60 return None;
61 }
62 let (offset, count) = self.ranges[self.current];
63 self.current += 1;
64 Some(offset..offset + count)
65 }
66}
67
68impl ExactSizeIterator for Balance211Iter {
69 fn len(&self) -> usize {
70 self.ranges.len() - self.current
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
80pub enum BatchSplitStrategy {
81 #[default]
83 Simple,
84 Equal,
86 SequenceAware,
88}
89
90#[must_use]
101pub fn split_batch(total: usize, num_workers: usize, strategy: BatchSplitStrategy) -> Vec<usize> {
102 if num_workers == 0 || total == 0 {
103 return vec![];
104 }
105
106 match strategy {
107 BatchSplitStrategy::Simple => {
108 let chunk_size = total / num_workers;
109 let mut chunks = vec![chunk_size; num_workers];
110 if let Some(last) = chunks.last_mut() {
112 *last += total % num_workers;
113 }
114 chunks
115 }
116 BatchSplitStrategy::Equal => {
117 balance211(total, num_workers).iter().map(|(_, count)| *count).collect()
119 }
120 BatchSplitStrategy::SequenceAware => {
121 balance211(total, num_workers).iter().map(|(_, count)| *count).collect()
123 }
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 use super::*;
130
131 #[test]
132 fn test_balance211_basic() {
133 let ranges = balance211(10, 3);
134 assert_eq!(ranges.len(), 3);
135 assert_eq!(ranges[0], (0, 4));
137 assert_eq!(ranges[1], (4, 3));
138 assert_eq!(ranges[2], (7, 3));
139 }
140
141 #[test]
142 fn test_balance211_even_division() {
143 let ranges = balance211(12, 4);
144 for (i, &(offset, count)) in ranges.iter().enumerate() {
146 assert_eq!(count, 3);
147 assert_eq!(offset, i * 3);
148 }
149 }
150
151 #[test]
152 fn test_balance211_empty() {
153 assert!(balance211(0, 4).iter().all(|&(_, c)| c == 0));
154 assert!(balance211(10, 0).is_empty());
155 }
156
157 #[test]
158 fn test_balance211_single_thread() {
159 let ranges = balance211(100, 1);
160 assert_eq!(ranges.len(), 1);
161 assert_eq!(ranges[0], (0, 100));
162 }
163
164 #[test]
165 fn test_balance211_more_threads_than_items() {
166 let ranges = balance211(3, 5);
167 assert_eq!(ranges.len(), 5);
168 let items: Vec<_> = ranges.iter().map(|(_, c)| *c).collect();
170 assert_eq!(items, vec![1, 1, 1, 0, 0]);
171 }
172
173 #[test]
174 fn test_balance211_iter_basic() {
175 let mut iter = Balance211Iter::new(10, 3);
176 assert_eq!(iter.len(), 3);
177
178 assert_eq!(iter.next(), Some(0..4));
179 assert_eq!(iter.next(), Some(4..7));
180 assert_eq!(iter.next(), Some(7..10));
181 assert_eq!(iter.next(), None);
182 }
183
184 #[test]
185 fn test_balance211_iter_exact_size() {
186 let iter = Balance211Iter::new(10, 3);
187 assert_eq!(iter.len(), 3);
188
189 let mut iter2 = Balance211Iter::new(10, 3);
190 iter2.next();
191 assert_eq!(iter2.len(), 2);
192 }
193
194 #[test]
195 fn test_batch_split_strategy_default() {
196 assert_eq!(BatchSplitStrategy::default(), BatchSplitStrategy::Simple);
197 }
198
199 #[test]
200 fn test_split_batch_simple() {
201 let chunks = split_batch(100, 4, BatchSplitStrategy::Simple);
202 assert_eq!(chunks.len(), 4);
203 assert_eq!(chunks, vec![25, 25, 25, 25]);
205 }
206
207 #[test]
208 fn test_split_batch_simple_with_remainder() {
209 let chunks = split_batch(10, 3, BatchSplitStrategy::Simple);
210 assert_eq!(chunks.len(), 3);
211 assert_eq!(chunks, vec![3, 3, 4]);
213 assert_eq!(chunks.iter().sum::<usize>(), 10);
214 }
215
216 #[test]
217 fn test_split_batch_equal() {
218 let chunks = split_batch(10, 3, BatchSplitStrategy::Equal);
219 assert_eq!(chunks.len(), 3);
220 assert_eq!(chunks, vec![4, 3, 3]);
222 assert_eq!(chunks.iter().sum::<usize>(), 10);
223 }
224
225 #[test]
226 fn test_split_batch_sequence_aware() {
227 let chunks = split_batch(10, 3, BatchSplitStrategy::SequenceAware);
228 assert_eq!(chunks, vec![4, 3, 3]);
230 }
231
232 #[test]
233 fn test_split_batch_empty() {
234 assert!(split_batch(0, 4, BatchSplitStrategy::Simple).is_empty());
235 assert!(split_batch(100, 0, BatchSplitStrategy::Simple).is_empty());
236 }
237
238 #[test]
239 fn test_split_batch_single_worker() {
240 let chunks = split_batch(100, 1, BatchSplitStrategy::Simple);
241 assert_eq!(chunks, vec![100]);
242 }
243
244 #[test]
248 fn test_falsify_split_batch_preserves_total() {
249 for total in [1, 10, 100, 997, 1000, 10000] {
250 for workers in [1, 2, 3, 4, 7, 16, 100] {
251 for strategy in [
252 BatchSplitStrategy::Simple,
253 BatchSplitStrategy::Equal,
254 BatchSplitStrategy::SequenceAware,
255 ] {
256 let chunks = split_batch(total, workers, strategy);
257 let sum: usize = chunks.iter().sum();
258 assert_eq!(
259 sum, total,
260 "FALSIFICATION FAILED: split_batch({}, {}, {:?}) sum {} != {}",
261 total, workers, strategy, sum, total
262 );
263 }
264 }
265 }
266 }
267
268 #[test]
272 fn test_falsify_balance211_max_diff_one() {
273 for n in [1, 10, 100, 997, 1000] {
274 for nthreads in [1, 2, 3, 4, 7, 16, 100] {
275 let ranges = balance211(n, nthreads);
276 if ranges.is_empty() {
277 continue;
278 }
279 let counts: Vec<_> = ranges.iter().map(|(_, c)| *c).collect();
280 let max_count = *counts.iter().max().unwrap_or(&0);
281 let min_count = *counts.iter().min().unwrap_or(&0);
282 assert!(
283 max_count - min_count <= 1,
284 "FALSIFICATION FAILED: balance211({}, {}) has diff {} (max={}, min={})",
285 n,
286 nthreads,
287 max_count - min_count,
288 max_count,
289 min_count
290 );
291 }
292 }
293 }
294
295 #[test]
297 fn test_falsify_balance211_contiguous() {
298 for n in [10, 100, 1000] {
299 for nthreads in [2, 3, 4, 7] {
300 let ranges = balance211(n, nthreads);
301 let mut expected_offset = 0;
302 for (i, &(offset, count)) in ranges.iter().enumerate() {
303 assert_eq!(
304 offset, expected_offset,
305 "FALSIFICATION FAILED: balance211({}, {}) range {} offset {} != expected {}",
306 n, nthreads, i, offset, expected_offset
307 );
308 expected_offset += count;
309 }
310 assert_eq!(
311 expected_offset, n,
312 "FALSIFICATION FAILED: balance211({}, {}) total {} != {}",
313 n, nthreads, expected_offset, n
314 );
315 }
316 }
317 }
318}