use super::util::{DynLifetimeView, LifetimeParameterized, Status};
use crate::macros::{log_debug, log_error};
use crossbeam_utils::CachePadded;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, RwLock};
#[derive(Clone, Copy)]
pub enum WorkerState {
Ready,
Finished,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MainStatus {
Waiting,
Ready,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum WorkerStatus {
Round(RoundColor),
Finished,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum RoundColor {
Blue,
Red,
}
impl RoundColor {
fn toggle(&mut self) {
*self = match self {
RoundColor::Blue => RoundColor::Red,
RoundColor::Red => RoundColor::Blue,
}
}
}
pub fn make_lending_group<T: LifetimeParameterized>(
num_threads: usize,
) -> (Lender<T>, Vec<Borrower<T>>) {
let round = RoundColor::Blue;
let shared_context = Arc::new(SharedContext {
num_active_threads: CachePadded::new(AtomicUsize::new(0)),
num_panicking_threads: CachePadded::new(AtomicUsize::new(0)),
worker_status: Status::new(WorkerStatus::Round(round)),
main_status: Status::new(MainStatus::Waiting),
value: RwLock::new(DynLifetimeView::empty()),
});
let borrowers = (0..num_threads)
.map(|_id| Borrower {
#[cfg(feature = "log")]
id: _id,
round,
shared_context: shared_context.clone(),
})
.collect();
let lender = Lender {
num_threads,
round,
shared_context,
};
(lender, borrowers)
}
struct SharedContext<T: LifetimeParameterized> {
num_active_threads: CachePadded<AtomicUsize>,
num_panicking_threads: CachePadded<AtomicUsize>,
worker_status: Status<WorkerStatus>,
main_status: Status<MainStatus>,
value: RwLock<DynLifetimeView<T>>,
}
pub struct Lender<T: LifetimeParameterized> {
num_threads: usize,
round: RoundColor,
shared_context: Arc<SharedContext<T>>,
}
impl<T: LifetimeParameterized> Lender<T> {
pub fn lend(&mut self, value: &T::T<'_>) {
self.shared_context
.num_active_threads
.store(self.num_threads, Ordering::SeqCst);
self.round.toggle();
let round = self.round;
self.shared_context.value.write().unwrap().set(value);
log_debug!("[main thread, round {round:?}] Ready to compute a parallel pipeline.");
self.shared_context
.worker_status
.notify_all(WorkerStatus::Round(round));
log_debug!("[main thread, round {round:?}] Waiting for all threads to finish computing this pipeline.");
let mut guard = self
.shared_context
.main_status
.wait_while(|status| *status == MainStatus::Waiting);
debug_assert_eq!(*guard, MainStatus::Ready);
let num_panicking_threads = self
.shared_context
.num_panicking_threads
.load(Ordering::SeqCst);
if num_panicking_threads != 0 {
log_error!(
"[main thread, round {round:?}] {num_panicking_threads} worker thread(s) panicked!"
);
panic!("{num_panicking_threads} worker thread(s) panicked!");
}
*guard = MainStatus::Waiting;
drop(guard);
log_debug!(
"[main thread, round {round:?}] All threads have now finished computing this pipeline."
);
self.shared_context.value.write().unwrap().clear();
}
pub fn finish_workers(&mut self) {
log_debug!("[main thread] Notifying threads to finish...");
self.shared_context
.worker_status
.notify_all(WorkerStatus::Finished);
}
}
pub struct Borrower<T: LifetimeParameterized> {
#[cfg(feature = "log")]
id: usize,
round: RoundColor,
shared_context: Arc<SharedContext<T>>,
}
impl<T: LifetimeParameterized> Borrower<T> {
pub fn borrow(&mut self, f: impl FnOnce(&T::T<'_>)) -> WorkerState {
self.round.toggle();
let round = self.round;
log_debug!(
"[thread {}, round {round:?}] Waiting for start signal",
self.id
);
let worker_status: WorkerStatus =
*self
.shared_context
.worker_status
.wait_while(|status| match status {
WorkerStatus::Finished => false,
WorkerStatus::Round(r) => *r != round,
});
match worker_status {
WorkerStatus::Finished => {
log_debug!(
"[thread {}, round {round:?}] Received finish signal",
self.id
);
WorkerState::Finished
}
WorkerStatus::Round(r) => {
debug_assert_eq!(round, r);
log_debug!(
"[thread {}, round {round:?}] Received start signal. Processing...",
self.id
);
let notifier = Notifier {
#[cfg(feature = "log")]
id: self.id,
#[cfg(feature = "log")]
round,
shared_context: &self.shared_context,
};
{
let guard = self.shared_context.value.read().unwrap();
let value = unsafe { guard.get().unwrap() };
f(value);
}
drop(notifier);
WorkerState::Ready
}
}
}
}
struct Notifier<'a, T: LifetimeParameterized> {
#[cfg(feature = "log")]
id: usize,
#[cfg(feature = "log")]
round: RoundColor,
shared_context: &'a SharedContext<T>,
}
impl<T: LifetimeParameterized> Drop for Notifier<'_, T> {
fn drop(&mut self) {
#[cfg(feature = "log")]
let round = self.round;
if std::thread::panicking() {
log_error!(
"[thread {}] Detected panic in this thread, notifying the main thread",
self.id
);
self.shared_context
.num_panicking_threads
.fetch_add(1, Ordering::SeqCst);
}
let thread_count = self
.shared_context
.num_active_threads
.fetch_sub(1, Ordering::SeqCst);
debug_assert!(thread_count > 0);
log_debug!(
"[thread {}, round {round:?}] Decremented the number of active threads: {}.",
self.id,
thread_count - 1
);
if thread_count == 1 {
log_debug!(
"[thread {}, round {round:?}] We're the last thread. Waking up the main thread.",
self.id
);
match self
.shared_context
.main_status
.try_notify_one(MainStatus::Ready)
{
Ok(_) => log_debug!(
"[thread {}, round {round:?}] Notified the main thread.",
self.id
),
Err(e) => {
log_error!(
"[thread {}] Failed to notify the main thread, the mutex was poisoned: {e:?}",
self.id
);
panic!("Failed to notify the main thread, the mutex was poisoned: {e:?}");
}
}
} else {
log_debug!(
"[thread {}, round {round:?}] Waiting for other threads to finish.",
self.id
);
}
}
}