use std::cell::UnsafeCell;
use std::sync::atomic::{fence, AtomicIsize, AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::thread;
use crate::error::{CoreError, CoreResult, ErrorContext, ErrorLocation};
const INITIAL_CAPACITY: usize = 64;
struct CircularBuf<T> {
cap: usize,
data: Box<[UnsafeCell<Option<T>>]>,
}
impl<T> CircularBuf<T> {
fn new(cap: usize) -> Self {
let data = (0..cap)
.map(|_| UnsafeCell::new(None))
.collect::<Vec<_>>()
.into_boxed_slice();
Self { cap, data }
}
fn mask(&self) -> usize {
self.cap - 1
}
unsafe fn write(&self, i: usize, val: T) {
let slot = self.data[i & self.mask()].get();
unsafe { (*slot) = Some(val) };
}
unsafe fn read(&self, i: usize) -> Option<T> {
let slot = self.data[i & self.mask()].get();
unsafe { (*slot).take() }
}
}
unsafe impl<T: Send> Send for CircularBuf<T> {}
unsafe impl<T: Send> Sync for CircularBuf<T> {}
pub struct WorkStealingDeque<T: Send + 'static> {
bottom: AtomicIsize,
top: AtomicIsize,
buf: Mutex<Arc<CircularBuf<T>>>,
}
#[derive(Debug)]
pub enum StealResult<T> {
Success(T),
Empty,
Retry,
}
impl<T: Send + 'static> WorkStealingDeque<T> {
pub fn new() -> Self {
Self {
bottom: AtomicIsize::new(0),
top: AtomicIsize::new(0),
buf: Mutex::new(Arc::new(CircularBuf::new(INITIAL_CAPACITY))),
}
}
pub fn len(&self) -> usize {
let b = self.bottom.load(Ordering::Relaxed);
let t = self.top.load(Ordering::Relaxed);
(b - t).max(0) as usize
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn push(&self, task: T) -> CoreResult<()> {
let b = self.bottom.load(Ordering::Relaxed);
let t = self.top.load(Ordering::Acquire);
let size = (b - t) as usize;
let buf: Arc<CircularBuf<T>> = {
let guard = self.buf.lock().map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("deque buf lock poisoned: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
Arc::clone(&*guard)
};
let buf: Arc<CircularBuf<T>> = if size >= buf.cap - 1 {
let new_cap = buf.cap * 2;
let new_buf = Arc::new(CircularBuf::new(new_cap));
for i in t..b {
unsafe {
let val = buf.read(i as usize);
if let Some(v) = val {
new_buf.write(i as usize, v);
}
}
}
let mut guard = self.buf.lock().map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("deque buf lock poisoned during grow: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
*guard = Arc::clone(&new_buf);
new_buf
} else {
buf
};
unsafe { buf.write(b as usize, task) };
fence(Ordering::Release);
self.bottom.store(b + 1, Ordering::Relaxed);
Ok(())
}
pub fn pop(&self) -> CoreResult<Option<T>> {
let b = self.bottom.load(Ordering::Relaxed) - 1;
let buf: Arc<CircularBuf<T>> = {
let guard = self.buf.lock().map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("deque buf lock poisoned on pop: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
Arc::clone(&*guard)
};
self.bottom.store(b, Ordering::Relaxed);
fence(Ordering::SeqCst);
let t = self.top.load(Ordering::Relaxed);
if t > b {
self.bottom.store(b + 1, Ordering::Relaxed);
return Ok(None);
}
let task = unsafe { buf.read(b as usize) };
if t == b {
let stolen = self
.top
.compare_exchange(t, t + 1, Ordering::SeqCst, Ordering::Relaxed)
.is_err();
self.bottom.store(b + 1, Ordering::Relaxed);
if stolen {
return Ok(None);
}
}
Ok(task)
}
pub fn steal(&self) -> StealResult<T> {
let t = self.top.load(Ordering::Acquire);
fence(Ordering::SeqCst);
let b = self.bottom.load(Ordering::Acquire);
if t >= b {
return StealResult::Empty;
}
let buf = match self.buf.lock() {
Ok(g) => Arc::clone(&*g),
Err(_) => return StealResult::Retry,
};
let task = unsafe { buf.read(t as usize) };
match self
.top
.compare_exchange(t, t + 1, Ordering::SeqCst, Ordering::Relaxed)
{
Ok(_) => match task {
Some(v) => StealResult::Success(v),
None => StealResult::Retry,
},
Err(_) => StealResult::Retry,
}
}
}
impl<T: Send + 'static> Default for WorkStealingDeque<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Priority {
High = 2,
Normal = 1,
Low = 0,
}
type BoxTask = Box<dyn FnOnce() + Send + 'static>;
struct PriorityItem {
priority: Priority,
seq: u64, task: BoxTask,
}
impl PartialEq for PriorityItem {
fn eq(&self, other: &Self) -> bool {
self.priority == other.priority && self.seq == other.seq
}
}
impl Eq for PriorityItem {}
impl PartialOrd for PriorityItem {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PriorityItem {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.priority
.cmp(&other.priority)
.then_with(|| other.seq.cmp(&self.seq))
}
}
pub struct PriorityTaskQueue {
inner: Mutex<PriorityQueueInner>,
not_empty: Condvar,
not_full: Condvar,
capacity: usize,
seq: AtomicUsize,
}
struct PriorityQueueInner {
heap: std::collections::BinaryHeap<PriorityItem>,
closed: bool,
}
impl PriorityTaskQueue {
pub fn new(capacity: usize) -> Self {
let cap = capacity.max(1);
Self {
inner: Mutex::new(PriorityQueueInner {
heap: std::collections::BinaryHeap::with_capacity(cap),
closed: false,
}),
not_empty: Condvar::new(),
not_full: Condvar::new(),
capacity: cap,
seq: AtomicUsize::new(0),
}
}
pub fn submit<F>(&self, priority: Priority, f: F) -> CoreResult<()>
where
F: FnOnce() + Send + 'static,
{
let seq = self.seq.fetch_add(1, Ordering::Relaxed) as u64;
let item = PriorityItem {
priority,
seq,
task: Box::new(f),
};
let mut guard = self.inner.lock().map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("priority queue lock poisoned on submit: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
loop {
if guard.closed {
return Err(CoreError::InvalidInput(ErrorContext::new(
"PriorityTaskQueue: queue is closed",
)));
}
if guard.heap.len() < self.capacity {
break;
}
guard = self.not_full.wait(guard).map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("condvar wait poisoned: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
}
guard.heap.push(item);
self.not_empty.notify_one();
Ok(())
}
pub fn try_submit<F>(&self, priority: Priority, f: F) -> CoreResult<()>
where
F: FnOnce() + Send + 'static,
{
let seq = self.seq.fetch_add(1, Ordering::Relaxed) as u64;
let item = PriorityItem {
priority,
seq,
task: Box::new(f),
};
let mut guard = self.inner.lock().map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("priority queue lock poisoned on try_submit: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
if guard.closed {
return Err(CoreError::InvalidInput(ErrorContext::new(
"PriorityTaskQueue: queue is closed",
)));
}
if guard.heap.len() >= self.capacity {
return Err(CoreError::InvalidInput(ErrorContext::new(
"PriorityTaskQueue: queue is full",
)));
}
guard.heap.push(item);
self.not_empty.notify_one();
Ok(())
}
pub fn dequeue(&self) -> CoreResult<Option<BoxTask>> {
let mut guard = self.inner.lock().map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("priority queue lock poisoned on dequeue: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
loop {
if let Some(item) = guard.heap.pop() {
self.not_full.notify_one();
return Ok(Some(item.task));
}
if guard.closed {
return Ok(None);
}
guard = self.not_empty.wait(guard).map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("condvar wait poisoned on dequeue: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
}
}
pub fn try_dequeue(&self) -> CoreResult<Option<BoxTask>> {
let mut guard = self.inner.lock().map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("priority queue lock poisoned on try_dequeue: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
match guard.heap.pop() {
Some(item) => {
self.not_full.notify_one();
Ok(Some(item.task))
}
None => Ok(None),
}
}
pub fn close(&self) {
if let Ok(mut g) = self.inner.lock() {
g.closed = true;
}
self.not_empty.notify_all();
self.not_full.notify_all();
}
pub fn pending(&self) -> usize {
self.inner.lock().map(|g| g.heap.len()).unwrap_or(0)
}
}
#[derive(Debug, Clone)]
pub struct SchedulerConfig {
pub num_workers: usize,
pub steal_attempts: usize,
pub idle_sleep_us: u64,
}
impl Default for SchedulerConfig {
fn default() -> Self {
Self {
num_workers: 0,
steal_attempts: 32,
idle_sleep_us: 100,
}
}
}
#[derive(Debug, Default, Clone)]
pub struct SchedulerStats {
pub tasks_completed: u64,
pub steal_successes: u64,
pub steal_failures: u64,
}
type StatsCell = Arc<Mutex<SchedulerStats>>;
pub struct WorkStealingScheduler {
deques: Arc<Vec<Arc<WorkStealingDeque<BoxTask>>>>,
handles: Vec<thread::JoinHandle<()>>,
stop: Arc<std::sync::atomic::AtomicBool>,
stats: StatsCell,
next_push: AtomicUsize,
}
impl WorkStealingScheduler {
pub fn new(cfg: SchedulerConfig) -> CoreResult<Self> {
let n = if cfg.num_workers == 0 {
thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(4)
} else {
cfg.num_workers
};
if n == 0 {
return Err(CoreError::InvalidInput(ErrorContext::new(
"WorkStealingScheduler: num_workers must be >= 1",
)));
}
let stop = Arc::new(std::sync::atomic::AtomicBool::new(false));
let stats: StatsCell = Arc::new(Mutex::new(SchedulerStats::default()));
let deques: Arc<Vec<Arc<WorkStealingDeque<BoxTask>>>> =
Arc::new((0..n).map(|_| Arc::new(WorkStealingDeque::new())).collect());
let mut handles = Vec::with_capacity(n);
for id in 0..n {
let deques2 = Arc::clone(&deques);
let stop2 = Arc::clone(&stop);
let stats2 = Arc::clone(&stats);
let steal_attempts = cfg.steal_attempts;
let idle_sleep_us = cfg.idle_sleep_us;
let handle = thread::Builder::new()
.name(format!("ws-worker-{id}"))
.spawn(move || {
worker_loop(id, n, deques2, stop2, stats2, steal_attempts, idle_sleep_us);
})
.map_err(|e| {
CoreError::SchedulerError(
ErrorContext::new(format!("failed to spawn worker {id}: {e}"))
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
handles.push(handle);
}
Ok(Self {
deques,
handles,
stop,
stats,
next_push: AtomicUsize::new(0),
})
}
pub fn submit<F>(&self, f: F) -> CoreResult<()>
where
F: FnOnce() + Send + 'static,
{
let idx = self.next_push.fetch_add(1, Ordering::Relaxed) % self.deques.len();
self.deques[idx].push(Box::new(f))
}
pub fn num_workers(&self) -> usize {
self.deques.len()
}
pub fn stats(&self) -> SchedulerStats {
self.stats.lock().map(|g| g.clone()).unwrap_or_default()
}
pub fn shutdown(self) -> CoreResult<()> {
self.stop.store(true, Ordering::SeqCst);
for h in self.handles {
h.join().map_err(|_| {
CoreError::SchedulerError(
ErrorContext::new("worker thread panicked during shutdown")
.with_location(ErrorLocation::new(file!(), line!())),
)
})?;
}
Ok(())
}
}
fn worker_loop(
id: usize,
n: usize,
deques: Arc<Vec<Arc<WorkStealingDeque<BoxTask>>>>,
stop: Arc<std::sync::atomic::AtomicBool>,
stats: StatsCell,
steal_attempts: usize,
idle_sleep_us: u64,
) {
let mut local_completed = 0u64;
let mut local_steals = 0u64;
let mut local_failures = 0u64;
loop {
let own = match deques[id].steal() {
StealResult::Success(task) => {
task();
local_completed += 1;
true
}
_ => false,
};
if own {
continue;
}
let mut stole = false;
'steal: for attempt in 0..steal_attempts {
let victim = (id + 1 + attempt) % n;
if victim == id {
continue;
}
match deques[victim].steal() {
StealResult::Success(task) => {
task();
local_completed += 1;
local_steals += 1;
stole = true;
break 'steal;
}
StealResult::Empty => {}
StealResult::Retry => {
local_failures += 1;
}
}
}
if stole {
continue;
}
if let StealResult::Success(task) = deques[id].steal() {
task();
local_completed += 1;
continue;
}
if stop.load(Ordering::Relaxed) {
break;
}
thread::sleep(std::time::Duration::from_micros(idle_sleep_us));
}
if let Ok(mut g) = stats.lock() {
g.tasks_completed += local_completed;
g.steal_successes += local_steals;
g.steal_failures += local_failures;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicU64;
#[test]
fn deque_push_pop_single_thread() {
let dq: WorkStealingDeque<i32> = WorkStealingDeque::new();
assert!(dq.is_empty());
dq.push(1).expect("push 1");
dq.push(2).expect("push 2");
dq.push(3).expect("push 3");
assert_eq!(dq.len(), 3);
assert_eq!(dq.pop().expect("pop"), Some(3));
assert_eq!(dq.pop().expect("pop"), Some(2));
assert_eq!(dq.pop().expect("pop"), Some(1));
assert_eq!(dq.pop().expect("pop"), None);
}
#[test]
fn deque_steal_basic() {
let dq = Arc::new(WorkStealingDeque::<i32>::new());
dq.push(10).expect("push");
dq.push(20).expect("push");
let dq2 = Arc::clone(&dq);
let stealer = thread::spawn(move || loop {
match dq2.steal() {
StealResult::Success(v) => return v,
StealResult::Empty => return -1,
StealResult::Retry => {}
}
});
let stolen = stealer.join().expect("stealer thread");
assert!(stolen == 10 || stolen == 20 || stolen == -1);
}
#[test]
fn deque_grows_automatically() {
let dq: WorkStealingDeque<usize> = WorkStealingDeque::new();
for i in 0..200 {
dq.push(i).expect("push");
}
let mut collected = Vec::new();
while let Ok(Some(v)) = dq.pop() {
collected.push(v);
}
assert_eq!(collected.len(), 200);
}
#[test]
fn priority_queue_ordering() {
let q = Arc::new(PriorityTaskQueue::new(16));
let results = Arc::new(Mutex::new(Vec::new()));
let r1 = Arc::clone(&results);
q.submit(Priority::Low, move || {
r1.lock().expect("lock").push("low");
})
.expect("submit low");
let r2 = Arc::clone(&results);
q.submit(Priority::High, move || {
r2.lock().expect("lock").push("high");
})
.expect("submit high");
let r3 = Arc::clone(&results);
q.submit(Priority::Normal, move || {
r3.lock().expect("lock").push("normal");
})
.expect("submit normal");
q.close();
while let Ok(Some(task)) = q.dequeue() {
task();
}
let res = results.lock().expect("lock");
assert_eq!(*res, vec!["high", "normal", "low"]);
}
#[test]
fn priority_queue_fifo_within_level() {
let q = Arc::new(PriorityTaskQueue::new(32));
let results = Arc::new(Mutex::new(Vec::new()));
for i in 0..5u32 {
let r = Arc::clone(&results);
q.submit(Priority::Normal, move || {
r.lock().expect("lock").push(i);
})
.expect("submit");
}
q.close();
while let Ok(Some(task)) = q.dequeue() {
task();
}
let res = results.lock().expect("lock");
assert_eq!(*res, vec![0, 1, 2, 3, 4]);
}
#[test]
fn scheduler_runs_tasks() {
let cfg = SchedulerConfig {
num_workers: 4,
steal_attempts: 16,
idle_sleep_us: 100,
};
let sched = WorkStealingScheduler::new(cfg).expect("new scheduler");
let counter = Arc::new(AtomicU64::new(0));
let n_tasks = 100usize;
for _ in 0..n_tasks {
let c = Arc::clone(&counter);
sched
.submit(move || {
c.fetch_add(1, Ordering::Relaxed);
})
.expect("submit");
}
let deadline = std::time::Instant::now() + std::time::Duration::from_secs(30);
while counter.load(Ordering::Relaxed) < n_tasks as u64 {
if std::time::Instant::now() > deadline {
break;
}
thread::sleep(std::time::Duration::from_millis(1));
}
assert_eq!(counter.load(Ordering::Relaxed), n_tasks as u64);
sched.shutdown().expect("shutdown");
}
}