use commonware_utils::sync::{Condvar, Mutex};
use futures::task::ArcWake;
use std::{
any::Any,
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
commonware_macros::stability_mod!(BETA, pub mod buffer);
pub mod signal;
#[cfg(not(target_arch = "wasm32"))]
pub(crate) mod thread;
mod handle;
pub use handle::Handle;
#[commonware_macros::stability(ALPHA)]
pub(crate) use handle::Panicked;
pub(crate) use handle::{Aborter, MetricHandle, Panicker};
mod cell;
pub use cell::Cell as ContextCell;
pub(crate) mod supervision;
#[derive(Copy, Clone, Debug)]
pub enum Execution {
Dedicated,
Shared(bool),
}
impl Default for Execution {
fn default() -> Self {
Self::Shared(false)
}
}
pub async fn reschedule() {
struct Reschedule {
yielded: bool,
}
impl Future for Reschedule {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.yielded {
Poll::Ready(())
} else {
self.yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
Reschedule { yielded: false }.await
}
pub(crate) fn extract_panic_message(err: &(dyn Any + Send)) -> String {
err.downcast_ref::<&str>().map_or_else(
|| {
err.downcast_ref::<String>()
.map_or_else(|| format!("{err:?}"), |s| s.clone())
},
|s| s.to_string(),
)
}
pub struct Blocker {
state: Mutex<bool>,
cv: Condvar,
}
impl Blocker {
pub fn new() -> Arc<Self> {
Arc::new(Self {
state: Mutex::new(false),
cv: Condvar::new(),
})
}
pub fn wait(&self) {
let mut signaled = self.state.lock();
while !*signaled {
self.cv.wait(&mut signaled);
}
*signaled = false;
}
}
impl ArcWake for Blocker {
fn wake_by_ref(arc_self: &Arc<Self>) {
{
let mut signaled = arc_self.state.lock();
*signaled = true;
}
arc_self.cv.notify_one();
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures::task::waker;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
#[test]
fn test_blocker_waits_until_wake() {
let blocker = Blocker::new();
let started = Arc::new(AtomicBool::new(false));
let completed = Arc::new(AtomicBool::new(false));
let thread_blocker = blocker.clone();
let thread_started = started.clone();
let thread_completed = completed.clone();
let handle = std::thread::spawn(move || {
thread_started.store(true, Ordering::SeqCst);
thread_blocker.wait();
thread_completed.store(true, Ordering::SeqCst);
});
while !started.load(Ordering::SeqCst) {
std::thread::yield_now();
}
assert!(!completed.load(Ordering::SeqCst));
waker(blocker).wake();
handle.join().unwrap();
assert!(completed.load(Ordering::SeqCst));
}
#[test]
fn test_blocker_handles_pre_wake() {
let blocker = Blocker::new();
waker(blocker.clone()).wake();
let completed = Arc::new(AtomicBool::new(false));
let thread_blocker = blocker;
let thread_completed = completed.clone();
std::thread::spawn(move || {
thread_blocker.wait();
thread_completed.store(true, Ordering::SeqCst);
})
.join()
.unwrap();
assert!(completed.load(Ordering::SeqCst));
}
#[test]
fn test_blocker_reusable_across_signals() {
let blocker = Blocker::new();
let completed = Arc::new(AtomicUsize::new(0));
let thread_blocker = blocker.clone();
let thread_completed = completed.clone();
let handle = std::thread::spawn(move || {
for _ in 0..2 {
thread_blocker.wait();
thread_completed.fetch_add(1, Ordering::SeqCst);
}
});
for expected in 1..=2 {
waker(blocker.clone()).wake();
while completed.load(Ordering::SeqCst) < expected {
std::thread::yield_now();
}
}
handle.join().unwrap();
assert_eq!(completed.load(Ordering::SeqCst), 2);
}
}