use crate::runtime::RuntimeState;
use crate::runtime::scheduler::ThreeLaneScheduler;
use crate::runtime::scheduler::local_queue::LocalQueue;
use crate::runtime::scheduler::stealing::steal_task;
use crate::sync::ContendedMutex;
use crate::types::TaskId;
use crate::util::DetRng;
use std::collections::HashSet;
use std::sync::Arc;
use proptest::prelude::*;
fn seq_tasks(count: u32) -> Vec<TaskId> {
(0..count).map(|i| TaskId::new_for_test(i, 0)).collect()
}
fn scheduler(worker_count: usize) -> ThreeLaneScheduler {
let state = Arc::new(ContendedMutex::new(
"ws_fairness.runtime_state",
RuntimeState::new(),
));
ThreeLaneScheduler::new(worker_count, &state)
}
fn drain_round_robin(
workers: &mut [crate::runtime::scheduler::ThreeLaneWorker],
task_count: usize,
) -> (Vec<TaskId>, usize) {
let mut polled = Vec::with_capacity(task_count);
let mut ticks = 0_usize;
let max_ticks = task_count
.saturating_mul(workers.len().max(1))
.saturating_mul(8)
+ 16;
while polled.len() < task_count && ticks < max_ticks {
let mut progressed = false;
for w in workers.iter_mut() {
if let Some(task) = w.next_task() {
polled.push(task);
progressed = true;
if polled.len() >= task_count {
break;
}
}
}
ticks += 1;
if !progressed {
break;
}
}
(polled, ticks)
}
proptest! {
#[test]
fn mr_ws1_all_tasks_polled_exactly_once(
task_count in 4usize..40,
worker_count in 2usize..6,
) {
let tasks = seq_tasks(task_count as u32);
let mut sched = scheduler(worker_count);
for &t in &tasks {
sched.inject_ready(t, 100);
}
let mut workers = sched.take_workers();
let (polled, _ticks) = drain_round_robin(&mut workers, task_count);
prop_assert_eq!(
polled.len(),
task_count,
"MR-WS1 VIOLATION: not all tasks polled ({} of {})",
polled.len(),
task_count,
);
let unique: HashSet<TaskId> = polled.iter().copied().collect();
prop_assert_eq!(
unique.len(),
task_count,
"MR-WS1 VIOLATION: duplicate polls detected ({} unique of {})",
unique.len(),
task_count,
);
let expected: HashSet<TaskId> = tasks.iter().copied().collect();
prop_assert_eq!(
unique,
expected,
"MR-WS1 VIOLATION: polled set differs from spawned set",
);
}
}
#[test]
fn mr_ws2_owner_lifo_dual_to_stealer_fifo() {
const N: u32 = 12;
let tasks = seq_tasks(N);
let owner_q = LocalQueue::new_for_test(N - 1);
for &t in &tasks {
owner_q.push(t);
}
let mut owner_order = Vec::with_capacity(N as usize);
while let Some(t) = owner_q.pop() {
owner_order.push(t);
}
let thief_q = LocalQueue::new_for_test(N - 1);
for &t in &tasks {
thief_q.push(t);
}
let stealer = thief_q.stealer();
let mut stealer_order = Vec::with_capacity(N as usize);
while let Some(t) = stealer.steal() {
stealer_order.push(t);
}
assert_eq!(
owner_order.len(),
N as usize,
"MR-WS2: owner did not drain every task",
);
assert_eq!(
stealer_order.len(),
N as usize,
"MR-WS2: stealer did not drain every task",
);
let mut expected_lifo: Vec<TaskId> = tasks.clone();
expected_lifo.reverse();
assert_eq!(
owner_order, expected_lifo,
"MR-WS2 VIOLATION: owner pop order is not LIFO",
);
assert_eq!(
stealer_order, tasks,
"MR-WS2 VIOLATION: stealer order is not FIFO",
);
let mut dual = owner_order.clone();
dual.reverse();
assert_eq!(
dual, stealer_order,
"MR-WS2 VIOLATION: reverse(owner_lifo) != stealer_fifo — the LIFO/FIFO duality is broken",
);
}
proptest! {
#[test]
fn mr_ws2_prop_reverse_owner_equals_stealer(
len in 2u32..30,
) {
let tasks = seq_tasks(len);
let a = LocalQueue::new_for_test(len - 1);
for &t in &tasks { a.push(t); }
let mut owner_order = Vec::with_capacity(len as usize);
while let Some(t) = a.pop() { owner_order.push(t); }
let b = LocalQueue::new_for_test(len - 1);
for &t in &tasks { b.push(t); }
let s = b.stealer();
let mut stealer_order = Vec::with_capacity(len as usize);
while let Some(t) = s.steal() { stealer_order.push(t); }
prop_assert_eq!(owner_order.len(), len as usize);
prop_assert_eq!(stealer_order.len(), len as usize);
let mut dual = owner_order.clone();
dual.reverse();
prop_assert_eq!(
dual,
stealer_order,
"MR-WS2 VIOLATION: reverse(owner_lifo) != stealer_fifo for len={}",
len,
);
}
}
proptest! {
#[test]
fn mr_ws3_no_worker_starves_under_contention(
worker_count in 2usize..5,
ratio in 3usize..8, ) {
let task_count = worker_count * ratio;
let tasks = seq_tasks(task_count as u32);
let mut sched = scheduler(worker_count);
for &t in &tasks { sched.inject_ready(t, 100); }
let mut workers = sched.take_workers();
let mut per_worker_polls = vec![0_usize; worker_count];
let mut per_worker_max_idle = vec![0_usize; worker_count];
let mut per_worker_cur_idle = vec![0_usize; worker_count];
let max_rounds = task_count * worker_count * 4 + 8;
let mut rounds = 0_usize;
let mut remaining = task_count;
while remaining > 0 && rounds < max_rounds {
let mut progressed = false;
for (i, w) in workers.iter_mut().enumerate() {
if remaining == 0 { break; }
if let Some(_task) = w.next_task() {
per_worker_polls[i] += 1;
per_worker_cur_idle[i] = 0;
remaining -= 1;
progressed = true;
} else {
per_worker_cur_idle[i] += 1;
per_worker_max_idle[i] =
per_worker_max_idle[i].max(per_worker_cur_idle[i]);
}
}
rounds += 1;
if !progressed {
break;
}
}
prop_assert_eq!(
remaining,
0,
"MR-WS3 VIOLATION: scheduler left {} tasks undrained after {} rounds",
remaining,
rounds,
);
for (i, &polls) in per_worker_polls.iter().enumerate() {
prop_assert!(
polls >= 1,
"MR-WS3 VIOLATION: worker {} polled zero tasks (W={}, M={}); \
stealing appears disabled",
i,
worker_count,
task_count,
);
}
let k: usize = 4;
let idle_bound = k.saturating_mul(worker_count);
for (i, &max_idle) in per_worker_max_idle.iter().enumerate() {
prop_assert!(
max_idle <= idle_bound,
"MR-WS3 VIOLATION: worker {} idled {} consecutive rounds \
while work remained (bound={}, W={}, M={})",
i,
max_idle,
idle_bound,
worker_count,
task_count,
);
}
}
}
#[test]
fn mr_ws4_stealers_alone_drain_all_tasks() {
const W: usize = 4;
const M: u32 = 24;
let q0 = LocalQueue::new_for_test(M - 1);
let tasks = seq_tasks(M);
for &t in &tasks {
q0.push(t);
}
let stealers: Vec<_> = std::iter::once(q0.stealer()).collect();
let mut rng = DetRng::new(0xC0FFEE);
let mut drained: HashSet<TaskId> = HashSet::new();
let mut iters = 0_usize;
let max_iters = (M as usize) * W * 4;
loop {
let owner_pop = q0.pop();
let thief_pop = steal_task(&stealers, &mut rng);
let owner_progressed = if let Some(t) = owner_pop {
assert!(drained.insert(t), "duplicate pop for {t:?}");
true
} else {
false
};
let thief_progressed = if let Some(t) = thief_pop {
assert!(drained.insert(t), "duplicate steal for {t:?}");
true
} else {
false
};
let progressed = owner_progressed || thief_progressed;
iters += 1;
if !progressed || drained.len() == M as usize || iters >= max_iters {
break;
}
}
assert_eq!(
drained.len(),
M as usize,
"MR-WS4 VIOLATION: {} of {} tasks drained via mixed owner+steal",
drained.len(),
M,
);
let expected: HashSet<TaskId> = tasks.iter().copied().collect();
assert_eq!(
drained, expected,
"MR-WS4 VIOLATION: drained set differs from spawned set",
);
}
proptest! {
#[test]
fn mr_ws5_stealing_identity(
task_count in 10usize..100,
worker_count in 2usize..16,
seed in any::<u64>(),
) {
use crate::util::DetRng;
let setup_queues = || {
let queues: Vec<_> = (0..worker_count)
.map(|_| LocalQueue::new_for_test(128))
.collect();
for i in 0..task_count {
queues[i % worker_count].push(TaskId::new_for_test(i as u32, 0));
}
queues
};
let run_steal = |queues: &Vec<LocalQueue>| {
let stealers: Vec<_> = queues.iter().map(|q| q.stealer()).collect();
let mut rng = DetRng::new(seed);
let mut stolen = Vec::new();
while let Some(t) = steal_task(&stealers, &mut rng) {
stolen.push(t);
}
stolen
};
let queues1 = setup_queues();
let stolen1 = run_steal(&queues1);
let queues2 = setup_queues();
let stolen2 = run_steal(&queues2);
prop_assert_eq!(
stolen1.len(),
task_count,
"MR-WS5 VIOLATION: steal path failed to drain all tasks",
);
prop_assert_eq!(stolen1, stolen2, "MR-WS5 VIOLATION: Identity failed");
}
}
proptest! {
#[test]
fn mr_ws6_stealing_permutation_invariance(
task_count in 10usize..100,
worker_count in 2usize..16,
seed in any::<u64>(),
) {
use crate::util::DetRng;
let queues: Vec<_> = (0..worker_count)
.map(|_| LocalQueue::new_for_test(128))
.collect();
for i in 0..task_count {
queues[i % worker_count].push(TaskId::new_for_test(i as u32, 0));
}
let stealers: Vec<_> = queues.iter().map(|q| q.stealer()).collect();
let mut rng1 = DetRng::new(seed);
let mut stolen1 = Vec::new();
while let Some(t) = steal_task(&stealers, &mut rng1) {
stolen1.push(t);
}
for i in 0..task_count {
queues[i % worker_count].push(TaskId::new_for_test(i as u32, 0));
}
let mut permuted_stealers: Vec<_> = queues.iter().map(|q| q.stealer()).collect();
permuted_stealers.reverse();
let mut rng2 = DetRng::new(seed);
let mut stolen2 = Vec::new();
while let Some(t) = steal_task(&permuted_stealers, &mut rng2) {
stolen2.push(t);
}
let mut s1 = stolen1.clone();
s1.sort();
let mut s2 = stolen2.clone();
s2.sort();
prop_assert_eq!(
stolen1.len(),
task_count,
"MR-WS6 VIOLATION: base steal path failed to drain all tasks",
);
prop_assert_eq!(
stolen2.len(),
task_count,
"MR-WS6 VIOLATION: permuted steal path failed to drain all tasks",
);
prop_assert_eq!(s1, s2, "MR-WS6 VIOLATION: Permutation invariance failed");
}
}
proptest! {
#[test]
fn mr_ws7_scaling_load_balance(
tasks_per_worker in 10usize..50,
worker_count in 1usize..32,
) {
let task_count = tasks_per_worker * worker_count;
let mut sched = scheduler(worker_count);
for i in 0..task_count {
sched.inject_ready(TaskId::new_for_test(i as u32, 0), 100);
}
let mut workers = sched.take_workers();
let mut process_counts = vec![0usize; worker_count];
let mut progressed = true;
while progressed {
progressed = false;
for (i, w) in workers.iter_mut().enumerate() {
if let Some(_task) = w.next_task() {
process_counts[i] += 1;
progressed = true;
}
}
}
let total_processed: usize = process_counts.iter().sum();
prop_assert_eq!(total_processed, task_count, "MR-WS7 VIOLATION: Did not process all tasks");
if worker_count > 1 {
let max_processed = *process_counts.iter().max().unwrap();
let expected_avg = task_count / worker_count;
let bound = (expected_avg * 4).max(expected_avg + 40);
prop_assert!(
max_processed <= bound,
"MR-WS7 VIOLATION: Load balance failed. Max {} > Bound {} (Avg {}). Distribution: {:?}",
max_processed, bound, expected_avg, process_counts
);
}
}
}