use crate::rand::{FastRand, RngSeedGenerator};
use crossbeam_deque::{Injector, Steal};
use st3::fifo::Worker;
use std::collections::VecDeque;
use std::fmt::Debug;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
#[repr(C)]
#[derive(Debug)]
pub struct WorkStealQueue<T: Debug> {
shared_queue: Injector<T>,
len: AtomicUsize,
local_queues: VecDeque<Worker<T>>,
index: AtomicUsize,
seed_generator: RngSeedGenerator,
}
impl<T: Debug> Drop for WorkStealQueue<T> {
fn drop(&mut self) {
if !std::thread::panicking() {
for local_queue in &self.local_queues {
assert!(local_queue.pop().is_none(), "local queue not empty");
}
assert!(self.pop().is_none(), "global queue not empty");
}
}
}
impl<T: Debug> WorkStealQueue<T> {
#[allow(unsafe_code, trivial_casts, box_pointers)]
pub fn get_instance<'s>() -> &'s WorkStealQueue<T> {
static INSTANCE: AtomicUsize = AtomicUsize::new(0);
let mut ret = INSTANCE.load(Ordering::Relaxed);
if ret == 0 {
let ptr: &'s mut WorkStealQueue<T> = Box::leak(Box::default());
ret = ptr as *mut WorkStealQueue<T> as usize;
INSTANCE.store(ret, Ordering::Relaxed);
}
unsafe { &*(ret as *mut WorkStealQueue<T>) }
}
#[must_use]
pub fn new(local_queues_size: usize, local_capacity: usize) -> Self {
WorkStealQueue {
shared_queue: Injector::new(),
len: AtomicUsize::new(0),
local_queues: (0..local_queues_size)
.map(|_| Worker::new(local_capacity))
.collect(),
index: AtomicUsize::new(0),
seed_generator: RngSeedGenerator::default(),
}
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn len(&self) -> usize {
self.len.load(Ordering::Acquire)
}
pub fn push(&self, item: T) {
self.shared_queue.push(item);
self.len.store(self.len() + 1, Ordering::Release);
}
pub fn pop(&self) -> Option<T> {
if self.is_empty() {
return None;
}
loop {
match self.shared_queue.steal() {
Steal::Success(item) => {
self.len.store(self.len() - 1, Ordering::Release);
return Some(item);
}
Steal::Retry => continue,
Steal::Empty => return None,
}
}
}
pub fn local_queue(&self) -> LocalQueue<'_, T> {
let mut index = self.index.fetch_add(1, Ordering::Relaxed);
if index == usize::MAX {
self.index.store(0, Ordering::Relaxed);
}
index %= self.local_queues.len();
let local = self
.local_queues
.get(index)
.unwrap_or_else(|| panic!("local queue {index} init failed!"));
LocalQueue::new(self, local, FastRand::new(self.seed_generator.next_seed()))
}
}
impl<T: Debug> Default for WorkStealQueue<T> {
fn default() -> Self {
Self::new(num_cpus::get(), 256)
}
}
#[repr(C)]
#[derive(Debug)]
pub struct LocalQueue<'l, T: Debug> {
tick: AtomicU32,
shared: &'l WorkStealQueue<T>,
stealing: AtomicBool,
queue: &'l Worker<T>,
rand: FastRand,
}
impl<T: Debug> Default for LocalQueue<'_, T> {
fn default() -> Self {
WorkStealQueue::get_instance().local_queue()
}
}
impl<T: Debug> Drop for LocalQueue<'_, T> {
fn drop(&mut self) {
if !std::thread::panicking() {
assert!(self.queue.pop().is_none(), "local queue not empty");
}
}
}
impl<'l, T: Debug> LocalQueue<'l, T> {
fn new(shared: &'l WorkStealQueue<T>, queue: &'l Worker<T>, rand: FastRand) -> Self {
LocalQueue {
tick: AtomicU32::new(0),
shared,
stealing: AtomicBool::new(false),
queue,
rand,
}
}
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
pub fn is_full(&self) -> bool {
self.queue.spare_capacity() == 0
}
pub fn len(&self) -> usize {
self.queue.capacity() - self.queue.spare_capacity()
}
fn try_lock(&self) -> bool {
self.stealing
.compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
.is_ok()
}
fn release_lock(&self) {
self.stealing.store(false, Ordering::Release);
}
pub fn push_back(&self, item: T) {
if let Err(item) = self.queue.push(item) {
let count = self.len() / 2;
for _ in 0..count {
if let Some(item) = self.queue.pop() {
self.shared.push(item);
}
}
self.shared.push(item);
}
}
fn tick(&self) -> u32 {
let val = self.tick.fetch_add(1, Ordering::Release);
if val == u32::MAX {
self.tick.store(0, Ordering::Release);
return 0;
}
val + 1
}
#[allow(clippy::cast_possible_truncation)]
pub fn pop_front(&self) -> Option<T> {
if self.tick() % 61 == 0 {
if let Some(val) = self.shared.pop() {
return Some(val);
}
}
if let Some(val) = self.queue.pop() {
return Some(val);
}
if self.try_lock() {
let local_queues = &self.shared.local_queues;
let num = local_queues.len();
let start = self.rand.fastrand_n(num as u32) as usize;
for i in 0..num {
let i = (start + i) % num;
if let Some(another) = local_queues.get(i) {
if std::ptr::eq(&another, &self.queue) {
continue;
}
if another.is_empty() {
continue;
}
if self.queue.spare_capacity() == 0 {
continue;
}
if another
.stealer()
.steal(self.queue, |n| {
n.min(self.queue.spare_capacity())
.min(((another.capacity() - another.spare_capacity()) + 1) / 2)
})
.is_ok()
{
self.release_lock();
return self.queue.pop();
}
}
}
self.release_lock();
}
self.shared.pop()
}
}