use async_task::Runnable;
use crossbeam_utils::sync::ShardedLock;
use event_listener::{Event, EventListener};
use rustc_hash::FxHashMap;
use smallvec::SmallVec;
use st3::fifo::{Stealer, Worker};
use std::{
cell::RefCell,
collections::VecDeque,
sync::atomic::{AtomicU64, Ordering},
};
pub struct GlobalQueue {
queue: parking_lot::Mutex<VecDeque<Runnable>>,
stealers: ShardedLock<FxHashMap<u64, Stealer<Runnable>>>,
id_ctr: AtomicU64,
event: Event,
}
impl GlobalQueue {
pub fn new() -> Self {
Self {
queue: Default::default(),
stealers: Default::default(),
id_ctr: AtomicU64::new(0),
event: Event::new(),
}
}
pub fn push(&self, task: Runnable) {
self.queue.lock().push_back(task);
self.event.notify(1);
}
pub fn rebalance(&self) {
self.event.notify_relaxed(usize::MAX);
}
pub fn subscribe(&self) -> LocalQueue<'_> {
let worker = Worker::<Runnable>::new(1024);
let id = self.id_ctr.fetch_add(1, Ordering::Relaxed);
self.stealers.write().unwrap().insert(id, worker.stealer());
LocalQueue {
id,
global: self,
local: worker,
}
}
pub fn wait(&self) -> EventListener {
self.event.listen()
}
}
pub struct LocalQueue<'a> {
id: u64,
global: &'a GlobalQueue,
local: Worker<Runnable>,
}
impl<'a> Drop for LocalQueue<'a> {
fn drop(&mut self) {
while let Some(task) = self.local.pop() {
self.global.push(task);
}
self.global.stealers.write().unwrap().remove(&self.id);
}
}
impl<'a> LocalQueue<'a> {
pub fn pop(&self) -> Option<Runnable> {
self.local.pop().or_else(|| self.steal_and_pop())
}
pub fn push(&self, runnable: Runnable) {
if let Err(runnable) = self.local.push(runnable) {
log::trace!("{} pushed globally", self.id);
self.global.push(runnable);
} else {
log::trace!("{} pushed locally", self.id);
}
}
fn steal_and_pop(&self) -> Option<Runnable> {
{
let stealers = self.global.stealers.read().unwrap();
let mut ids: SmallVec<[u64; 64]> = stealers.keys().copied().collect();
fastrand::shuffle(&mut ids);
for id in ids {
if let Ok((val, count)) =
stealers[&id].steal_and_pop(&self.local, |n| (n / 2 + 1).min(64))
{
log::trace!("{} stole {} from {id}", count + 1, self.id);
return Some(val);
}
}
}
if let Some(mut global) = self.global.queue.try_lock() {
let to_steal = (global.len() / 2 + 1).min(64).min(global.len());
for _ in 0..to_steal {
let stolen = global.pop_front().unwrap();
if let Err(back) = self.local.push(stolen) {
return Some(back);
}
}
return self.local.pop();
}
None
}
}