armdb 0.1.11

sharded bitcask key-value storage optimized for NVMe
Documentation
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Duration;

use crate::sync::{Condvar, Mutex};

struct Inner {
    flag: AtomicBool,
    mutex: Mutex<bool>,
    condvar: Condvar,
    #[cfg(feature = "rpc")]
    broadcast_txs: std::sync::Mutex<Vec<async_broadcast::Sender<()>>>,
}

/// Shared shutdown signal for coordinating graceful termination of background
/// workers (compactor, replication, RPC).
///
/// Cloning is cheap (`Arc` inside). Pass the same signal to every worker that
/// should stop together.
///
/// ```ignore
/// let signal = ShutdownSignal::new();
/// let compactor = Compactor::start_with_signal(compact_fn, interval, signal.clone());
///
/// // Later — trigger shutdown:
/// signal.shutdown();           // wakes all waiters immediately
/// ```
pub struct ShutdownSignal {
    inner: Arc<Inner>,
}

impl Clone for ShutdownSignal {
    fn clone(&self) -> Self {
        Self {
            inner: self.inner.clone(),
        }
    }
}

impl ShutdownSignal {
    /// Create a new signal (not yet triggered).
    pub fn new() -> Self {
        Self {
            inner: Arc::new(Inner {
                flag: AtomicBool::new(false),
                mutex: Mutex::new(false),
                condvar: Condvar::new(),
                #[cfg(feature = "rpc")]
                broadcast_txs: std::sync::Mutex::new(Vec::new()),
            }),
        }
    }

    /// Trigger shutdown: sets the flag, wakes all threads blocked in
    /// [`wait_timeout`](Self::wait_timeout), and broadcasts to async receivers.
    pub fn shutdown(&self) {
        self.inner.flag.store(true, Ordering::Release);

        // Wake sync waiters.
        #[cfg(feature = "parking_lot")]
        {
            let mut guard = self.inner.mutex.lock();
            *guard = true;
            self.inner.condvar.notify_all();
            drop(guard);
        }
        #[cfg(not(feature = "parking_lot"))]
        {
            let mut guard = self.inner.mutex.lock().expect("shutdown mutex poisoned");
            *guard = true;
            self.inner.condvar.notify_all();
            drop(guard);
        }

        // Wake async waiters (RPC accept loops / connections).
        #[cfg(feature = "rpc")]
        {
            if let Ok(txs) = self.inner.broadcast_txs.lock() {
                for tx in txs.iter() {
                    let _ = tx.try_broadcast(());
                }
            }
        }
    }

    /// Non-blocking check.
    #[inline]
    pub fn is_shutdown(&self) -> bool {
        self.inner.flag.load(Ordering::Acquire)
    }

    /// Sleep for at most `timeout`. Returns `true` if woken by shutdown
    /// (i.e. shutdown was triggered), `false` on normal timeout.
    pub fn wait_timeout(&self, timeout: Duration) -> bool {
        if self.is_shutdown() {
            return true;
        }

        #[cfg(feature = "parking_lot")]
        {
            let mut guard = self.inner.mutex.lock();
            if *guard {
                return true;
            }
            let result = self.inner.condvar.wait_for(&mut guard, timeout);
            // parking_lot: wait_for returns WaitTimeoutResult
            // If timed_out() is false, the condvar was notified.
            *guard || !result.timed_out()
        }

        #[cfg(not(feature = "parking_lot"))]
        {
            let guard = self.inner.mutex.lock().expect("shutdown mutex poisoned");
            if *guard {
                return true;
            }
            let (guard, result) = self
                .inner
                .condvar
                .wait_timeout(guard, timeout)
                .expect("shutdown condvar poisoned");
            *guard || !result.timed_out()
        }
    }

    /// Raw flag reference — for code paths that poll `AtomicBool` directly
    /// (e.g. inner replication loops).
    #[inline]
    pub fn as_flag(&self) -> &AtomicBool {
        &self.inner.flag
    }

    /// Register an `async_broadcast` sender that will be triggered on
    /// [`shutdown()`](Self::shutdown). Returns the corresponding receiver.
    ///
    /// Used by the RPC accept loops so they can `select!` between accepting
    /// connections and the shutdown signal.
    #[cfg(feature = "rpc")]
    pub fn subscribe_broadcast(&self) -> async_broadcast::Receiver<()> {
        let (tx, rx) = async_broadcast::broadcast::<()>(1);

        // If already shut down, fire immediately.
        if self.is_shutdown() {
            let _ = tx.try_broadcast(());
        }

        if let Ok(mut txs) = self.inner.broadcast_txs.lock() {
            txs.push(tx);
        }
        rx
    }

    /// Install a Ctrl-C (SIGINT) handler that triggers this signal.
    ///
    /// Safe to call multiple times — only the first call installs the handler.
    #[cfg(feature = "rpc")]
    pub fn install_ctrlc(&self) {
        let signal = self.clone();
        // ctrlc::set_handler returns Err if a handler is already installed.
        let _ = ctrlc::set_handler(move || {
            signal.shutdown();
        });
    }
}

impl Default for ShutdownSignal {
    fn default() -> Self {
        Self::new()
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn not_shutdown_by_default() {
        let s = ShutdownSignal::new();
        assert!(!s.is_shutdown());
    }

    #[test]
    fn shutdown_sets_flag() {
        let s = ShutdownSignal::new();
        s.shutdown();
        assert!(s.is_shutdown());
    }

    #[test]
    fn wait_timeout_returns_immediately_after_shutdown() {
        let s = ShutdownSignal::new();
        s.shutdown();
        let start = std::time::Instant::now();
        let woken = s.wait_timeout(Duration::from_secs(60));
        assert!(woken);
        assert!(start.elapsed() < Duration::from_millis(100));
    }

    #[test]
    fn wait_timeout_wakes_on_shutdown() {
        let s = ShutdownSignal::new();
        let s2 = s.clone();

        let handle = std::thread::spawn(move || {
            let start = std::time::Instant::now();
            let woken = s2.wait_timeout(Duration::from_secs(60));
            (woken, start.elapsed())
        });

        std::thread::sleep(Duration::from_millis(50));
        s.shutdown();

        let (woken, elapsed) = handle.join().expect("thread panicked");
        assert!(woken);
        assert!(elapsed < Duration::from_secs(1));
    }

    #[test]
    fn wait_timeout_times_out_normally() {
        let s = ShutdownSignal::new();
        let start = std::time::Instant::now();
        let woken = s.wait_timeout(Duration::from_millis(50));
        assert!(!woken);
        assert!(start.elapsed() >= Duration::from_millis(40));
    }

    #[test]
    fn clone_shares_state() {
        let s1 = ShutdownSignal::new();
        let s2 = s1.clone();
        s1.shutdown();
        assert!(s2.is_shutdown());
    }
}