use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;
use crate::sync::{Condvar, Mutex};
struct Inner {
flag: AtomicBool,
mutex: Mutex<bool>,
condvar: Condvar,
#[cfg(feature = "rpc")]
broadcast_txs: std::sync::Mutex<Vec<async_broadcast::Sender<()>>>,
}
pub struct ShutdownSignal {
inner: Arc<Inner>,
}
impl Clone for ShutdownSignal {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl ShutdownSignal {
pub fn new() -> Self {
Self {
inner: Arc::new(Inner {
flag: AtomicBool::new(false),
mutex: Mutex::new(false),
condvar: Condvar::new(),
#[cfg(feature = "rpc")]
broadcast_txs: std::sync::Mutex::new(Vec::new()),
}),
}
}
pub fn shutdown(&self) {
self.inner.flag.store(true, Ordering::Release);
#[cfg(feature = "parking_lot")]
{
let mut guard = self.inner.mutex.lock();
*guard = true;
self.inner.condvar.notify_all();
drop(guard);
}
#[cfg(not(feature = "parking_lot"))]
{
let mut guard = self.inner.mutex.lock().expect("shutdown mutex poisoned");
*guard = true;
self.inner.condvar.notify_all();
drop(guard);
}
#[cfg(feature = "rpc")]
{
if let Ok(txs) = self.inner.broadcast_txs.lock() {
for tx in txs.iter() {
let _ = tx.try_broadcast(());
}
}
}
}
#[inline]
pub fn is_shutdown(&self) -> bool {
self.inner.flag.load(Ordering::Acquire)
}
pub fn wait_timeout(&self, timeout: Duration) -> bool {
if self.is_shutdown() {
return true;
}
#[cfg(feature = "parking_lot")]
{
let mut guard = self.inner.mutex.lock();
let deadline = std::time::Instant::now() + timeout;
loop {
if *guard {
return true;
}
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return *guard;
}
let result = self.inner.condvar.wait_for(&mut guard, remaining);
if result.timed_out() {
return *guard;
}
}
}
#[cfg(not(feature = "parking_lot"))]
{
let mut guard = self.inner.mutex.lock().expect("shutdown mutex poisoned");
let deadline = std::time::Instant::now() + timeout;
loop {
if *guard {
return true;
}
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
return *guard;
}
let (g, result) = self
.inner
.condvar
.wait_timeout(guard, remaining)
.expect("shutdown condvar poisoned");
guard = g;
if result.timed_out() {
return *guard;
}
}
}
}
#[inline]
pub fn as_flag(&self) -> &AtomicBool {
&self.inner.flag
}
#[cfg(feature = "rpc")]
pub fn subscribe_broadcast(&self) -> async_broadcast::Receiver<()> {
let (tx, rx) = async_broadcast::broadcast::<()>(1);
if self.is_shutdown() {
let _ = tx.try_broadcast(());
}
if let Ok(mut txs) = self.inner.broadcast_txs.lock() {
txs.push(tx);
}
rx
}
#[cfg(feature = "rpc")]
pub fn install_ctrlc(&self) {
let signal = self.clone();
let _ = ctrlc::set_handler(move || {
signal.shutdown();
});
}
#[cfg(test)]
pub(crate) fn notify_without_shutdown(&self) {
#[cfg(feature = "parking_lot")]
self.inner.condvar.notify_all();
#[cfg(not(feature = "parking_lot"))]
self.inner.condvar.notify_all();
}
}
impl Default for ShutdownSignal {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn not_shutdown_by_default() {
let s = ShutdownSignal::new();
assert!(!s.is_shutdown());
}
#[test]
fn shutdown_sets_flag() {
let s = ShutdownSignal::new();
s.shutdown();
assert!(s.is_shutdown());
}
#[test]
fn wait_timeout_returns_immediately_after_shutdown() {
let s = ShutdownSignal::new();
s.shutdown();
let start = std::time::Instant::now();
let woken = s.wait_timeout(Duration::from_secs(60));
assert!(woken);
assert!(start.elapsed() < Duration::from_millis(100));
}
#[test]
fn wait_timeout_wakes_on_shutdown() {
let s = ShutdownSignal::new();
let s2 = s.clone();
let handle = std::thread::spawn(move || {
let start = std::time::Instant::now();
let woken = s2.wait_timeout(Duration::from_secs(60));
(woken, start.elapsed())
});
std::thread::sleep(Duration::from_millis(50));
s.shutdown();
let (woken, elapsed) = handle.join().expect("thread panicked");
assert!(woken);
assert!(elapsed < Duration::from_secs(1));
}
#[test]
fn wait_timeout_times_out_normally() {
let s = ShutdownSignal::new();
let start = std::time::Instant::now();
let woken = s.wait_timeout(Duration::from_millis(50));
assert!(!woken);
assert!(start.elapsed() >= Duration::from_millis(40));
}
#[test]
fn clone_shares_state() {
let s1 = ShutdownSignal::new();
let s2 = s1.clone();
s1.shutdown();
assert!(s2.is_shutdown());
}
#[test]
fn spurious_wakeup_does_not_report_shutdown() {
let s = ShutdownSignal::new();
let s2 = s.clone();
let handle = std::thread::spawn(move || {
let start = std::time::Instant::now();
let woken = s2.wait_timeout(Duration::from_secs(2));
(woken, start.elapsed())
});
std::thread::sleep(Duration::from_millis(100));
s.notify_without_shutdown();
let (woken, elapsed) = handle.join().expect("thread panicked");
assert!(
!woken,
"wait_timeout must return false on spurious wakeup (no shutdown)"
);
assert!(
elapsed >= Duration::from_secs(1),
"should have waited close to the full timeout, but returned in {elapsed:?}"
);
}
}