use std::{
collections::VecDeque,
sync::{atomic::AtomicBool, Arc, Condvar, Mutex},
};
pub(crate) struct WorkStealingQueue<T> {
inner: Arc<Inner<T>>,
}
struct Inner<T> {
queue: Mutex<VecDeque<T>>,
condvar: Condvar,
closed: AtomicBool,
}
impl<T> WorkStealingQueue<T> {
pub(crate) fn new() -> Self {
Self {
inner: Arc::new(Inner {
queue: Mutex::new(VecDeque::new()),
condvar: Condvar::new(),
closed: AtomicBool::new(false),
}),
}
}
pub(crate) fn worker(&self) -> WorkerHandle<T> {
WorkerHandle {
inner: Arc::clone(&self.inner),
}
}
pub(crate) fn push(&self, item: T) -> bool {
if self
.inner
.closed
.load(core::sync::atomic::Ordering::Acquire)
{
return false;
}
{
let mut queue = self.inner.queue.lock().unwrap();
queue.push_back(item);
}
self.inner.condvar.notify_one();
true
}
pub(crate) fn close(&self) {
self.inner
.closed
.store(true, core::sync::atomic::Ordering::Release);
self.inner.condvar.notify_all();
}
pub(crate) fn len(&self) -> usize {
self.inner.queue.lock().unwrap().len()
}
pub(crate) fn is_empty(&self) -> bool {
self.inner.queue.lock().unwrap().is_empty()
}
}
impl<T> Default for WorkStealingQueue<T> {
fn default() -> Self {
Self::new()
}
}
pub(crate) struct WorkerHandle<T> {
inner: Arc<Inner<T>>,
}
impl<T> WorkerHandle<T> {
pub(crate) fn steal(&self) -> Option<T> {
let mut queue = self.inner.queue.lock().unwrap();
loop {
if let Some(item) = queue.pop_front() {
return Some(item);
}
if self
.inner
.closed
.load(core::sync::atomic::Ordering::Acquire)
{
return None;
}
queue = self.inner.condvar.wait(queue).unwrap();
}
}
pub(crate) fn try_steal(&self) -> Option<T> {
self.inner.queue.lock().unwrap().pop_front()
}
pub(crate) fn is_closed_and_empty(&self) -> bool {
let queue = self.inner.queue.lock().unwrap();
let closed = self
.inner
.closed
.load(core::sync::atomic::Ordering::Acquire);
closed && queue.is_empty()
}
}
impl<T> Clone for WorkerHandle<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
#[cfg(test)]
mod tests {
use std::{thread, time::Duration};
use super::*;
#[test]
fn test_basic_functionality() {
let queue = WorkStealingQueue::new();
let worker = queue.worker();
assert!(queue.push(1));
assert!(queue.push(2));
assert!(queue.push(3));
assert_eq!(worker.steal(), Some(1));
assert_eq!(worker.steal(), Some(2));
assert_eq!(worker.try_steal(), Some(3));
assert_eq!(worker.try_steal(), None);
queue.close();
assert!(!queue.push(4));
assert!(worker.is_closed_and_empty());
}
#[test]
fn test_multiple_workers() {
let queue = WorkStealingQueue::new();
let worker1 = queue.worker();
let worker2 = queue.worker();
for i in 0..10 {
queue.push(i);
}
let mut results = Vec::new();
while let Some(item) = worker1.try_steal() {
results.push(item);
}
while let Some(item) = worker2.try_steal() {
results.push(item);
}
results.sort();
assert_eq!(results, (0..10).collect::<Vec<_>>());
}
#[test]
fn test_blocking_behavior() {
let queue = WorkStealingQueue::new();
let worker = queue.worker();
let queue_clone = WorkStealingQueue {
inner: Arc::clone(&queue.inner),
};
thread::spawn(move || {
thread::sleep(Duration::from_millis(50));
queue_clone.push(42);
queue_clone.close();
});
assert_eq!(worker.steal(), Some(42));
assert_eq!(worker.steal(), None);
}
}