use crate::record::task::TaskRecord;
use crate::runtime::{RuntimeState, TaskTable};
use crate::sync::ContendedMutex;
use crate::types::TaskId;
#[cfg(any(test, feature = "test-internals"))]
use crate::types::{Budget, RegionId};
use crate::util::Arena;
use parking_lot::Mutex;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::sync::Arc;
thread_local! {
static CURRENT_QUEUE: RefCell<Option<LocalQueue>> = const { RefCell::new(None) };
}
#[derive(Debug, Clone)]
enum TaskSource {
RuntimeState(Arc<ContendedMutex<RuntimeState>>),
TaskTable(Arc<ContendedMutex<TaskTable>>),
}
impl TaskSource {
#[inline]
fn with_tasks_arena_mut<R, F>(&self, f: F) -> R
where
F: FnOnce(&mut Arena<TaskRecord>) -> R,
{
match self {
Self::RuntimeState(state) => {
let mut state = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
f(state.tasks_arena_mut())
}
Self::TaskTable(tasks) => {
let mut tasks = tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
f(tasks.tasks_arena_mut())
}
}
}
#[inline]
fn same_underlying_tasks(&self, other: &Self) -> bool {
match (self, other) {
(Self::RuntimeState(lhs), Self::RuntimeState(rhs)) => Arc::ptr_eq(lhs, rhs),
(Self::TaskTable(lhs), Self::TaskTable(rhs)) => Arc::ptr_eq(lhs, rhs),
_ => false,
}
}
}
#[derive(Debug, Clone)]
pub struct LocalQueue {
tasks: TaskSource,
inner: Arc<Mutex<VecDeque<TaskId>>>,
}
impl LocalQueue {
#[inline]
#[must_use]
pub fn new(state: Arc<ContendedMutex<RuntimeState>>) -> Self {
Self::new_with_source(TaskSource::RuntimeState(state))
}
#[inline]
#[must_use]
pub fn new_with_task_table(tasks: Arc<ContendedMutex<TaskTable>>) -> Self {
Self::new_with_source(TaskSource::TaskTable(tasks))
}
#[inline]
fn new_with_source(tasks: TaskSource) -> Self {
Self {
tasks,
inner: Arc::new(Mutex::new(VecDeque::with_capacity(256))),
}
}
pub(crate) fn set_current(queue: Self) -> CurrentQueueGuard {
let prev = CURRENT_QUEUE.with(|slot| slot.replace(Some(queue)));
CurrentQueueGuard { prev }
}
#[allow(dead_code)] pub(crate) fn clear_current() {
CURRENT_QUEUE.with(|slot| {
slot.borrow_mut().take();
});
}
#[inline]
pub(crate) fn schedule_local(task: TaskId) -> bool {
CURRENT_QUEUE.with(|slot| {
slot.borrow()
.as_ref()
.is_some_and(|queue| queue.schedule_local_push(task))
})
}
#[cfg(any(test, feature = "test-internals"))]
#[must_use]
pub fn test_state(max_task_id: u32) -> Arc<ContendedMutex<RuntimeState>> {
let mut state = RuntimeState::new();
for id in 0..=max_task_id {
let task_id = TaskId::new_for_test(id, 0);
let record = TaskRecord::new(task_id, RegionId::new_for_test(0, 0), Budget::INFINITE);
let idx = state.insert_task(record);
debug_assert_eq!(idx.index(), id);
}
Arc::new(ContendedMutex::new("runtime_state", state))
}
#[cfg(any(test, feature = "test-internals"))]
#[must_use]
pub fn test_task_table(max_task_id: u32) -> Arc<ContendedMutex<TaskTable>> {
let mut tasks = TaskTable::new();
for id in 0..=max_task_id {
let task_id = TaskId::new_for_test(id, 0);
let record = TaskRecord::new(task_id, RegionId::new_for_test(0, 0), Budget::INFINITE);
let idx = tasks.insert_task(record);
debug_assert_eq!(idx.index(), id);
}
Arc::new(ContendedMutex::new("task_table", tasks))
}
#[cfg(any(test, feature = "test-internals"))]
#[must_use]
pub fn new_for_test(max_task_id: u32) -> Self {
Self::new(Self::test_state(max_task_id))
}
#[inline]
pub fn push(&self, task: TaskId) {
let mut queue = self.inner.lock();
queue.push_back(task);
}
#[inline]
fn schedule_local_push(&self, task: TaskId) -> bool {
self.tasks.with_tasks_arena_mut(|arena| {
if arena.get(task.arena_index()).is_none() {
return false;
}
let mut queue = self.inner.lock();
if !queue.contains(&task) {
queue.push_back(task);
}
true
})
}
#[inline]
pub fn push_many(&self, tasks: &[TaskId]) {
if tasks.is_empty() {
return;
}
let mut queue = self.inner.lock();
for &task in tasks {
queue.push_back(task);
}
}
#[inline]
#[must_use]
pub fn pop(&self) -> Option<TaskId> {
let mut queue = self.inner.lock();
queue.pop_back()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
let stack = self.inner.lock();
stack.is_empty()
}
#[inline]
#[must_use]
pub fn len(&self) -> usize {
let stack = self.inner.lock();
stack.len()
}
#[inline]
#[must_use]
pub fn snapshot_tasks(&self) -> Vec<TaskId> {
let queue = self.inner.lock();
queue.iter().copied().collect()
}
#[inline]
#[must_use]
pub fn stealer(&self) -> Stealer {
Stealer {
tasks: self.tasks.clone(),
inner: Arc::clone(&self.inner),
}
}
}
pub(crate) struct CurrentQueueGuard {
prev: Option<LocalQueue>,
}
impl Drop for CurrentQueueGuard {
fn drop(&mut self) {
let prev = self.prev.take();
let _ = CURRENT_QUEUE.try_with(|slot| {
*slot.borrow_mut() = prev;
});
}
}
#[derive(Debug, Clone)]
pub struct Stealer {
tasks: TaskSource,
inner: Arc<Mutex<VecDeque<TaskId>>>,
}
impl Stealer {
const SKIPPED_LOCALS_INLINE_CAP: usize = 8;
#[inline]
#[must_use]
pub fn len(&self) -> usize {
self.inner.lock().len()
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.lock().is_empty()
}
#[inline]
fn steal_batch_locked(
src: &mut VecDeque<TaskId>,
dest: &mut VecDeque<TaskId>,
arena: &Arena<TaskRecord>,
) -> bool {
let initial_len = src.len();
if initial_len == 0 {
return false;
}
let steal_limit = (initial_len / 2).clamp(1, 256);
let mut stolen = 0;
let mut i = 0;
while i < src.len() && stolen < steal_limit && i < Self::SKIPPED_LOCALS_INLINE_CAP {
let task_id = src[i];
if let Some(record) = arena.get(task_id.arena_index()) {
if !record.is_local() {
let task = src
.remove(i)
.expect("index must be valid within array bounds");
dest.push_back(task);
stolen += 1;
continue; }
} else {
src.remove(i);
continue; }
i += 1;
}
stolen > 0
}
#[inline]
#[must_use]
#[allow(clippy::significant_drop_tightening)]
pub fn steal(&self) -> Option<TaskId> {
self.tasks.with_tasks_arena_mut(|arena| {
let mut stack = self.inner.lock();
let mut i = 0;
while i < stack.len() && i < Self::SKIPPED_LOCALS_INLINE_CAP {
let task_id = stack[i];
if let Some(record) = arena.get(task_id.arena_index()) {
if !record.is_local() {
let removed = stack.remove(i);
drop(stack);
return removed;
}
} else {
stack.remove(i);
continue;
}
i += 1;
}
drop(stack);
None
})
}
#[inline]
#[must_use]
#[allow(clippy::significant_drop_tightening)]
pub fn steal_batch(&self, dest: &LocalQueue) -> bool {
if Arc::ptr_eq(&self.inner, &dest.inner) {
return false;
}
if !self.tasks.same_underlying_tasks(&dest.tasks) {
return false;
}
debug_assert!(self.tasks.same_underlying_tasks(&dest.tasks));
self.tasks.with_tasks_arena_mut(|arena| {
let src_addr = Arc::as_ptr(&self.inner) as usize;
let dest_addr = Arc::as_ptr(&dest.inner) as usize;
if src_addr < dest_addr {
let mut src = self.inner.lock();
let mut dest_stack = dest.inner.lock();
Self::steal_batch_locked(&mut src, &mut dest_stack, arena)
} else {
let mut dest_stack = dest.inner.lock();
let mut src = self.inner.lock();
Self::steal_batch_locked(&mut src, &mut dest_stack, arena)
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::TaskId;
use proptest::prelude::*;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Barrier};
use std::thread;
fn task(id: u32) -> TaskId {
TaskId::new_for_test(id, 0)
}
fn queue(max_task_id: u32) -> LocalQueue {
LocalQueue::new_for_test(max_task_id)
}
fn queue_with_task_table(max_task_id: u32) -> LocalQueue {
let tasks = LocalQueue::test_task_table(max_task_id);
LocalQueue::new_with_task_table(tasks)
}
fn run_repeated_steal_batch_schedule(layout: &[(u32, bool)]) -> (Vec<Vec<u32>>, Vec<u32>) {
let max_task_id = layout.iter().map(|(id, _)| *id).max().unwrap_or(0);
let state = LocalQueue::test_state(max_task_id);
let src = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for &(id, is_local) in layout {
if is_local {
let record = guard.task_mut(task(id)).expect("task record missing");
record.mark_local();
}
}
}
for &(id, _) in layout {
src.push(task(id));
}
let mut steal_rounds = Vec::new();
loop {
let dest = LocalQueue::new(Arc::clone(&state));
if !src.stealer().steal_batch(&dest) {
break;
}
steal_rounds.push(
dest.snapshot_tasks()
.into_iter()
.map(|task_id| task_id.0.index())
.collect(),
);
}
let mut owner_remaining = Vec::new();
while let Some(task_id) = src.pop() {
owner_remaining.push(task_id.0.index());
}
(steal_rounds, owner_remaining)
}
fn run_repeated_steal_batch_from_push_chunks(chunks: &[Vec<u32>]) -> (Vec<Vec<u32>>, Vec<u32>) {
let max_task_id = chunks
.iter()
.flat_map(|chunk| chunk.iter().copied())
.max()
.unwrap_or(0);
let state = LocalQueue::test_state(max_task_id);
let src = LocalQueue::new(Arc::clone(&state));
for chunk in chunks {
let task_ids: Vec<_> = chunk.iter().copied().map(task).collect();
src.push_many(&task_ids);
}
let mut steal_rounds = Vec::new();
loop {
let dest = LocalQueue::new(Arc::clone(&state));
if !src.stealer().steal_batch(&dest) {
break;
}
steal_rounds.push(
dest.snapshot_tasks()
.into_iter()
.map(|task_id| task_id.0.index())
.collect(),
);
}
let mut owner_remaining = Vec::new();
while let Some(task_id) = src.pop() {
owner_remaining.push(task_id.0.index());
}
(steal_rounds, owner_remaining)
}
fn normalize_task_ids(task_ids: Vec<u32>, layout: &[(u32, bool)]) -> Vec<usize> {
let order: HashMap<u32, usize> = layout
.iter()
.enumerate()
.map(|(idx, (task_id, _))| (*task_id, idx))
.collect();
task_ids
.into_iter()
.map(|task_id| order[&task_id])
.collect()
}
fn normalize_rounds(rounds: Vec<Vec<u32>>, layout: &[(u32, bool)]) -> Vec<Vec<usize>> {
rounds
.into_iter()
.map(|round| normalize_task_ids(round, layout))
.collect()
}
fn drain_owner(queue: &LocalQueue) -> Vec<TaskId> {
let mut drained = Vec::new();
while let Some(task_id) = queue.pop() {
drained.push(task_id);
}
drained
}
fn drain_thief(queue: &LocalQueue) -> Vec<TaskId> {
let stealer = queue.stealer();
let mut drained = Vec::new();
while let Some(task_id) = stealer.steal() {
drained.push(task_id);
}
drained
}
#[test]
fn owner_pop_is_lifo() {
let queue = queue(3);
queue.push(task(1));
queue.push(task(2));
queue.push(task(3));
assert_eq!(queue.pop(), Some(task(3)));
assert_eq!(queue.pop(), Some(task(2)));
assert_eq!(queue.pop(), Some(task(1)));
assert_eq!(queue.pop(), None);
}
#[test]
fn thief_steal_is_fifo() {
let queue = queue(3);
queue.push(task(1));
queue.push(task(2));
queue.push(task(3));
let stealer = queue.stealer();
assert_eq!(stealer.steal(), Some(task(1)));
assert_eq!(stealer.steal(), Some(task(2)));
assert_eq!(stealer.steal(), Some(task(3)));
assert_eq!(stealer.steal(), None);
}
#[test]
fn steal_skips_local_tasks() {
let state = LocalQueue::test_state(1);
let queue = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let record = guard.task_mut(task(1)).expect("task record missing");
record.mark_local();
drop(guard);
}
queue.push(task(1));
let stealer = queue.stealer();
assert_eq!(stealer.steal(), None, "local task must not be stolen");
assert_eq!(queue.pop(), Some(task(1)), "local task remains queued");
assert_eq!(queue.pop(), None);
}
#[test]
fn failed_steal_probe_preserves_owner_local_order() {
let state = LocalQueue::test_state(3);
let queue = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for id in [1_u32, 2_u32, 3_u32] {
let record = guard.task_mut(task(id)).expect("task record missing");
record.mark_local();
}
drop(guard);
}
queue.push(task(1));
queue.push(task(2));
queue.push(task(3));
let stealer = queue.stealer();
assert_eq!(
stealer.steal(),
None,
"all-local queue should not be stealable"
);
assert_eq!(stealer.steal(), None, "repeated probes must be idempotent");
assert_eq!(queue.pop(), Some(task(3)));
assert_eq!(queue.pop(), Some(task(2)));
assert_eq!(queue.pop(), Some(task(1)));
assert_eq!(queue.pop(), None);
}
#[test]
fn steal_skips_local_tail_and_finds_remote() {
let state = LocalQueue::test_state(1);
let queue = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let record = guard.task_mut(task(0)).expect("task record missing");
record.mark_local();
drop(guard);
}
queue.push(task(0));
queue.push(task(1));
let stealer = queue.stealer();
assert_eq!(
stealer.steal(),
Some(task(1)),
"stealer should skip local tail and still find remote task"
);
assert_eq!(queue.pop(), Some(task(0)), "local task remains queued");
assert_eq!(queue.pop(), None);
}
#[test]
fn steal_batch_moves_tasks_without_loss_or_dup() {
let state = LocalQueue::test_state(7);
let src = LocalQueue::new(Arc::clone(&state));
let dest = LocalQueue::new(Arc::clone(&state));
for id in 0..8 {
src.push(task(id));
}
assert!(src.stealer().steal_batch(&dest));
let mut seen = HashSet::new();
let mut remaining = Vec::new();
while let Some(task) = src.pop() {
remaining.push(task);
}
while let Some(task) = dest.pop() {
remaining.push(task);
}
for item in remaining {
assert!(seen.insert(item), "duplicate task found: {item:?}");
}
assert_eq!(seen.len(), 8);
}
#[test]
fn interleaved_owner_thief_operations_preserve_tasks() {
let queue = queue(3);
let stealer = queue.stealer();
queue.push(task(1));
assert_eq!(stealer.steal(), Some(task(1)));
queue.push(task(2));
queue.push(task(3));
assert_eq!(queue.pop(), Some(task(3)));
assert_eq!(stealer.steal(), Some(task(2)));
assert_eq!(queue.pop(), None);
}
#[test]
fn concurrent_owner_and_stealers_preserve_tasks() {
let total: usize = 512;
let queue = Arc::new(LocalQueue::new_for_test((total - 1) as u32));
for id in 0..total {
queue.push(task(id as u32));
}
let counts: Arc<Vec<AtomicUsize>> =
Arc::new((0..total).map(|_| AtomicUsize::new(0)).collect());
let stealer_threads = 4;
let barrier = Arc::new(Barrier::new(stealer_threads + 2));
let queue_owner = Arc::clone(&queue);
let counts_owner = Arc::clone(&counts);
let barrier_owner = Arc::clone(&barrier);
let owner = thread::spawn(move || {
barrier_owner.wait();
while let Some(task) = queue_owner.pop() {
let idx = task.0.index() as usize;
counts_owner[idx].fetch_add(1, Ordering::SeqCst);
thread::yield_now();
}
});
let mut stealers = Vec::new();
for _ in 0..stealer_threads {
let stealer = queue.stealer();
let counts = Arc::clone(&counts);
let barrier = Arc::clone(&barrier);
stealers.push(thread::spawn(move || {
barrier.wait();
while let Some(task) = stealer.steal() {
let idx = task.0.index() as usize;
counts[idx].fetch_add(1, Ordering::SeqCst);
thread::yield_now();
}
}));
}
barrier.wait();
owner.join().expect("owner join");
for handle in stealers {
handle.join().expect("stealer join");
}
let mut total_seen = 0usize;
for (idx, count) in counts.iter().enumerate() {
let value = count.load(Ordering::SeqCst);
assert_eq!(value, 1, "task {idx} seen {value} times");
total_seen += value;
}
assert_eq!(total_seen, total);
}
#[test]
fn test_local_queue_push_pop() {
let queue = queue(1);
queue.push(task(1));
assert_eq!(queue.pop(), Some(task(1)));
assert_eq!(queue.pop(), None);
}
#[test]
fn task_table_backed_push_pop() {
let queue = queue_with_task_table(1);
queue.push(task(1));
assert_eq!(queue.pop(), Some(task(1)));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_local_queue_is_empty() {
let queue = queue(1);
assert!(queue.is_empty());
queue.push(task(1));
assert!(!queue.is_empty());
let _ = queue.pop();
assert!(queue.is_empty());
}
#[test]
fn test_local_queue_lifo_optimization() {
let queue = queue(5);
for i in 1..=5 {
queue.push(task(i));
}
assert_eq!(queue.pop(), Some(task(5)));
assert_eq!(queue.pop(), Some(task(4)));
assert_eq!(queue.pop(), Some(task(3)));
assert_eq!(queue.pop(), Some(task(2)));
assert_eq!(queue.pop(), Some(task(1)));
}
#[test]
fn test_steal_batch_steals_half() {
let state = LocalQueue::test_state(9);
let src = LocalQueue::new(Arc::clone(&state));
let dest = LocalQueue::new(Arc::clone(&state));
for i in 0..10 {
src.push(task(i));
}
let _ = src.stealer().steal_batch(&dest);
let mut src_count = 0;
while src.pop().is_some() {
src_count += 1;
}
let mut dest_count = 0;
while dest.pop().is_some() {
dest_count += 1;
}
assert_eq!(src_count + dest_count, 10, "no tasks should be lost");
assert!(
(4..=6).contains(&dest_count),
"should steal roughly half, got {dest_count}"
);
}
#[test]
fn test_steal_batch_steals_one() {
let state = LocalQueue::test_state(42);
let src = LocalQueue::new(Arc::clone(&state));
let dest = LocalQueue::new(Arc::clone(&state));
src.push(task(42));
let _ = src.stealer().steal_batch(&dest);
assert!(src.is_empty());
assert_eq!(dest.pop(), Some(task(42)));
}
#[test]
fn test_steal_batch_skips_local_tasks() {
let state = LocalQueue::test_state(4);
let src = LocalQueue::new(Arc::clone(&state));
let dest = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for id in [0, 1] {
if let Some(record) = guard.task_mut(task(id)) {
record.mark_local();
}
}
drop(guard);
}
for id in 0..=4 {
src.push(task(id));
}
let _ = src.stealer().steal_batch(&dest);
let mut stolen = Vec::new();
while let Some(task_id) = dest.pop() {
stolen.push(task_id);
}
assert!(
!stolen.contains(&task(0)) && !stolen.contains(&task(1)),
"local tasks must not be stolen"
);
let mut seen = HashSet::new();
for task_id in stolen {
assert!(seen.insert(task_id), "duplicate task found: {task_id:?}");
}
while let Some(task_id) = src.pop() {
assert!(seen.insert(task_id), "duplicate task found: {task_id:?}");
}
assert_eq!(seen.len(), 5, "no tasks should be lost");
}
#[test]
fn steal_batch_skips_local_without_reordering_owner_tasks() {
let state = LocalQueue::test_state(3);
let src = LocalQueue::new(Arc::clone(&state));
let dest = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for id in [1_u32, 2_u32] {
let record = guard.task_mut(task(id)).expect("task record missing");
record.mark_local();
}
drop(guard);
}
src.push(task(1));
src.push(task(2));
src.push(task(3));
assert!(
src.stealer().steal_batch(&dest),
"remote task should be stolen"
);
assert_eq!(dest.pop(), Some(task(3)));
assert_eq!(dest.pop(), None);
assert_eq!(src.pop(), Some(task(2)));
assert_eq!(src.pop(), Some(task(1)));
assert_eq!(src.pop(), None);
}
#[test]
fn task_table_backed_steal_skips_local_tasks() {
let tasks = LocalQueue::test_task_table(2);
let src = LocalQueue::new_with_task_table(Arc::clone(&tasks));
let dest = LocalQueue::new_with_task_table(Arc::clone(&tasks));
{
let mut guard = tasks
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let record = guard.task_mut(task(1)).expect("task record missing");
record.mark_local();
drop(guard);
}
src.push(task(0));
src.push(task(1));
src.push(task(2));
let _ = src.stealer().steal_batch(&dest);
let mut stolen = Vec::new();
while let Some(task_id) = dest.pop() {
stolen.push(task_id);
}
assert!(
!stolen.contains(&task(1)),
"task table-backed queue must not steal local tasks"
);
}
#[test]
fn steal_batch_many_skipped_locals_preserves_owner_order() {
let state = LocalQueue::test_state(8);
let src = LocalQueue::new(Arc::clone(&state));
let dest = LocalQueue::new(Arc::clone(&state));
{
let mut guard = state
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
for id in 0..=6 {
let record = guard.task_mut(task(id)).expect("task record missing");
record.mark_local();
}
drop(guard);
}
for id in 0..=7 {
src.push(task(id));
}
assert!(
src.stealer().steal_batch(&dest),
"remote task should be stolen"
);
assert_eq!(dest.pop(), Some(task(7)));
assert_eq!(dest.pop(), None);
for expected in (0..=6).rev() {
assert_eq!(src.pop(), Some(task(expected)));
}
assert_eq!(src.pop(), None);
}
#[test]
fn mr_task_id_relabeling_preserves_repeated_steal_batch_schedule() {
let base_layout = [
(0, true),
(1, false),
(2, false),
(3, true),
(4, false),
(5, false),
(6, true),
(7, false),
];
let relabeled_layout = [
(100, true),
(101, false),
(102, false),
(103, true),
(104, false),
(105, false),
(106, true),
(107, false),
];
let (base_rounds, base_owner_remaining) = run_repeated_steal_batch_schedule(&base_layout);
let (relabeled_rounds, relabeled_owner_remaining) =
run_repeated_steal_batch_schedule(&relabeled_layout);
assert!(
base_rounds.len() >= 2,
"fixture should exercise multiple steal rounds"
);
assert_eq!(
normalize_rounds(base_rounds, &base_layout),
normalize_rounds(relabeled_rounds, &relabeled_layout),
"relabeling task IDs must not perturb repeated steal_batch partitions"
);
assert_eq!(
normalize_task_ids(base_owner_remaining, &base_layout),
normalize_task_ids(relabeled_owner_remaining, &relabeled_layout),
"relabeling task IDs must not perturb owner-visible remaining order"
);
}
proptest! {
#[test]
fn mr_chunked_push_equivalent_for_owner_lifo_mode(
total in 1usize..64,
split in 0usize..64,
) {
let split = split.min(total);
let max_task_id = total as u32;
let individual = queue(max_task_id);
let chunked = queue(max_task_id);
for id in 0..total {
individual.push(task(id as u32));
}
let prefix: Vec<_> = (0..split).map(|id| task(id as u32)).collect();
let suffix: Vec<_> = (split..total).map(|id| task(id as u32)).collect();
chunked.push_many(&prefix);
chunked.push_many(&suffix);
let baseline = drain_owner(&individual);
let variant = drain_owner(&chunked);
let expected: Vec<_> = (0..total).rev().map(|id| task(id as u32)).collect();
prop_assert_eq!(
baseline,
expected.clone(),
"owner LIFO drain should match reverse arrival order",
);
prop_assert_eq!(
variant,
expected.clone(),
"chunking pushes must not perturb owner-visible LIFO order",
);
}
#[test]
fn mr_chunked_push_equivalent_for_thief_fifo_mode(
total in 1usize..64,
split in 0usize..64,
) {
let split = split.min(total);
let max_task_id = total as u32;
let individual = queue(max_task_id);
let chunked = queue(max_task_id);
for id in 0..total {
individual.push(task(id as u32));
}
let prefix: Vec<_> = (0..split).map(|id| task(id as u32)).collect();
let suffix: Vec<_> = (split..total).map(|id| task(id as u32)).collect();
chunked.push_many(&prefix);
chunked.push_many(&suffix);
let baseline = drain_thief(&individual);
let variant = drain_thief(&chunked);
let expected: Vec<_> = (0..total).map(|id| task(id as u32)).collect();
prop_assert_eq!(
baseline,
expected.clone(),
"thief FIFO drain should match arrival order",
);
prop_assert_eq!(
variant,
expected.clone(),
"chunking pushes must not perturb thief-visible FIFO order",
);
}
#[test]
fn mr_chunked_push_equivalent_for_repeated_steal_fifo_mode(
total in 1usize..64,
split in 0usize..64,
) {
let split = split.min(total);
let baseline_chunks = vec![(0..total as u32).collect::<Vec<_>>()];
let variant_chunks = vec![
(0..split as u32).collect::<Vec<_>>(),
(split as u32..total as u32).collect::<Vec<_>>(),
];
let (baseline_rounds, baseline_owner_remaining) =
run_repeated_steal_batch_from_push_chunks(&baseline_chunks);
let (variant_rounds, variant_owner_remaining) =
run_repeated_steal_batch_from_push_chunks(&variant_chunks);
let baseline_flattened: Vec<_> = baseline_rounds.iter().flatten().copied().collect();
let variant_flattened: Vec<_> = variant_rounds.iter().flatten().copied().collect();
let expected: Vec<_> = (0..total as u32).collect();
prop_assert_eq!(
baseline_flattened,
expected.clone(),
"repeated steal_batch should drain remote tasks in FIFO arrival order",
);
prop_assert_eq!(
variant_flattened,
expected,
"chunking pushes must not perturb repeated steal_batch FIFO order",
);
prop_assert!(
baseline_owner_remaining.is_empty(),
"all-remote repeated steal_batch schedule should drain the owner queue",
);
prop_assert_eq!(
variant_owner_remaining,
baseline_owner_remaining,
"chunking pushes must not change the owner-visible remainder under repeated steals",
);
}
#[test]
fn mr_owner_thief_mode_switch_is_atomic_without_loss_or_duplication(
total in 1usize..64,
schedule in prop::collection::vec(any::<bool>(), 1..32),
) {
let queue = queue(total as u32);
for id in 0..total {
queue.push(task(id as u32));
}
let stealer = queue.stealer();
let mut owner_seen = Vec::new();
let mut thief_seen = Vec::new();
let mut all_seen = HashSet::new();
let mut step = 0usize;
while !queue.is_empty() {
let owner_turn = schedule[step % schedule.len()];
let next = if owner_turn { queue.pop() } else { stealer.steal() };
let next = next.expect("non-local task should be available to the selected mode");
prop_assert!(
all_seen.insert(next),
"mode switches must not duplicate tasks across owner/thief drains",
);
if owner_turn {
owner_seen.push(next.0.index());
} else {
thief_seen.push(next.0.index());
}
step += 1;
}
let mut normalized_all: Vec<_> = all_seen.into_iter().map(|task_id| task_id.0.index()).collect();
normalized_all.sort_unstable();
let expected_all: Vec<_> = (0..total as u32).collect();
prop_assert_eq!(
normalized_all,
expected_all,
"mode switches must preserve the exact task set",
);
prop_assert!(
owner_seen.windows(2).all(|pair| pair[0] > pair[1]),
"owner observations must remain strictly LIFO across switches",
);
prop_assert!(
thief_seen.windows(2).all(|pair| pair[0] < pair[1]),
"thief observations must remain strictly FIFO across switches",
);
}
}
#[test]
fn test_local_queue_stealer_clone() {
let queue = queue(2);
queue.push(task(1));
queue.push(task(2));
let stealer1 = queue.stealer();
let stealer2 = stealer1.clone();
let t1 = stealer1.steal();
let t2 = stealer2.steal();
assert!(t1.is_some());
assert!(t2.is_some());
assert_ne!(t1, t2, "stealers should get different tasks");
}
#[test]
fn concurrent_bidirectional_steal_batch_does_not_deadlock_or_lose_tasks() {
let state = LocalQueue::test_state(63);
let left = Arc::new(LocalQueue::new(Arc::clone(&state)));
let right = Arc::new(LocalQueue::new(Arc::clone(&state)));
for id in 0..32 {
left.push(task(id));
}
for id in 32..64 {
right.push(task(id));
}
let barrier = Arc::new(Barrier::new(3));
let left_for_t1 = Arc::clone(&left);
let right_for_t1 = Arc::clone(&right);
let barrier_t1 = Arc::clone(&barrier);
let t1 = thread::spawn(move || {
let stealer = right_for_t1.stealer();
barrier_t1.wait();
for _ in 0..64 {
let _ = stealer.steal_batch(&left_for_t1);
thread::yield_now();
}
});
let left_for_t2 = Arc::clone(&left);
let right_for_t2 = Arc::clone(&right);
let barrier_t2 = Arc::clone(&barrier);
let t2 = thread::spawn(move || {
let stealer = left_for_t2.stealer();
barrier_t2.wait();
for _ in 0..64 {
let _ = stealer.steal_batch(&right_for_t2);
thread::yield_now();
}
});
barrier.wait();
t1.join().expect("first steal-batch thread should complete");
t2.join()
.expect("second steal-batch thread should complete");
let mut seen = HashSet::new();
while let Some(task_id) = left.pop() {
assert!(seen.insert(task_id), "duplicate task found: {task_id:?}");
}
while let Some(task_id) = right.pop() {
assert!(seen.insert(task_id), "duplicate task found: {task_id:?}");
}
assert_eq!(seen.len(), 64, "all tasks should remain accounted for");
}
#[test]
fn steal_batch_rejects_different_task_sources_without_mutation() {
let src = queue(3);
let dest = queue_with_task_table(3);
src.push(task(1));
src.push(task(2));
assert!(
!src.stealer().steal_batch(&dest),
"steal_batch must reject cross-arena transfer"
);
assert_eq!(dest.pop(), None, "destination must remain unchanged");
assert_eq!(src.pop(), Some(task(2)));
assert_eq!(src.pop(), Some(task(1)));
assert_eq!(src.pop(), None);
}
#[test]
fn test_local_queue_high_volume() {
let count = 10_000;
let queue = queue(count - 1);
for i in 0..count {
queue.push(task(i));
}
let mut popped = 0;
while queue.pop().is_some() {
popped += 1;
}
assert_eq!(popped, count, "should pop exactly {count} tasks");
}
#[test]
fn test_local_queue_mixed_push_pop() {
let queue = queue(3);
queue.push(task(1));
queue.push(task(2));
assert_eq!(queue.pop(), Some(task(2)));
queue.push(task(3));
assert_eq!(queue.pop(), Some(task(3)));
assert_eq!(queue.pop(), Some(task(1)));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_local_queue_push_many_lifo_order() {
let queue = queue(4);
queue.push_many(&[task(1), task(2), task(3), task(4)]);
assert_eq!(queue.pop(), Some(task(4)));
assert_eq!(queue.pop(), Some(task(3)));
assert_eq!(queue.pop(), Some(task(2)));
assert_eq!(queue.pop(), Some(task(1)));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_steal_from_empty_is_idempotent() {
let queue = queue(0);
let stealer = queue.stealer();
for _ in 0..10 {
assert!(stealer.steal().is_none());
}
}
#[test]
fn test_steal_batch_from_empty() {
let state = LocalQueue::test_state(0);
let src = LocalQueue::new(Arc::clone(&state));
let dest = LocalQueue::new(Arc::clone(&state));
let result = src.stealer().steal_batch(&dest);
assert!(!result, "steal_batch from empty should return false");
assert!(dest.is_empty());
}
#[test]
fn schedule_local_returns_false_when_task_record_missing() {
let queue = queue(0);
let _guard = LocalQueue::set_current(queue.clone());
let scheduled = LocalQueue::schedule_local(task(1));
assert!(
!scheduled,
"schedule_local should report failure for missing task records"
);
assert!(queue.is_empty(), "queue should remain unchanged");
}
#[test]
fn schedule_local_duplicate_still_reports_success() {
let queue = queue(1);
queue.push(task(1));
let _guard = LocalQueue::set_current(queue.clone());
let scheduled = LocalQueue::schedule_local(task(1));
assert!(
scheduled,
"duplicate scheduling should still report success (already queued)"
);
assert_eq!(queue.pop(), Some(task(1)));
assert_eq!(
queue.pop(),
None,
"duplicate schedule must not enqueue twice"
);
}
#[test]
fn concurrent_steal_and_schedule_local_push_no_deadlock() {
let state = LocalQueue::test_state(99);
let queue = LocalQueue::new(Arc::clone(&state));
for id in 0..50 {
queue.push(task(id));
}
let stealer = queue.stealer();
let schedule_queue = queue;
let barrier = Arc::new(Barrier::new(3));
let done = Arc::new(AtomicBool::new(false));
let b1 = Arc::clone(&barrier);
let d1 = Arc::clone(&done);
let t1 = thread::spawn(move || {
b1.wait();
while !d1.load(Ordering::Relaxed) {
let _ = stealer.steal();
thread::yield_now();
}
});
let b2 = Arc::clone(&barrier);
let d2 = Arc::clone(&done);
let t2 = thread::spawn(move || {
let _guard = LocalQueue::set_current(schedule_queue);
b2.wait();
for round in 0..200 {
let id = 50 + (round % 50);
LocalQueue::schedule_local(task(id));
thread::yield_now();
}
d2.store(true, Ordering::Relaxed);
});
barrier.wait();
t1.join()
.expect("steal thread should complete without deadlock");
t2.join()
.expect("schedule_local thread should complete without deadlock");
}
}