pub type TaskId = usize;
pub trait QueueKey: Eq + Sized + Copy + Send + Sync + std::fmt::Debug + 'static {}
impl<K> QueueKey for K where K: Eq + Sized + Copy + Send + Sync + std::fmt::Debug + 'static {}
pub struct Queue<K: QueueKey> {
id: K,
share: u64,
}
impl<K: QueueKey> Queue<K> {
pub fn new(id: K, share: u64) -> Self {
Self { id, share }
}
pub fn id(&self) -> K {
self.id
}
pub fn share(&self) -> u64 {
self.share
}
}
use crate::mpsc::Mpsc;
use futures_util::task::AtomicWaker;
use std::cell::Cell;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::task::Waker;
const LIFO_EMPTY: usize = usize::MAX;
pub struct TaskQueue {
mpsc: Mpsc<TaskId>,
lifo_slot: AtomicUsize,
lifo_counter: Cell<usize>,
len: AtomicUsize,
waker: AtomicWaker,
lifo_skip_interval: usize,
enable_lifo: bool,
}
impl TaskQueue {
pub fn new(enable_lifo: bool, lifo_skip_interval: usize) -> Self {
Self {
mpsc: Mpsc::new(),
lifo_slot: AtomicUsize::new(LIFO_EMPTY),
lifo_counter: Cell::new(0),
len: AtomicUsize::new(0),
waker: AtomicWaker::new(),
lifo_skip_interval,
enable_lifo,
}
}
pub fn push(&self, task_id: TaskId) {
if self.enable_lifo {
let old_id = self.lifo_slot.swap(task_id, Ordering::AcqRel);
if old_id != LIFO_EMPTY {
self.mpsc.push(old_id);
}
} else {
self.mpsc.push(task_id);
}
if self.len.fetch_add(1, Ordering::AcqRel) == 0 {
self.waker.wake();
}
}
pub fn pop(&self) -> Option<TaskId> {
let result = self.pop_from_lifo().or_else(|| self.mpsc.pop());
if result.is_some() {
self.len.fetch_sub(1, Ordering::Release);
}
result
}
fn pop_from_lifo(&self) -> Option<TaskId> {
if !self.enable_lifo {
return None;
}
let counter = self.lifo_counter.get();
let counter = counter + 1;
self.lifo_counter.set(counter);
let use_lifo = (counter % self.lifo_skip_interval) != 0;
if use_lifo {
let lifo_id = self.lifo_slot.swap(LIFO_EMPTY, Ordering::AcqRel);
if lifo_id != LIFO_EMPTY {
return Some(lifo_id);
}
}
None
}
pub fn is_empty(&self) -> bool {
self.len.load(Ordering::Acquire) == 0
}
pub fn drain_lifo_to_mpsc(&self) {
if self.enable_lifo {
let lifo_id = self.lifo_slot.swap(LIFO_EMPTY, Ordering::AcqRel);
if lifo_id != LIFO_EMPTY {
self.mpsc.push(lifo_id);
}
}
}
pub fn register_waker(&self, waker: &Waker) {
self.waker.register(waker);
}
}
unsafe impl Send for TaskQueue {}
unsafe impl Sync for TaskQueue {}
impl std::fmt::Debug for TaskQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TaskQueue")
.field("lifo_skip_interval", &self.lifo_skip_interval)
.field("enable_lifo", &self.enable_lifo)
.field("len", &self.len.load(Ordering::Acquire))
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
#[test]
fn test_task_queue_enqueue_empty_lifo() {
let queue = TaskQueue::new(true, 16);
assert!(queue.is_empty());
queue.push(1);
assert!(!queue.is_empty());
}
#[test]
fn test_task_queue_enqueue_non_empty_lifo() {
let queue = TaskQueue::new(true, 16);
queue.push(1);
queue.push(2);
assert!(!queue.is_empty());
}
#[test]
fn test_task_queue_pop_lifo() {
let queue = TaskQueue::new(true, 16);
queue.push(1);
queue.push(2);
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), None);
assert!(queue.is_empty());
}
#[test]
fn test_task_queue_pop_skip_interval() {
let queue = TaskQueue::new(true, 4);
queue.push(10); queue.push(20); queue.push(30); queue.push(40);
assert_eq!(queue.pop(), Some(40));
assert_eq!(queue.pop(), Some(10));
assert_eq!(queue.pop(), Some(20));
assert_eq!(queue.pop(), Some(30));
queue.push(50); queue.push(60);
assert_eq!(queue.pop(), Some(60));
assert_eq!(queue.pop(), Some(50));
}
#[test]
fn test_task_queue_lifo_disabled() {
let queue = TaskQueue::new(false, 16);
queue.push(1);
queue.push(2);
queue.push(3);
assert_eq!(queue.pop(), Some(1));
assert_eq!(queue.pop(), Some(2));
assert_eq!(queue.pop(), Some(3));
assert_eq!(queue.pop(), None);
}
#[test]
fn test_task_queue_is_empty() {
let queue = TaskQueue::new(true, 16);
assert!(queue.is_empty());
queue.push(1);
assert!(!queue.is_empty());
queue.pop();
assert!(queue.is_empty());
}
#[test]
fn test_task_queue_concurrent_enqueue() {
let queue = Arc::new(TaskQueue::new(true, 16));
let num_threads = 8;
let items_per_thread = 1000;
let handles: Vec<_> = (0..num_threads)
.map(|t| {
let queue = Arc::clone(&queue);
thread::spawn(move || {
for i in 0..items_per_thread {
queue.push(t * items_per_thread + i);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
let mut count = 0;
while queue.pop().is_some() {
count += 1;
}
assert_eq!(count, num_threads * items_per_thread);
assert!(queue.is_empty());
}
#[test]
fn test_task_queue_length_accuracy() {
let queue = TaskQueue::new(true, 16);
for i in 1..=10 {
queue.push(i);
}
for _ in 0..5 {
assert!(queue.pop().is_some());
}
let mut remaining = 0;
while queue.pop().is_some() {
remaining += 1;
}
assert_eq!(remaining, 5);
assert!(queue.is_empty());
}
}