1use log::debug;
16#[cfg(feature = "log_parallelism")]
17use log::{info, trace};
18#[cfg(feature = "log_parallelism")]
19use std::ops::AddAssign;
20use std::sync::atomic::{AtomicU64, Ordering};
21use std::sync::Arc;
22#[cfg(feature = "log_parallelism")]
23use std::sync::Mutex;
24
25pub trait RangeFactory {
27 type Rn: Range;
28 type Orchestrator: RangeOrchestrator;
29
30 fn new(num_elements: usize, num_threads: usize) -> Self;
33
34 fn orchestrator(self) -> Self::Orchestrator;
37
38 fn range(&self, thread_id: usize) -> Self::Rn;
40}
41
42pub trait RangeOrchestrator {
44 fn reset_ranges(&self);
46
47 #[cfg(feature = "log_parallelism")]
49 fn print_statistics(&self) {}
50}
51
52pub trait Range {
55 type Iter: Iterator<Item = usize>;
56
57 fn iter(&self) -> Self::Iter;
61}
62
63pub struct FixedRangeFactory {
65 num_elements: usize,
67 num_threads: usize,
69}
70
71impl RangeFactory for FixedRangeFactory {
72 type Rn = FixedRange;
73 type Orchestrator = FixedRangeOrchestrator;
74
75 fn new(num_elements: usize, num_threads: usize) -> Self {
76 Self {
77 num_elements,
78 num_threads,
79 }
80 }
81
82 fn orchestrator(self) -> FixedRangeOrchestrator {
83 FixedRangeOrchestrator {}
84 }
85
86 fn range(&self, thread_id: usize) -> FixedRange {
87 let start = (thread_id * self.num_elements) / self.num_threads;
88 let end = ((thread_id + 1) * self.num_elements) / self.num_threads;
89 FixedRange(start..end)
90 }
91}
92
93pub struct FixedRangeOrchestrator {}
95
96impl RangeOrchestrator for FixedRangeOrchestrator {
97 fn reset_ranges(&self) {
98 }
100}
101
102#[derive(Debug, PartialEq, Eq)]
104pub struct FixedRange(std::ops::Range<usize>);
105
106impl Range for FixedRange {
107 type Iter = std::ops::Range<usize>;
108
109 fn iter(&self) -> Self::Iter {
110 self.0.clone()
111 }
112}
113
114pub struct WorkStealingRangeFactory {
120 num_elements: usize,
122 ranges: Arc<Vec<AtomicRange>>,
124 #[cfg(feature = "log_parallelism")]
126 stats: Arc<Mutex<WorkStealingStats>>,
127}
128
129impl RangeFactory for WorkStealingRangeFactory {
130 type Rn = WorkStealingRange;
131 type Orchestrator = WorkStealingRangeOrchestrator;
132
133 fn new(num_elements: usize, num_threads: usize) -> Self {
134 Self {
135 num_elements,
136 ranges: Arc::new((0..num_threads).map(|_| AtomicRange::default()).collect()),
137 #[cfg(feature = "log_parallelism")]
138 stats: Arc::new(Mutex::new(WorkStealingStats::default())),
139 }
140 }
141
142 fn orchestrator(self) -> WorkStealingRangeOrchestrator {
143 WorkStealingRangeOrchestrator {
144 num_elements: self.num_elements,
145 ranges: self.ranges,
146 #[cfg(feature = "log_parallelism")]
147 stats: self.stats,
148 }
149 }
150
151 fn range(&self, thread_id: usize) -> WorkStealingRange {
152 WorkStealingRange {
153 id: thread_id,
154 ranges: self.ranges.clone(),
155 #[cfg(feature = "log_parallelism")]
156 stats: self.stats.clone(),
157 }
158 }
159}
160
161pub struct WorkStealingRangeOrchestrator {
163 num_elements: usize,
165 ranges: Arc<Vec<AtomicRange>>,
167 #[cfg(feature = "log_parallelism")]
169 stats: Arc<Mutex<WorkStealingStats>>,
170}
171
172impl RangeOrchestrator for WorkStealingRangeOrchestrator {
173 fn reset_ranges(&self) {
174 debug!("Resetting ranges.");
175 let num_threads = self.ranges.len();
176 for (i, range) in self.ranges.iter().enumerate() {
177 let start = (i * self.num_elements) / num_threads;
178 let end = ((i + 1) * self.num_elements) / num_threads;
179 range.store(PackedRange::new(start as u32, end as u32));
180 }
181 }
182
183 #[cfg(feature = "log_parallelism")]
184 fn print_statistics(&self) {
185 let stats = self.stats.lock().unwrap();
186 info!("Work-stealing statistics:");
187 info!("- increments: {}", stats.increments);
188 info!("- failed_increments: {}", stats.failed_increments);
189 info!("- other_loads: {}", stats.other_loads);
190 info!("- thefts: {}", stats.thefts);
191 info!("- failed_thefts: {}", stats.failed_thefts);
192 info!("- increments + thefts: {}", stats.increments + stats.thefts);
193 }
194}
195
196pub struct WorkStealingRange {
198 id: usize,
200 ranges: Arc<Vec<AtomicRange>>,
202 #[cfg(feature = "log_parallelism")]
204 stats: Arc<Mutex<WorkStealingStats>>,
205}
206
207impl Range for WorkStealingRange {
208 type Iter = WorkStealingRangeIterator;
209
210 fn iter(&self) -> Self::Iter {
211 WorkStealingRangeIterator {
212 id: self.id,
213 ranges: self.ranges.clone(),
214 #[cfg(feature = "log_parallelism")]
215 stats: WorkStealingStats::default(),
216 #[cfg(feature = "log_parallelism")]
217 global_stats: self.stats.clone(),
218 }
219 }
220}
221
222#[repr(align(64))]
224struct AtomicRange(AtomicU64);
225
226impl Default for AtomicRange {
227 #[inline(always)]
228 fn default() -> Self {
229 AtomicRange::new(PackedRange::default())
230 }
231}
232
233impl AtomicRange {
234 #[inline(always)]
236 fn new(range: PackedRange) -> Self {
237 AtomicRange(AtomicU64::new(range.0))
238 }
239
240 #[inline(always)]
242 fn load(&self) -> PackedRange {
243 PackedRange(self.0.load(Ordering::SeqCst))
244 }
245
246 #[inline(always)]
248 fn store(&self, range: PackedRange) {
249 self.0.store(range.0, Ordering::SeqCst)
250 }
251
252 #[inline(always)]
255 fn compare_exchange(&self, before: PackedRange, after: PackedRange) -> Result<(), PackedRange> {
256 match self
257 .0
258 .compare_exchange(before.0, after.0, Ordering::SeqCst, Ordering::SeqCst)
259 {
260 Ok(_) => Ok(()),
261 Err(e) => Err(PackedRange(e)),
262 }
263 }
264}
265
266#[derive(Clone, Copy, Default)]
269struct PackedRange(u64);
270
271impl PackedRange {
272 #[inline(always)]
274 fn new(start: u32, end: u32) -> Self {
275 Self((start as u64) | ((end as u64) << 32))
276 }
277
278 #[inline(always)]
280 fn start(self) -> u32 {
281 self.0 as u32
282 }
283
284 #[inline(always)]
286 fn end(self) -> u32 {
287 (self.0 >> 32) as u32
288 }
289
290 #[inline(always)]
292 fn len(self) -> u32 {
293 self.end() - self.start()
294 }
295
296 #[inline(always)]
298 fn increment_start(self) -> (u32, Self) {
299 assert!(self.start() < self.end());
300 (self.start(), PackedRange::new(self.start() + 1, self.end()))
302 }
303
304 #[inline(always)]
307 fn split(self) -> (Self, Self) {
308 let start = self.start();
309 let end = self.end();
310 let middle = (start + end) / 2;
312 (
313 PackedRange::new(start, middle),
314 PackedRange::new(middle, end),
315 )
316 }
317
318 #[inline(always)]
320 fn is_empty(self) -> bool {
321 self.start() == self.end()
322 }
323}
324
325#[cfg(feature = "log_parallelism")]
326#[derive(Default)]
327pub struct WorkStealingStats {
328 increments: u64,
330 failed_increments: u64,
333 other_loads: u64,
336 thefts: u64,
338 failed_thefts: u64,
341}
342
343#[cfg(feature = "log_parallelism")]
344impl AddAssign<&WorkStealingStats> for WorkStealingStats {
345 fn add_assign(&mut self, other: &WorkStealingStats) {
346 self.increments += other.increments;
347 self.failed_increments += other.failed_increments;
348 self.other_loads += other.other_loads;
349 self.thefts += other.thefts;
350 self.failed_thefts += other.failed_thefts;
351 }
352}
353
354pub struct WorkStealingRangeIterator {
356 id: usize,
358 ranges: Arc<Vec<AtomicRange>>,
360 #[cfg(feature = "log_parallelism")]
362 stats: WorkStealingStats,
363 #[cfg(feature = "log_parallelism")]
365 global_stats: Arc<Mutex<WorkStealingStats>>,
366}
367
368impl Iterator for WorkStealingRangeIterator {
369 type Item = usize;
370
371 fn next(&mut self) -> Option<usize> {
372 let my_atomic_range: &AtomicRange = &self.ranges[self.id];
373 let mut my_range: PackedRange = my_atomic_range.load();
374 loop {
375 if !my_range.is_empty() {
376 let (taken, my_new_range) = my_range.increment_start();
377 match my_atomic_range.compare_exchange(my_range, my_new_range) {
378 Ok(()) => {
379 #[cfg(feature = "log_parallelism")]
380 {
381 self.stats.increments += 1;
382 trace!(
383 "[thread {}] Incremented range to {}..{}.",
384 self.id,
385 my_new_range.start(),
386 my_new_range.end()
387 );
388 }
389 return Some(taken as usize);
390 }
391 Err(range) => {
392 my_range = range;
393 #[cfg(feature = "log_parallelism")]
394 {
395 self.stats.failed_increments += 1;
396 debug!(
397 "[thread {}] Failed to increment range, new range is {}..{}.",
398 self.id,
399 range.start(),
400 range.end()
401 );
402 }
403 continue;
404 }
405 }
406 } else {
407 #[cfg(feature = "log_parallelism")]
408 debug!(
409 "[thread {}] Range {}..{} is empty, scanning other threads.",
410 self.id,
411 my_range.start(),
412 my_range.end()
413 );
414 let range_count = self.ranges.len();
415
416 #[cfg(feature = "log_parallelism")]
417 {
418 self.stats.other_loads += range_count as u64 - 1;
419 }
420 let mut other_ranges = vec![PackedRange::default(); range_count];
421 for (i, range) in other_ranges.iter_mut().enumerate() {
422 if i == self.id {
423 continue;
424 }
425 *range = self.ranges[i].load();
426 }
427
428 let mut max_index = 0;
429 let mut max_range = PackedRange::default();
430 for (i, range) in other_ranges.iter().enumerate() {
431 if i == self.id {
432 continue;
433 }
434 if range.len() > max_range.len() {
435 max_index = i;
436 max_range = *range;
437 }
438 }
439
440 while !max_range.is_empty() {
441 let (remaining, stolen) = max_range.split();
443 match self.ranges[max_index].compare_exchange(max_range, remaining) {
444 Ok(()) => {
445 let (taken, my_new_range) = stolen.increment_start();
446 my_atomic_range.store(my_new_range);
447 #[cfg(feature = "log_parallelism")]
448 {
449 self.stats.thefts += 1;
450 }
451 return Some(taken as usize);
452 }
453 Err(range) => {
454 other_ranges[max_index] = range;
455 #[cfg(feature = "log_parallelism")]
456 {
457 self.stats.failed_thefts += 1;
458 }
459 max_range = range;
461 for (i, range) in other_ranges.iter().enumerate() {
462 if i == self.id {
463 continue;
464 }
465 if range.len() > max_range.len() {
466 max_index = i;
467 max_range = *range;
468 }
469 }
470 }
471 }
472 }
473
474 #[cfg(feature = "log_parallelism")]
475 {
476 debug!("[thread {}] Didn't find anything to steal", self.id);
477 *self.global_stats.lock().unwrap() += &self.stats;
478 }
479 return None;
481 }
482 }
483 }
484}
485
486#[cfg(test)]
487mod test {
488 use super::*;
489
490 #[test]
491 fn test_fixed_range_factory_splits_evenly() {
492 let factory = FixedRangeFactory::new(100, 4);
493 assert_eq!(factory.range(0), FixedRange(0..25));
494 assert_eq!(factory.range(1), FixedRange(25..50));
495 assert_eq!(factory.range(2), FixedRange(50..75));
496 assert_eq!(factory.range(3), FixedRange(75..100));
497
498 let factory = FixedRangeFactory::new(100, 7);
499 assert_eq!(factory.range(0), FixedRange(0..14));
500 assert_eq!(factory.range(1), FixedRange(14..28));
501 assert_eq!(factory.range(2), FixedRange(28..42));
502 assert_eq!(factory.range(3), FixedRange(42..57));
503 assert_eq!(factory.range(4), FixedRange(57..71));
504 assert_eq!(factory.range(5), FixedRange(71..85));
505 assert_eq!(factory.range(6), FixedRange(85..100));
506 }
507
508 #[test]
509 fn test_fixed_range() {
510 let factory = FixedRangeFactory::new(100, 4);
511 let ranges: [_; 4] = std::array::from_fn(|i| factory.range(i));
512 let orchestrator = factory.orchestrator();
513
514 std::thread::scope(|s| {
515 for _ in 0..10 {
516 orchestrator.reset_ranges();
517 let handles = ranges
518 .each_ref()
519 .map(|range| s.spawn(move || range.iter().collect::<Vec<_>>()));
520 let values: [Vec<usize>; 4] = handles.map(|handle| handle.join().unwrap());
521
522 for (i, set) in values.iter().enumerate() {
524 assert_eq!(*set, (i * 25..(i + 1) * 25).collect::<Vec<_>>());
525 }
526 }
527 });
528 }
529
530 #[test]
531 fn test_work_stealing_range() {
532 const NUM_THREADS: usize = 4;
533 const NUM_ELEMENTS: usize = 10000;
534
535 let factory = WorkStealingRangeFactory::new(NUM_ELEMENTS, NUM_THREADS);
536 let ranges: [_; NUM_THREADS] = std::array::from_fn(|i| factory.range(i));
537 let orchestrator = factory.orchestrator();
538
539 std::thread::scope(|s| {
540 for _ in 0..10 {
541 orchestrator.reset_ranges();
542 let handles = ranges
543 .each_ref()
544 .map(|range| s.spawn(move || range.iter().collect::<Vec<_>>()));
545 let values: [Vec<usize>; NUM_THREADS] =
546 handles.map(|handle| handle.join().unwrap());
547
548 let mut all_values = vec![false; NUM_ELEMENTS];
552 for set in values {
553 println!("Values: {set:?}");
554 for x in set {
555 assert!(!all_values[x]);
556 all_values[x] = true;
557 }
558 }
559 assert!(all_values.iter().all(|x| *x));
561 }
562 });
563 }
564
565 #[test]
566 fn test_default_packed_range_is_empty() {
567 let range = PackedRange::default();
568 assert!(range.is_empty());
569 assert_eq!(range.start(), 0);
570 assert_eq!(range.end(), 0);
571 }
572
573 #[test]
574 fn test_packed_range_is_consistent() {
575 for i in 0..30 {
576 for j in i..30 {
577 let range = PackedRange::new(i, j);
578 assert_eq!(range.start(), i);
579 assert_eq!(range.end(), j);
580 }
581 }
582 }
583
584 #[test]
585 fn test_packed_range_increment_start() {
586 let mut range = PackedRange::new(0, 10);
587
588 for i in 1..=10 {
589 let (j, new_range) = range.increment_start();
590 range = new_range;
591 assert_eq!(j, i - 1);
592 assert_eq!((range.start(), range.end()), (i, 10));
593 }
594 }
595
596 #[test]
597 fn test_packed_range_split() {
598 let (left, right) = PackedRange::new(0, 0).split();
599 assert!(left.is_empty());
600 assert_eq!((left.start(), left.end()), (0, 0));
601 assert!(right.is_empty());
602 assert_eq!((right.start(), right.end()), (0, 0));
603
604 let (left, right) = PackedRange::new(0, 1).split();
605 assert!(left.is_empty());
606 assert_eq!((left.start(), left.end()), (0, 0));
607 assert!(!right.is_empty());
608 assert_eq!((right.start(), right.end()), (0, 1));
609 }
610
611 #[test]
612 fn test_packed_range_split_is_exhaustive() {
613 for i in 0..100 {
614 for j in i..100 {
615 let (left, right) = PackedRange::new(i, j).split();
616 assert!(left.start() <= left.end());
617 assert!(right.start() <= right.end());
618 assert_eq!(left.start(), i);
619 assert_eq!(left.end(), right.start());
620 assert_eq!(right.end(), j);
621 }
622 }
623 }
624
625 #[test]
626 fn test_packed_range_split_is_fair() {
627 for i in 0..100 {
628 for j in i..100 {
629 let (left, right) = PackedRange::new(i, j).split();
630 assert!(left.end() - left.start() <= right.end() - right.start());
631 assert!(right.end() - right.start() <= left.end() - left.start() + 1);
632 if i != j {
633 assert!(!right.is_empty());
634 }
635 }
636 }
637 }
638}