use std::collections::{BinaryHeap, VecDeque};
use std::sync::atomic::Ordering;
use std::sync::{Arc, Mutex, RwLock};
use super::{Worklist, WorklistChannel};
pub type FifoWorklist<T> = BiglockWorklist<VecDeque<T>>;
pub type LifoWorklist<T> = BiglockWorklist<VecDeque<T>>;
pub type PriorityWorklist<T> = BiglockWorklist<BinaryHeap<T>>;
#[derive(Debug)]
pub struct BiglockWorklist<Q> {
storage: Arc<RwLock<Vec<Arc<Mutex<Q>>>>>,
num_open_channels: Arc<std::sync::atomic::AtomicUsize>,
per_channel_queues: bool,
initial_len: usize,
}
trait Queue<T>: QueueLen {
fn push(&mut self, item: T);
fn pop(&mut self) -> Option<T>;
}
pub trait QueueLen {
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T> QueueLen for Vec<T> {
fn len(&self) -> usize {
self.len()
}
}
impl<T> Queue<T> for Vec<T> {
fn push(&mut self, item: T) {
Vec::push(self, item)
}
fn pop(&mut self) -> Option<T> {
Vec::pop(self)
}
}
impl<T> QueueLen for VecDeque<T> {
fn len(&self) -> usize {
self.len()
}
}
impl<T> Queue<T> for VecDeque<T> {
fn push(&mut self, item: T) {
self.push_back(item)
}
fn pop(&mut self) -> Option<T> {
self.pop_front()
}
}
impl<T> QueueLen for BinaryHeap<T> {
fn len(&self) -> usize {
self.len()
}
}
impl<T> Queue<T> for BinaryHeap<T>
where
T: Ord,
{
fn push(&mut self, item: T) {
BinaryHeap::push(self, item)
}
fn pop(&mut self) -> Option<T> {
BinaryHeap::pop(self)
}
}
impl<Q> BiglockWorklist<Q>
where
Q: QueueLen,
{
pub fn new_global_queue(initial_elements: Q) -> Self {
let initial_len = initial_elements.len();
Self {
storage: Arc::new(vec![Arc::new(initial_elements.into())].into()),
num_open_channels: Default::default(),
per_channel_queues: false,
initial_len,
}
}
pub fn new_with_local_queues(initial_elements: Q) -> Self {
let initial_len = initial_elements.len();
Self {
storage: Arc::new(vec![Arc::new(initial_elements.into())].into()),
num_open_channels: Default::default(),
per_channel_queues: true,
initial_len,
}
}
}
impl<T, Q> Worklist<T> for BiglockWorklist<Q>
where
Q: Queue<T> + Default,
{
type Channel = BiglockWorklistChannel<Q>;
fn create_channel(&mut self) -> Self::Channel {
let mut storage_guard = self.storage.write().unwrap();
let channel_id = if self.num_open_channels.load(Ordering::Relaxed) == 0 {
0
} else {
let channel_id = storage_guard.len();
storage_guard.push(Arc::new(Mutex::new(Default::default())));
channel_id
};
self.num_open_channels.fetch_add(1, Ordering::Relaxed);
BiglockWorklistChannel {
channel_id,
storage: self.storage.clone(),
num_open_channels: self.num_open_channels.clone(),
per_channel_queues: self.per_channel_queues,
}
}
fn stop(&mut self) {
}
fn initial_len(&self) -> usize {
self.initial_len
}
}
pub struct BiglockWorklistChannel<Q> {
channel_id: usize,
storage: Arc<RwLock<Vec<Arc<Mutex<Q>>>>>,
num_open_channels: Arc<std::sync::atomic::AtomicUsize>,
per_channel_queues: bool,
}
impl<T, Q> WorklistChannel<T> for BiglockWorklistChannel<Q>
where
Q: Queue<T>,
{
fn push(&self, item: T) {
self.push_to(item, self.channel_id as u32)
}
fn push_to(&self, item: T, channel_id: u32) {
let channel_id = if self.per_channel_queues {
channel_id
} else {
0
};
self.storage.read().unwrap()[channel_id as usize]
.lock()
.expect("acquiring mutex failed")
.push(item);
}
fn pop(&self) -> Option<T> {
let storage = self.storage.read().unwrap();
if self.per_channel_queues {
let maybe_item = storage[self.channel_id]
.lock()
.expect("acquiring mutex failed")
.pop();
if let Some(item) = maybe_item {
Some(item)
} else {
let steal_threshold = 16;
for i in 1..storage.len() {
let idx = (self.channel_id + i) % storage.len();
let mut queue_guard = storage[idx].lock().expect("failed to get mutex");
if queue_guard.len() > steal_threshold {
let item = queue_guard.pop();
if item.is_some() {
return item;
}
}
}
None
}
} else {
storage[0].lock().expect("acquiring mutex failed").pop()
}
}
fn local_len(&self) -> usize {
if self.per_channel_queues {
self.storage.read().unwrap()[self.channel_id]
.lock()
.expect("acquiring mutex failed")
.len()
} else {
self.global_len()
}
}
fn global_len(&self) -> usize {
self.storage
.read()
.unwrap()
.iter()
.map(|q| q.lock().expect("acquiring mutex failed").len())
.sum()
}
fn close(self) {
todo!()
}
}
impl<Q> Drop for BiglockWorklistChannel<Q> {
fn drop(&mut self) {
self.num_open_channels.fetch_sub(1, Ordering::Relaxed);
}
}
#[test]
fn test_biglock_worklist() {
use std::thread;
let mut wl = LifoWorklist::new_global_queue(Default::default());
let ch1 = wl.create_channel();
let ch2 = wl.create_channel();
let ch3 = wl.create_channel();
ch2.push(1);
thread::spawn(move || {
ch1.push(2);
})
.join()
.unwrap();
assert_eq!(ch3.pop(), Some(1));
thread::spawn(move || {
assert_eq!(ch2.pop(), Some(2));
});
}
#[test]
fn test_priority_queue() {
use std::thread;
let mut wl = PriorityWorklist::new_global_queue(Default::default());
let ch1 = wl.create_channel();
let ch2 = wl.create_channel();
ch2.push(2);
thread::spawn(move || {
ch1.push(3);
ch1.push(1);
})
.join()
.unwrap();
assert_eq!(ch2.pop(), Some(3));
assert_eq!(ch2.pop(), Some(2));
assert_eq!(ch2.pop(), Some(1));
}