use std::borrow::Borrow;
use std::fmt::Debug;
use std::sync::atomic::{AtomicIsize, AtomicU8, AtomicUsize, Ordering};
use std::sync::RwLock;
use std::{num::NonZeroUsize, thread};
use petgraph::{data::DataMap, visit::GraphBase};
use crate::conflict_resolution::ConflictResolvingOperator;
use crate::local_view::LocalGraphView;
use crate::worklists::WorklistChannel;
use crate::worklists::{PushFnWrapper, Worklist};
use crate::{BorrowDataCell, LabellingOperator, ReadonlyOperator};
pub struct MultiThreadExecutor {
num_threads: NonZeroUsize,
}
impl Default for MultiThreadExecutor {
fn default() -> Self {
Self::new()
}
}
impl MultiThreadExecutor {
pub fn new() -> Self {
Self {
num_threads: std::thread::available_parallelism()
.unwrap_or(NonZeroUsize::new(1).unwrap()),
}
}
pub fn with_num_threads(mut self, num_threads: NonZeroUsize) -> Self {
self.num_threads = num_threads;
self
}
pub fn run_node_labelling<Wl, Op, G>(&self, initial_worklist: Wl, operator: Op, graph: G)
where
Wl: Worklist<Op::WorkItem>,
Wl::Channel: Send,
Op: LabellingOperator<G> + Clone + Send,
G: GraphBase + DataMap + Sync,
G::NodeWeight: BorrowDataCell<UserData = Op::NodeWeight>,
Op::WorkItem: Copy,
{
let op_wrapper = ConflictResolvingOperator::new(operator);
self.run_readonly(initial_worklist, op_wrapper, graph)
}
pub fn run_readonly<Wl, Op, G>(&self, mut initial_worklist: Wl, operator: Op, graph: G)
where
Wl: Worklist<Op::WorkItem>,
Wl::Channel: Send,
Op: ReadonlyOperator<G> + Clone + Send,
G: GraphBase + Sync,
{
let graph = &graph;
let shared = WorkerSharedData::new(self.num_threads);
thread::scope(|s| {
{
let initial_worklist_len = initial_worklist.initial_len();
shared.pending_task_counters[0].set(initial_worklist_len as isize);
}
let mut thread_ids_guard = shared.threads.write().unwrap();
let operator_clones = itertools::repeat_n(operator, self.num_threads.get());
let thread_ids = operator_clones.enumerate().map(|(id, operator)| {
let worker = Worker {
shared: &shared,
id: id as u32,
channel: initial_worklist.create_channel(),
graph,
operator,
};
let thread = s.spawn(move || {
worker.run();
});
thread.thread().clone()
});
thread_ids_guard.extend(thread_ids);
drop(thread_ids_guard);
});
initial_worklist.stop();
}
}
struct WorkerSharedData {
pending_task_counters: Vec<PendingTasksCounter>,
unpark_signals: Vec<UnparkSignal>,
worklist_maybe_empty: AtomicUsize,
threads: RwLock<Vec<thread::Thread>>,
num_threads: NonZeroUsize,
}
impl WorkerSharedData {
fn new(num_threads: NonZeroUsize) -> Self {
let threads: RwLock<Vec<thread::Thread>> =
RwLock::new(Vec::with_capacity(num_threads.get()));
let pending_task_counters = (0..num_threads.get()).map(|_| Default::default()).collect();
let unpark_signals = (0..num_threads.get()).map(|_| Default::default()).collect();
Self {
pending_task_counters,
unpark_signals,
worklist_maybe_empty: Default::default(),
threads,
num_threads,
}
}
}
struct Worker<'a, Ch, G, Op> {
id: u32,
channel: Ch,
graph: &'a G,
operator: Op,
shared: &'a WorkerSharedData,
}
impl<'a, Ch, G, Op> Worker<'a, Ch, G, Op>
where
G: GraphBase,
Ch: WorklistChannel<Op::WorkItem>,
Op: ReadonlyOperator<G>,
{
fn run(self) {
let pending_task_counter = &self.shared.pending_task_counters[self.id as usize];
loop {
if let Some(active_node) = self.channel.pop() {
let num_new_work_items = self.process_task(active_node);
pending_task_counter.decrement();
if num_new_work_items > 0 {
self.wake_starved_threads();
}
} else {
if self.handle_empty_worklist() == LoopControl::Break {
break;
}
}
}
}
fn process_task(&self, work_item: Op::WorkItem) -> usize {
let mut num_new_work_items = 0;
let pending_task_counter = &self.shared.pending_task_counters[self.id as usize];
let active_node = *work_item.borrow();
let local_view = LocalGraphView::new(self.graph, active_node);
let push = PushFnWrapper::new(
|item| {
num_new_work_items += 1;
pending_task_counter.increment();
self.channel.push_to(item, self.id)
},
self.id,
);
self.operator.op(work_item, local_view, push);
num_new_work_items
}
}
impl<'a, Ch, G, Op> Worker<'a, Ch, G, Op> {
fn num_pending_tasks(&self) -> isize {
self.shared
.pending_task_counters
.iter()
.map(|c| c.get())
.sum()
}
fn handle_empty_worklist(&self) -> LoopControl {
let num_empty = self
.shared
.worklist_maybe_empty
.fetch_add(1, Ordering::Relaxed)
+ 1;
assert!(
num_empty <= self.shared.num_threads.get(),
"unmatched increment/decrement"
);
if num_empty == self.shared.num_threads.get() {
if self.num_pending_tasks() == 0 {
for (thread, signal) in self
.shared
.threads
.read()
.unwrap()
.iter()
.zip(&self.shared.unpark_signals)
{
signal.set_exit();
if thread.id() != thread::current().id() {
thread.unpark();
}
}
LoopControl::Break
} else {
self.shared
.worklist_maybe_empty
.fetch_sub(1, Ordering::Relaxed);
LoopControl::NoAction
}
} else {
let signal = &self.shared.unpark_signals[self.id as usize];
loop {
thread::park();
match signal.clear_unpark() {
UnparkAction::ParkAgain => {}
UnparkAction::Unpark => {
self.shared
.worklist_maybe_empty
.fetch_sub(1, Ordering::Relaxed);
break LoopControl::NoAction;
}
UnparkAction::Exit => {
assert_eq!(self.num_pending_tasks(), 0);
break LoopControl::Break;
}
}
}
}
}
fn wake_starved_threads<T>(&self)
where
Ch: WorklistChannel<T>,
{
let has_waiting_threads = self.shared.worklist_maybe_empty.load(Ordering::Acquire) > 0;
let has_enough_tasks = self.channel.local_len() > self.shared.num_threads.get();
if has_waiting_threads && has_enough_tasks
{
for (thread, signal) in self
.shared
.threads
.read()
.unwrap()
.iter()
.zip(&self.shared.unpark_signals)
{
if thread.id() != thread::current().id() {
signal.set_unpark();
thread.unpark();
}
}
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
#[must_use]
enum LoopControl {
NoAction,
Break,
}
#[derive(Default)]
#[repr(align(64))] struct PendingTasksCounter {
value: AtomicIsize,
}
impl PendingTasksCounter {
fn increment(&self) {
self.value.fetch_add(1, Ordering::Release);
}
fn decrement(&self) {
self.value.fetch_sub(1, Ordering::Release);
}
fn get(&self) -> isize {
self.value.load(Ordering::Acquire)
}
fn set(&self, value: isize) {
self.value.store(value, Ordering::Release);
}
}
#[derive(Default)]
#[repr(align(64))]
struct UnparkSignal {
signal: AtomicU8,
}
impl UnparkSignal {
fn set_exit(&self) -> UnparkAction {
self.signal.fetch_or(0b10, Ordering::Relaxed).into()
}
fn clear_unpark(&self) -> UnparkAction {
self.signal.fetch_and(!0b1, Ordering::Relaxed).into()
}
fn set_unpark(&self) -> UnparkAction {
self.signal.fetch_or(0b1, Ordering::Relaxed).into()
}
}
#[derive(Debug, Default, PartialEq, Eq, Clone, Copy)]
enum UnparkAction {
#[default]
ParkAgain,
Unpark,
Exit,
}
impl From<u8> for UnparkAction {
fn from(value: u8) -> Self {
if value & 0b10 != 0 {
UnparkAction::Exit
} else if value & 0b1 != 0 {
UnparkAction::Unpark
} else {
UnparkAction::ParkAgain
}
}
}
impl From<UnparkAction> for u8 {
fn from(value: UnparkAction) -> Self {
match value {
UnparkAction::ParkAgain => 0b00,
UnparkAction::Unpark => 0b01,
UnparkAction::Exit => 0b10,
}
}
}