#[must_use]
pub fn balance211(n: usize, nthreads: usize) -> Vec<(usize, usize)> {
if nthreads == 0 {
return vec![];
}
let div = n / nthreads;
let rem = n % nthreads;
(0..nthreads)
.map(|i| {
let offset = if i < rem { (div + 1) * i } else { div * i + rem };
let count = if i < rem { div + 1 } else { div };
(offset, count)
})
.collect()
}
pub struct Balance211Iter {
ranges: Vec<(usize, usize)>,
current: usize,
}
impl Balance211Iter {
pub fn new(n: usize, nthreads: usize) -> Self {
Self { ranges: balance211(n, nthreads), current: 0 }
}
}
impl Iterator for Balance211Iter {
type Item = std::ops::Range<usize>;
fn next(&mut self) -> Option<Self::Item> {
if self.current >= self.ranges.len() {
return None;
}
let (offset, count) = self.ranges[self.current];
self.current += 1;
Some(offset..offset + count)
}
}
impl ExactSizeIterator for Balance211Iter {
fn len(&self) -> usize {
self.ranges.len() - self.current
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum BatchSplitStrategy {
#[default]
Simple,
Equal,
SequenceAware,
}
#[must_use]
pub fn split_batch(total: usize, num_workers: usize, strategy: BatchSplitStrategy) -> Vec<usize> {
if num_workers == 0 || total == 0 {
return vec![];
}
match strategy {
BatchSplitStrategy::Simple => {
let chunk_size = total / num_workers;
let mut chunks = vec![chunk_size; num_workers];
if let Some(last) = chunks.last_mut() {
*last += total % num_workers;
}
chunks
}
BatchSplitStrategy::Equal => {
balance211(total, num_workers).iter().map(|(_, count)| *count).collect()
}
BatchSplitStrategy::SequenceAware => {
balance211(total, num_workers).iter().map(|(_, count)| *count).collect()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_balance211_basic() {
let ranges = balance211(10, 3);
assert_eq!(ranges.len(), 3);
assert_eq!(ranges[0], (0, 4));
assert_eq!(ranges[1], (4, 3));
assert_eq!(ranges[2], (7, 3));
}
#[test]
fn test_balance211_even_division() {
let ranges = balance211(12, 4);
for (i, &(offset, count)) in ranges.iter().enumerate() {
assert_eq!(count, 3);
assert_eq!(offset, i * 3);
}
}
#[test]
fn test_balance211_empty() {
assert!(balance211(0, 4).iter().all(|&(_, c)| c == 0));
assert!(balance211(10, 0).is_empty());
}
#[test]
fn test_balance211_single_thread() {
let ranges = balance211(100, 1);
assert_eq!(ranges.len(), 1);
assert_eq!(ranges[0], (0, 100));
}
#[test]
fn test_balance211_more_threads_than_items() {
let ranges = balance211(3, 5);
assert_eq!(ranges.len(), 5);
let items: Vec<_> = ranges.iter().map(|(_, c)| *c).collect();
assert_eq!(items, vec![1, 1, 1, 0, 0]);
}
#[test]
fn test_balance211_iter_basic() {
let mut iter = Balance211Iter::new(10, 3);
assert_eq!(iter.len(), 3);
assert_eq!(iter.next(), Some(0..4));
assert_eq!(iter.next(), Some(4..7));
assert_eq!(iter.next(), Some(7..10));
assert_eq!(iter.next(), None);
}
#[test]
fn test_balance211_iter_exact_size() {
let iter = Balance211Iter::new(10, 3);
assert_eq!(iter.len(), 3);
let mut iter2 = Balance211Iter::new(10, 3);
iter2.next();
assert_eq!(iter2.len(), 2);
}
#[test]
fn test_batch_split_strategy_default() {
assert_eq!(BatchSplitStrategy::default(), BatchSplitStrategy::Simple);
}
#[test]
fn test_split_batch_simple() {
let chunks = split_batch(100, 4, BatchSplitStrategy::Simple);
assert_eq!(chunks.len(), 4);
assert_eq!(chunks, vec![25, 25, 25, 25]);
}
#[test]
fn test_split_batch_simple_with_remainder() {
let chunks = split_batch(10, 3, BatchSplitStrategy::Simple);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks, vec![3, 3, 4]);
assert_eq!(chunks.iter().sum::<usize>(), 10);
}
#[test]
fn test_split_batch_equal() {
let chunks = split_batch(10, 3, BatchSplitStrategy::Equal);
assert_eq!(chunks.len(), 3);
assert_eq!(chunks, vec![4, 3, 3]);
assert_eq!(chunks.iter().sum::<usize>(), 10);
}
#[test]
fn test_split_batch_sequence_aware() {
let chunks = split_batch(10, 3, BatchSplitStrategy::SequenceAware);
assert_eq!(chunks, vec![4, 3, 3]);
}
#[test]
fn test_split_batch_empty() {
assert!(split_batch(0, 4, BatchSplitStrategy::Simple).is_empty());
assert!(split_batch(100, 0, BatchSplitStrategy::Simple).is_empty());
}
#[test]
fn test_split_batch_single_worker() {
let chunks = split_batch(100, 1, BatchSplitStrategy::Simple);
assert_eq!(chunks, vec![100]);
}
#[test]
fn test_falsify_split_batch_preserves_total() {
for total in [1, 10, 100, 997, 1000, 10000] {
for workers in [1, 2, 3, 4, 7, 16, 100] {
for strategy in [
BatchSplitStrategy::Simple,
BatchSplitStrategy::Equal,
BatchSplitStrategy::SequenceAware,
] {
let chunks = split_batch(total, workers, strategy);
let sum: usize = chunks.iter().sum();
assert_eq!(
sum, total,
"FALSIFICATION FAILED: split_batch({}, {}, {:?}) sum {} != {}",
total, workers, strategy, sum, total
);
}
}
}
}
#[test]
fn test_falsify_balance211_max_diff_one() {
for n in [1, 10, 100, 997, 1000] {
for nthreads in [1, 2, 3, 4, 7, 16, 100] {
let ranges = balance211(n, nthreads);
if ranges.is_empty() {
continue;
}
let counts: Vec<_> = ranges.iter().map(|(_, c)| *c).collect();
let max_count = *counts.iter().max().unwrap_or(&0);
let min_count = *counts.iter().min().unwrap_or(&0);
assert!(
max_count - min_count <= 1,
"FALSIFICATION FAILED: balance211({}, {}) has diff {} (max={}, min={})",
n,
nthreads,
max_count - min_count,
max_count,
min_count
);
}
}
}
#[test]
fn test_falsify_balance211_contiguous() {
for n in [10, 100, 1000] {
for nthreads in [2, 3, 4, 7] {
let ranges = balance211(n, nthreads);
let mut expected_offset = 0;
for (i, &(offset, count)) in ranges.iter().enumerate() {
assert_eq!(
offset, expected_offset,
"FALSIFICATION FAILED: balance211({}, {}) range {} offset {} != expected {}",
n, nthreads, i, offset, expected_offset
);
expected_offset += count;
}
assert_eq!(
expected_offset, n,
"FALSIFICATION FAILED: balance211({}, {}) total {} != {}",
n, nthreads, expected_offset, n
);
}
}
}
}