use crate::runtime::scheduler::local_queue::Stealer;
use crate::types::TaskId;
use crate::util::DetRng;
#[inline]
pub fn steal_task(stealers: &[Stealer], rng: &mut DetRng) -> Option<TaskId> {
let len = stealers.len();
if len == 0 {
return None;
}
if len == 1 {
return stealers[0].steal();
}
let idx1 = rng.next_usize(len);
let mut idx2 = rng.next_usize(len);
if idx1 == idx2 {
idx2 = (idx1 + 1) % len;
}
let len1 = stealers[idx1].stealable_len_hint();
let len2 = stealers[idx2].stealable_len_hint();
let (primary, primary_hint, secondary, secondary_hint) = if len1 >= len2 {
(idx1, len1, idx2, len2)
} else {
(idx2, len2, idx1, len1)
};
if primary_hint > 0
&& let Some(task) = stealers[primary].steal()
{
return Some(task);
}
if secondary_hint > 0
&& let Some(task) = stealers[secondary].steal()
{
return Some(task);
}
let start = rng.next_usize(len);
for i in 0..len {
let idx = circular_index(start, i, len);
if idx == primary || idx == secondary {
continue; }
if let Some(task) = stealers[idx].steal() {
return Some(task);
}
}
None
}
#[inline]
fn circular_index(start: usize, offset: usize, len: usize) -> usize {
debug_assert!(len > 0);
(start + offset) % len
}
#[cfg(test)]
mod tests {
#![allow(
clippy::pedantic,
clippy::nursery,
clippy::expect_fun_call,
clippy::map_unwrap_or,
clippy::cast_possible_wrap,
clippy::future_not_send
)]
use super::*;
use crate::runtime::scheduler::local_queue::LocalQueue;
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
fn task(id: u32) -> TaskId {
TaskId::new_for_test(id, 0)
}
#[test]
fn test_steal_from_busy_worker_succeeds() {
let queue = LocalQueue::new_for_test(9);
for i in 0..10 {
queue.push(task(i));
}
let stealers = vec![queue.stealer()];
let mut rng = DetRng::new(42);
let stolen = steal_task(&stealers, &mut rng);
assert!(stolen.is_some(), "should steal from busy queue");
}
#[test]
fn test_steal_from_empty_returns_none() {
let queue = LocalQueue::new_for_test(0);
let stealers = vec![queue.stealer()];
let mut rng = DetRng::new(42);
let stolen = steal_task(&stealers, &mut rng);
assert!(stolen.is_none(), "empty queue should return None");
}
#[test]
fn test_steal_empty_stealers_list() {
let stealers: Vec<Stealer> = vec![];
let mut rng = DetRng::new(42);
let stolen = steal_task(&stealers, &mut rng);
assert!(stolen.is_none(), "empty stealers list should return None");
}
#[test]
fn test_steal_skips_empty_queues() {
let q1 = LocalQueue::new_for_test(0);
let q2 = LocalQueue::new_for_test(0);
let q3 = LocalQueue::new_for_test(99);
q3.push(task(99));
let stealers = vec![q1.stealer(), q2.stealer(), q3.stealer()];
let mut found = false;
for seed in 0..10 {
let mut rng = DetRng::new(seed);
let stolen = steal_task(&stealers, &mut rng);
if let Some(t) = stolen {
assert_eq!(t, task(99));
found = true;
break;
}
}
assert!(
found,
"should have found task in q3 with at least one deterministic seed in [0, 10)"
);
}
#[test]
fn test_steal_visits_all_queues() {
let queues: Vec<_> = (0..5).map(|_| LocalQueue::new_for_test(4)).collect();
for (i, q) in queues.iter().enumerate() {
q.push(task(i as u32));
}
let stealers: Vec<_> = queues.iter().map(LocalQueue::stealer).collect();
let mut seen = HashSet::new();
let mut rng = DetRng::new(0);
for _ in 0..10 {
if let Some(t) = steal_task(&stealers, &mut rng) {
seen.insert(t);
}
}
assert_eq!(seen.len(), 5, "should visit all queues");
}
#[test]
fn test_steal_contention_no_deadlock() {
let queue = Arc::new(LocalQueue::new_for_test(99));
for i in 0..100 {
queue.push(task(i));
}
let stealer = queue.stealer();
let stolen_count = Arc::new(AtomicUsize::new(0));
let barrier = Arc::new(Barrier::new(5));
let handles: Vec<_> = (0_u64..5)
.map(|i| {
let s = stealer.clone();
let count = stolen_count.clone();
let b = barrier.clone();
thread::spawn(move || {
let stealers = vec![s];
let mut rng = DetRng::new(i);
b.wait();
let mut local_count = 0;
while steal_task(&stealers, &mut rng).is_some() {
local_count += 1;
thread::yield_now();
}
count.fetch_add(local_count, Ordering::SeqCst);
})
})
.collect();
for h in handles {
h.join().expect("thread should complete without deadlock");
}
assert_eq!(
stolen_count.load(Ordering::SeqCst),
100,
"all tasks should be stolen exactly once"
);
}
#[test]
fn test_steal_task_concurrent_owner_pop_on_resized_queue_preserves_task_set() {
let queues: Vec<_> = (0..3)
.map(|_| Arc::new(LocalQueue::new_for_test(383)))
.collect();
let per_queue = 128usize;
let total = per_queue * queues.len();
for (queue_idx, queue) in queues.iter().enumerate() {
let base = queue_idx * per_queue;
for task_id in base..base + per_queue {
queue.push(task(task_id as u32));
}
}
let counts: Arc<Vec<AtomicUsize>> =
Arc::new((0..total).map(|_| AtomicUsize::new(0)).collect());
let observed_total = Arc::new(AtomicUsize::new(0));
let stealer_threads = 4;
let barrier = Arc::new(Barrier::new(stealer_threads + 2));
let owner_queue = Arc::clone(&queues[0]);
let owner_counts = Arc::clone(&counts);
let owner_observed_total = Arc::clone(&observed_total);
let owner_barrier = Arc::clone(&barrier);
let owner = thread::spawn(move || {
owner_barrier.wait();
while let Some(task_id) = owner_queue.pop() {
let idx = task_id.0.index() as usize;
owner_counts[idx].fetch_add(1, Ordering::SeqCst);
owner_observed_total.fetch_add(1, Ordering::SeqCst);
thread::yield_now();
}
});
let mut stealer_handles = Vec::new();
for seed in 0_u64..stealer_threads as u64 {
let stealers: Vec<_> = queues.iter().map(|queue| queue.stealer()).collect();
let counts = Arc::clone(&counts);
let observed_total = Arc::clone(&observed_total);
let barrier = Arc::clone(&barrier);
stealer_handles.push(thread::spawn(move || {
let mut rng = DetRng::new(seed + 1);
barrier.wait();
while observed_total.load(Ordering::SeqCst) < total {
if let Some(task_id) = steal_task(&stealers, &mut rng) {
let idx = task_id.0.index() as usize;
counts[idx].fetch_add(1, Ordering::SeqCst);
observed_total.fetch_add(1, Ordering::SeqCst);
} else {
thread::yield_now();
}
}
}));
}
barrier.wait();
owner
.join()
.expect("owner thread should complete without losing tasks");
for handle in stealer_handles {
handle
.join()
.expect("stealer thread should complete without deadlock");
}
let mut total_seen = 0usize;
for (idx, count) in counts.iter().enumerate() {
let seen = count.load(Ordering::SeqCst);
assert_eq!(seen, 1, "task {idx} seen {seen} times");
total_seen += seen;
}
assert_eq!(
total_seen, total,
"owner pops and steal_task victim selection must preserve the exact task set"
);
}
#[test]
fn test_steal_deterministic_with_same_seed() {
let q1a = LocalQueue::new_for_test(3);
let q2a = LocalQueue::new_for_test(3);
let q3a = LocalQueue::new_for_test(3);
q1a.push(task(1));
q2a.push(task(2));
q3a.push(task(3));
let stealers_a = vec![q1a.stealer(), q2a.stealer(), q3a.stealer()];
let q1b = LocalQueue::new_for_test(3);
let q2b = LocalQueue::new_for_test(3);
let q3b = LocalQueue::new_for_test(3);
q1b.push(task(1));
q2b.push(task(2));
q3b.push(task(3));
let stealers_b = vec![q1b.stealer(), q2b.stealer(), q3b.stealer()];
let mut rng1 = DetRng::new(12345);
let mut rng2 = DetRng::new(12345);
let result1 = steal_task(&stealers_a, &mut rng1);
let result2 = steal_task(&stealers_b, &mut rng2);
assert_eq!(result1, result2, "same seed should give same steal target");
}
#[test]
fn test_power_of_two_prefers_heavier_queue() {
let heavy = LocalQueue::new_for_test(20);
let light = LocalQueue::new_for_test(20);
heavy.push(task(10));
heavy.push(task(11));
light.push(task(19));
let stealers = vec![heavy.stealer(), light.stealer()];
let mut rng = DetRng::new(7);
let stolen = steal_task(&stealers, &mut rng);
assert!(
matches!(stolen, Some(t) if t == task(10) || t == task(11)),
"power-of-two choice should prefer the heavier queue"
);
}
#[test]
fn test_power_of_two_falls_back_when_primary_is_local_only() {
let state = LocalQueue::test_state(10);
let local_only = LocalQueue::new(Arc::clone(&state));
let remote = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let first = guard.task_mut(task(0)).expect("task record missing");
first.mark_local();
let second = guard.task_mut(task(1)).expect("task record missing");
second.mark_local();
drop(guard);
}
local_only.push(task(0));
local_only.push(task(1));
remote.push(task(2));
let stealers = vec![local_only.stealer(), remote.stealer()];
let mut rng = DetRng::new(99);
let stolen = steal_task(&stealers, &mut rng);
assert_eq!(
stolen,
Some(task(2)),
"steal should fall back to the secondary queue when primary has only local tasks"
);
}
#[test]
fn test_power_of_two_prefers_stealable_hint_over_total_queue_len() {
let state = LocalQueue::test_state(20);
let local_heavy = LocalQueue::new(Arc::clone(&state));
let remote_light = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for id in 0..8 {
let record = guard.task_mut(task(id)).expect("task record missing");
record.mark_local();
}
}
for id in 0..8 {
local_heavy.push(task(id));
}
local_heavy.push(task(9));
remote_light.push(task(10));
let stealers = vec![local_heavy.stealer(), remote_light.stealer()];
let mut rng = DetRng::new(7);
let stolen = steal_task(&stealers, &mut rng);
assert_eq!(
stolen,
Some(task(10)),
"victim ranking should prefer actually stealable work over a larger local-only backlog"
);
}
#[test]
fn test_circular_index_math_correct() {
let len = 5;
let start = 3;
let offset = 4;
let idx = circular_index(start, offset, len);
assert_eq!(idx, 2); }
}