use super::{Worklist, WorklistChannel};
use crossbeam_deque as cb;
use std::cell::Cell;
use std::cell::RefCell;
use std::sync::Arc;
use std::sync::RwLock;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
pub struct CbWorklist<T> {
shared: Arc<Shared<T>>,
initial_len: usize,
}
impl<T> Default for CbWorklist<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> CbWorklist<T> {
fn new() -> Self {
let shared = Shared {
num_channels: AtomicUsize::new(0),
stealers: vec![].into(),
global_injector: cb::Injector::new(),
injectors: vec![].into(),
steal_limit: 16,
};
Self {
shared: Arc::new(shared),
initial_len: 0,
}
}
#[must_use]
pub fn new_builder() -> CbWorklistBuilder<T> {
CbWorklistBuilder {
worklist: Self::new(),
}
}
fn push(&self, item: T) {
self.shared.global_injector.push(item);
}
fn push_items(&self, items: impl IntoIterator<Item = T>) {
items.into_iter().for_each(|item| self.push(item))
}
}
pub struct CbWorklistBuilder<T> {
worklist: CbWorklist<T>,
}
impl<T> CbWorklistBuilder<T> {
#[must_use]
pub fn push(self, item: T) -> Self {
self.worklist.push(item);
self
}
#[must_use]
pub fn push_items(self, items: impl IntoIterator<Item = T>) -> Self {
self.worklist.push_items(items);
self
}
pub fn build(self) -> CbWorklist<T> {
let mut wl = self.worklist;
wl.initial_len = wl.shared.global_injector.len();
wl
}
}
struct Shared<T> {
num_channels: AtomicUsize,
stealers: RwLock<Vec<cb::Stealer<T>>>,
global_injector: cb::Injector<T>,
injectors: RwLock<Vec<Arc<cb::Injector<T>>>>,
steal_limit: usize,
}
impl<T> Worklist<T> for CbWorklist<T> {
type Channel = CbWorklistChannel<T>;
fn create_channel(&mut self) -> Self::Channel {
let local_queue = cb::Worker::new_fifo(); let stealer = local_queue.stealer();
let injector = Arc::new(cb::Injector::new());
self.shared
.stealers
.write()
.expect("failed to acquire write lock")
.push(stealer);
self.shared
.injectors
.write()
.expect("failed to acquire write lock")
.push(Arc::clone(&injector));
self.shared.num_channels.fetch_add(1, Ordering::Release);
CbWorklistChannel {
local_queue,
shared: self.shared.clone(),
injector,
injectors: RefCell::new(vec![]),
stealers: RefCell::new(vec![]),
num_channels: Cell::new(0),
}
}
fn initial_len(&self) -> usize {
self.initial_len
}
fn stop(&mut self) {
todo!()
}
}
pub struct CbWorklistChannel<T> {
local_queue: cb::Worker<T>,
shared: Arc<Shared<T>>,
injector: Arc<cb::Injector<T>>,
injectors: RefCell<Vec<Arc<cb::Injector<T>>>>,
stealers: RefCell<Vec<cb::Stealer<T>>>,
num_channels: Cell<usize>,
}
impl<T> CbWorklistChannel<T> {
fn update_local_handles_to_other_channels(&self) {
let new_num_channels = self.shared.num_channels.load(Ordering::Acquire);
if self.num_channels.get() != new_num_channels {
self.num_channels.set(new_num_channels);
let mut injectors = self.injectors.borrow_mut();
injectors.clear();
self.shared
.injectors
.read()
.unwrap()
.iter()
.cloned()
.for_each(|i| injectors.push(i));
let mut stealers = self.stealers.borrow_mut();
stealers.clear();
self.shared
.stealers
.read()
.unwrap()
.iter()
.cloned()
.for_each(|s| stealers.push(s));
}
}
}
impl<T> WorklistChannel<T> for CbWorklistChannel<T> {
fn push(&self, item: T) {
self.local_queue.push(item);
}
fn push_to(&self, item: T, worker_id: u32) {
self.update_local_handles_to_other_channels();
self.injectors
.borrow()
.get(worker_id as usize)
.expect("no such channel ID")
.push(item)
}
fn pop(&self) -> Option<T> {
self.local_queue
.pop()
.or_else(|| {
self.injector
.steal_batch_and_pop(&self.local_queue)
.success()
})
.or_else(|| {
self.update_local_handles_to_other_channels();
self.stealers
.borrow()
.iter()
.max_by_key(|stealer| stealer.len())
.and_then(|stealer| {
stealer
.steal_batch_with_limit_and_pop(
&self.local_queue,
self.shared.steal_limit,
)
.success()
})
})
.or_else(|| {
self.shared
.global_injector
.steal_batch_with_limit_and_pop(&self.local_queue, self.shared.steal_limit)
.success()
})
.or_else(|| {
self.update_local_handles_to_other_channels();
self.injectors
.borrow()
.iter()
.max_by_key(|stealer| stealer.len())
.and_then(|stealer| {
stealer
.steal_batch_with_limit_and_pop(
&self.local_queue,
self.shared.steal_limit,
)
.success()
})
})
}
fn local_len(&self) -> usize {
self.local_queue.len()
}
fn global_len(&self) -> usize {
self.local_len()
+ self
.shared
.stealers
.read()
.expect("failed to acquire read lock")
.iter()
.map(|stealer| stealer.len())
.sum::<usize>()
+ self
.shared
.injectors
.read()
.expect("failed to acquire read lock")
.iter()
.map(|inj| inj.len())
.sum::<usize>()
+ self.shared.global_injector.len()
}
fn close(self) {
while let Some(item) = self.local_queue.pop() {
self.shared.global_injector.push(item);
}
}
}
#[test]
fn test_worklist() {
let mut wl = CbWorklist::new();
let ch1 = wl.create_channel();
let ch2 = wl.create_channel();
ch1.push(1);
ch2.push(2);
assert_eq!(ch1.pop(), Some(1));
assert_eq!(ch2.pop(), Some(2));
ch1.push(1);
ch2.push(2);
assert_eq!(ch1.pop(), Some(1));
assert_eq!(ch1.pop(), Some(2));
ch1.push_to(2, 1);
ch2.push_to(1, 0);
assert_eq!(ch1.pop(), Some(1));
assert_eq!(ch2.pop(), Some(2));
}
#[test]
fn test_close_channel() {
let mut wl = CbWorklist::new();
let ch1 = wl.create_channel();
let ch2 = wl.create_channel();
ch1.push(1);
ch2.push(2);
ch1.close();
ch2.close();
let ch3 = wl.create_channel();
assert_eq!(ch3.pop(), Some(1));
assert_eq!(ch3.pop(), Some(2));
}
#[test]
fn test_worklist_with_initial_elements() {
let mut wl = CbWorklist::new_builder().push(1).push(2).build();
assert_eq!(wl.initial_len(), 2);
let ch1 = wl.create_channel();
assert_eq!(ch1.pop(), Some(1));
assert_eq!(ch1.pop(), Some(2));
}