fast_steal/
split_task.rs

1use crate::task::Task;
2
3pub trait SplitTask {
4    fn split_task(&self, n: u64) -> impl Iterator<Item = Task>;
5    fn split_two(&self) -> (u64, u64);
6}
7
8impl SplitTask for Task {
9    fn split_task(&self, n: u64) -> impl Iterator<Item = Task> {
10        debug_assert!(n > 0, "n must be greater than 0");
11        let total = self.remain();
12        let offset = self.start();
13        let per_group = total / n;
14        let remainder = total % n;
15        (0..n).map(move |i| {
16            let start = offset + i * per_group + i.min(remainder);
17            let end = start + per_group + if i < remainder { 1 } else { 0 };
18            Task::new(start, end)
19        })
20    }
21
22    fn split_two(&self) -> (u64, u64) {
23        let start = self.start();
24        let end = self.end();
25        let mid = (start + end) / 2;
26        self.set_end(mid);
27        (mid, end)
28    }
29}
30
31#[cfg(test)]
32mod tests {
33    extern crate alloc;
34    use super::*;
35    use alloc::vec::Vec;
36
37    #[test]
38    fn test_split_task() {
39        let task = Task::new(1, 6); // 1, 2, 3, 4, 5
40        let groups: Vec<_> = task.split_task(3).collect(); // 5 / 3 = 1 remainder 2
41        assert_eq!(groups, [Task::new(1, 3), Task::new(3, 5), Task::new(5, 6)]);
42    }
43
44    #[test]
45    fn test_split_two() {
46        let task = Task::new(1, 6); // 1, 2, 3, 4, 5
47        let (mid, end) = task.split_two();
48        assert_eq!(task.start(), 1);
49        assert_eq!(task.end(), 3);
50        assert_eq!(mid, 3);
51        assert_eq!(end, 6);
52    }
53}