#![allow(clippy::mutex_atomic)]
use std::fmt;
use std::sync::{Arc, Condvar, Mutex};
use std::thread::ThreadId;
struct CondState {
count: usize,
panic_id: Option<ThreadId>,
}
pub(crate) struct WaitGroup {
inner: Arc<Inner>,
}
struct Inner {
cvar: Condvar,
state: Mutex<CondState>,
}
impl Default for WaitGroup {
fn default() -> Self {
Self {
inner: Arc::new(Inner {
cvar: Condvar::new(),
state: Mutex::new(CondState {
count: 1,
panic_id: None,
}),
}),
}
}
}
impl WaitGroup {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn wait(self) -> Result<(), ThreadId> {
if let Some(id) = self.inner.state.lock().unwrap().panic_id {
return Err(id);
}
if self.inner.state.lock().unwrap().count == 1 {
return Ok(());
}
let inner = self.inner.clone();
drop(self);
let mut guard = inner.state.lock().unwrap();
while guard.count > 0 {
guard = inner.cvar.wait(guard).unwrap();
}
Ok(())
}
pub(crate) fn set_panic_id(&self, id: ThreadId) {
self.inner.state.lock().unwrap().panic_id = Some(id);
self.inner.cvar.notify_all();
}
}
impl Drop for WaitGroup {
fn drop(&mut self) {
let mut guard = self.inner.state.lock().unwrap();
guard.count -= 1;
if guard.count == 0 {
self.inner.cvar.notify_all();
}
}
}
impl Clone for WaitGroup {
fn clone(&self) -> WaitGroup {
let mut guard = self.inner.state.lock().unwrap();
guard.count += 1;
WaitGroup {
inner: self.inner.clone(),
}
}
}
impl fmt::Debug for WaitGroup {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let guard = self.inner.state.lock().unwrap();
f.debug_struct("WaitGroup")
.field("count", &guard.count)
.field("panic_id", &guard.panic_id)
.finish()
}
}