Skip to main content

trueno/brick/
batch.rs

1//! Batch Splitting and Work Distribution
2//!
3//! LCP-05: Balance211 Work Distribution (Intel MKL pattern)
4//! LCP-09: Batch Splitting Strategies
5
6// ============================================================================
7// LCP-05: Balance211 Work Distribution (Intel MKL pattern)
8// ============================================================================
9
10/// Balance211 work distribution (Intel MKL pattern).
11///
12/// Distributes N items across T threads such that no thread
13/// has more than 1 extra item compared to any other.
14///
15/// # Example
16/// ```rust
17/// use trueno::brick::balance211;
18///
19/// let ranges = balance211(10, 3);
20/// // Thread 0: (0, 4) - 4 items
21/// // Thread 1: (4, 3) - 3 items
22/// // Thread 2: (7, 3) - 3 items
23/// assert_eq!(ranges.len(), 3);
24/// ```
25#[must_use]
26pub fn balance211(n: usize, nthreads: usize) -> Vec<(usize, usize)> {
27    if nthreads == 0 {
28        return vec![];
29    }
30    let div = n / nthreads;
31    let rem = n % nthreads;
32
33    (0..nthreads)
34        .map(|i| {
35            let offset = if i < rem { (div + 1) * i } else { div * i + rem };
36            let count = if i < rem { div + 1 } else { div };
37            (offset, count)
38        })
39        .collect()
40}
41
42/// Iterator adapter for balanced work distribution.
43pub struct Balance211Iter {
44    ranges: Vec<(usize, usize)>,
45    current: usize,
46}
47
48impl Balance211Iter {
49    /// Create a new balanced work iterator.
50    pub fn new(n: usize, nthreads: usize) -> Self {
51        Self { ranges: balance211(n, nthreads), current: 0 }
52    }
53}
54
55impl Iterator for Balance211Iter {
56    type Item = std::ops::Range<usize>;
57
58    fn next(&mut self) -> Option<Self::Item> {
59        if self.current >= self.ranges.len() {
60            return None;
61        }
62        let (offset, count) = self.ranges[self.current];
63        self.current += 1;
64        Some(offset..offset + count)
65    }
66}
67
68impl ExactSizeIterator for Balance211Iter {
69    fn len(&self) -> usize {
70        self.ranges.len() - self.current
71    }
72}
73
74// ============================================================================
75// LCP-09: Batch Splitting Strategies
76// ============================================================================
77
78/// Strategy for splitting batches across workers.
79#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
80pub enum BatchSplitStrategy {
81    /// Simple equal division (may leave remainder)
82    #[default]
83    Simple,
84    /// Equal distribution using Balance211
85    Equal,
86    /// Sequence-aware (keeps sequences together)
87    SequenceAware,
88}
89
90/// Split a batch into chunks according to strategy.
91///
92/// # Example
93/// ```rust
94/// use trueno::brick::{split_batch, BatchSplitStrategy};
95///
96/// let chunks = split_batch(100, 4, BatchSplitStrategy::Equal);
97/// assert_eq!(chunks.len(), 4);
98/// assert_eq!(chunks.iter().sum::<usize>(), 100);
99/// ```
100#[must_use]
101pub fn split_batch(total: usize, num_workers: usize, strategy: BatchSplitStrategy) -> Vec<usize> {
102    if num_workers == 0 || total == 0 {
103        return vec![];
104    }
105
106    match strategy {
107        BatchSplitStrategy::Simple => {
108            let chunk_size = total / num_workers;
109            let mut chunks = vec![chunk_size; num_workers];
110            // Last worker gets remainder
111            if let Some(last) = chunks.last_mut() {
112                *last += total % num_workers;
113            }
114            chunks
115        }
116        BatchSplitStrategy::Equal => {
117            // Use Balance211 for even distribution
118            balance211(total, num_workers).iter().map(|(_, count)| *count).collect()
119        }
120        BatchSplitStrategy::SequenceAware => {
121            // For now, same as Equal (sequence boundaries would need external info)
122            balance211(total, num_workers).iter().map(|(_, count)| *count).collect()
123        }
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130
131    #[test]
132    fn test_balance211_basic() {
133        let ranges = balance211(10, 3);
134        assert_eq!(ranges.len(), 3);
135        // First thread gets 4 items, others get 3
136        assert_eq!(ranges[0], (0, 4));
137        assert_eq!(ranges[1], (4, 3));
138        assert_eq!(ranges[2], (7, 3));
139    }
140
141    #[test]
142    fn test_balance211_even_division() {
143        let ranges = balance211(12, 4);
144        // 12 / 4 = 3 items each, no remainder
145        for (i, &(offset, count)) in ranges.iter().enumerate() {
146            assert_eq!(count, 3);
147            assert_eq!(offset, i * 3);
148        }
149    }
150
151    #[test]
152    fn test_balance211_empty() {
153        assert!(balance211(0, 4).iter().all(|&(_, c)| c == 0));
154        assert!(balance211(10, 0).is_empty());
155    }
156
157    #[test]
158    fn test_balance211_single_thread() {
159        let ranges = balance211(100, 1);
160        assert_eq!(ranges.len(), 1);
161        assert_eq!(ranges[0], (0, 100));
162    }
163
164    #[test]
165    fn test_balance211_more_threads_than_items() {
166        let ranges = balance211(3, 5);
167        assert_eq!(ranges.len(), 5);
168        // First 3 threads get 1 item each, last 2 get 0
169        let items: Vec<_> = ranges.iter().map(|(_, c)| *c).collect();
170        assert_eq!(items, vec![1, 1, 1, 0, 0]);
171    }
172
173    #[test]
174    fn test_balance211_iter_basic() {
175        let mut iter = Balance211Iter::new(10, 3);
176        assert_eq!(iter.len(), 3);
177
178        assert_eq!(iter.next(), Some(0..4));
179        assert_eq!(iter.next(), Some(4..7));
180        assert_eq!(iter.next(), Some(7..10));
181        assert_eq!(iter.next(), None);
182    }
183
184    #[test]
185    fn test_balance211_iter_exact_size() {
186        let iter = Balance211Iter::new(10, 3);
187        assert_eq!(iter.len(), 3);
188
189        let mut iter2 = Balance211Iter::new(10, 3);
190        iter2.next();
191        assert_eq!(iter2.len(), 2);
192    }
193
194    #[test]
195    fn test_batch_split_strategy_default() {
196        assert_eq!(BatchSplitStrategy::default(), BatchSplitStrategy::Simple);
197    }
198
199    #[test]
200    fn test_split_batch_simple() {
201        let chunks = split_batch(100, 4, BatchSplitStrategy::Simple);
202        assert_eq!(chunks.len(), 4);
203        // First 3 get 25, last gets 25 + 0 = 25
204        assert_eq!(chunks, vec![25, 25, 25, 25]);
205    }
206
207    #[test]
208    fn test_split_batch_simple_with_remainder() {
209        let chunks = split_batch(10, 3, BatchSplitStrategy::Simple);
210        assert_eq!(chunks.len(), 3);
211        // 10 / 3 = 3, remainder = 1 goes to last
212        assert_eq!(chunks, vec![3, 3, 4]);
213        assert_eq!(chunks.iter().sum::<usize>(), 10);
214    }
215
216    #[test]
217    fn test_split_batch_equal() {
218        let chunks = split_batch(10, 3, BatchSplitStrategy::Equal);
219        assert_eq!(chunks.len(), 3);
220        // Balance211 gives first worker more: 4, 3, 3
221        assert_eq!(chunks, vec![4, 3, 3]);
222        assert_eq!(chunks.iter().sum::<usize>(), 10);
223    }
224
225    #[test]
226    fn test_split_batch_sequence_aware() {
227        let chunks = split_batch(10, 3, BatchSplitStrategy::SequenceAware);
228        // Currently same as Equal
229        assert_eq!(chunks, vec![4, 3, 3]);
230    }
231
232    #[test]
233    fn test_split_batch_empty() {
234        assert!(split_batch(0, 4, BatchSplitStrategy::Simple).is_empty());
235        assert!(split_batch(100, 0, BatchSplitStrategy::Simple).is_empty());
236    }
237
238    #[test]
239    fn test_split_batch_single_worker() {
240        let chunks = split_batch(100, 1, BatchSplitStrategy::Simple);
241        assert_eq!(chunks, vec![100]);
242    }
243
244    /// FALSIFICATION TEST: Verify total items preserved after split
245    ///
246    /// The sum of all chunks must equal the original total.
247    #[test]
248    fn test_falsify_split_batch_preserves_total() {
249        for total in [1, 10, 100, 997, 1000, 10000] {
250            for workers in [1, 2, 3, 4, 7, 16, 100] {
251                for strategy in [
252                    BatchSplitStrategy::Simple,
253                    BatchSplitStrategy::Equal,
254                    BatchSplitStrategy::SequenceAware,
255                ] {
256                    let chunks = split_batch(total, workers, strategy);
257                    let sum: usize = chunks.iter().sum();
258                    assert_eq!(
259                        sum, total,
260                        "FALSIFICATION FAILED: split_batch({}, {}, {:?}) sum {} != {}",
261                        total, workers, strategy, sum, total
262                    );
263                }
264            }
265        }
266    }
267
268    /// FALSIFICATION TEST: Balance211 never gives more than 1 extra item
269    ///
270    /// The maximum difference between any two thread counts must be <= 1.
271    #[test]
272    fn test_falsify_balance211_max_diff_one() {
273        for n in [1, 10, 100, 997, 1000] {
274            for nthreads in [1, 2, 3, 4, 7, 16, 100] {
275                let ranges = balance211(n, nthreads);
276                if ranges.is_empty() {
277                    continue;
278                }
279                let counts: Vec<_> = ranges.iter().map(|(_, c)| *c).collect();
280                let max_count = *counts.iter().max().unwrap_or(&0);
281                let min_count = *counts.iter().min().unwrap_or(&0);
282                assert!(
283                    max_count - min_count <= 1,
284                    "FALSIFICATION FAILED: balance211({}, {}) has diff {} (max={}, min={})",
285                    n,
286                    nthreads,
287                    max_count - min_count,
288                    max_count,
289                    min_count
290                );
291            }
292        }
293    }
294
295    /// FALSIFICATION TEST: Balance211 ranges are contiguous and non-overlapping
296    #[test]
297    fn test_falsify_balance211_contiguous() {
298        for n in [10, 100, 1000] {
299            for nthreads in [2, 3, 4, 7] {
300                let ranges = balance211(n, nthreads);
301                let mut expected_offset = 0;
302                for (i, &(offset, count)) in ranges.iter().enumerate() {
303                    assert_eq!(
304                        offset, expected_offset,
305                        "FALSIFICATION FAILED: balance211({}, {}) range {} offset {} != expected {}",
306                        n, nthreads, i, offset, expected_offset
307                    );
308                    expected_offset += count;
309                }
310                assert_eq!(
311                    expected_offset, n,
312                    "FALSIFICATION FAILED: balance211({}, {}) total {} != {}",
313                    n, nthreads, expected_offset, n
314                );
315            }
316        }
317    }
318}