killswitch 0.4.2

Killswitch used to broadcast a shutdown request.
Documentation
//! This library provides two separate structures for signalling (and
//! receiveing) termination requests [in `async` contexts]:
//!
//! - [`KillSwitch`] acts as both a trigger and a receiver.
//! - [`pair::KillTrig`] and [`pair::KillWait`] (created using
//!   [`pair::create()`]) act as a kill signal sender and receiver.
//!
//! # KillSwitch
//! Signal a request for (multiple) async tasks to self-terminate.
//!
//! ```
//! use std::error::Error;
//! use tokio::time::{sleep, Duration};
//! use killswitch::KillSwitch;
//!
//! #[tokio::main]
//! async fn main() -> Result<(), Box<dyn Error>> {
//!   let ks = KillSwitch::new();
//!
//!   tokio::spawn(killable(String::from("test1"), ks.clone()));
//!   tokio::spawn(killable(String::from("test2"), ks.clone()));
//!
//!   sleep(Duration::from_secs(1)).await;
//!
//!   println!("Triggering kill switch");
//!   ks.trigger();
//!
//!   tokio::spawn(killable(String::from("test3"), ks.clone()));
//!   tokio::spawn(killable(String::from("test4"), ks.clone()));
//!
//!   // Wait for all waiters to drop
//!   ks.finalize().await;
//!
//!   Ok(())
//! }
//!
//! async fn killable(s: String, ks: KillSwitch) {
//!   println!("killable({}) entered", s);
//!   ks.wait().await;
//!   println!("killable({}) leaving", s);
//! }
//! ```
//!
//! `killswitch` was developed to help create abortable async tasks in
//! conjuction with multiple-wait features such as the `tokio::select!` macro.

pub mod pair;

use std::{
  collections::HashMap,
  future::Future,
  num::NonZeroUsize,
  pin::Pin,
  sync::atomic::{AtomicBool, AtomicUsize, Ordering},
  sync::Arc,
  task::{Context, Poll, Waker}
};

use parking_lot::Mutex;


/// Shared state buffer.
struct State {
  /// Map used to keep track of all wakers currently waiting for the
  /// killswitch to be triggered.
  waiting: HashMap<usize, Waker>,

  /// Waker used to track all wakers waiting for `waiting` to be emptied.
  final_wait: HashMap<usize, Waker>
}


/// Buffer shared among all KillSwitch objects.
struct Shared {
  /// Used to generate internal unique identifiers.
  id: AtomicUsize,

  /// Keep track of whether the killswitch has been triggered.
  triggered: AtomicBool,

  /// Shared state buffer that requires locking.
  state: Mutex<State>
}

impl Shared {
  #[inline]
  fn id(&self) -> usize {
    self.id.fetch_add(1, Ordering::SeqCst)
  }
}


/// The KillSwitch is used both to signal termination and waiting for
/// termination.
pub struct KillSwitch(Arc<Shared>);

impl KillSwitch {
  /// Create a new kill switch object.
  pub fn new() -> Self {
    Self::default()
  }


  /// Mark killswitch as triggered and signal all waiting tasks that they
  /// should terminate.
  #[inline]
  pub fn trigger(&self) {
    // Mark killswitch as "set".
    self.0.triggered.store(true, Ordering::SeqCst);

    // Tell all waiting tasks to wake up [so they can check the killswitch
    // state].
    let mut state = self.0.state.lock();
    for (_, waker) in state.waiting.drain() {
      waker.wake();
    }
  }


  /// Wait for the killswitch to be triggered.
  #[inline]
  pub fn wait(&self) -> WaitFuture {
    WaitFuture {
      ctx: Arc::clone(&self.0),
      id: None
    }
  }

  /// Return a Future that will return `Ready` once there are no more waiters
  /// waiting on this killswitch to be triggered.
  ///
  /// The KillSwitch must be triggered before calling this function, or the
  /// returned future will return an error.
  pub fn finalize(&self) -> FinalizedFuture {
    FinalizedFuture {
      ctx: Arc::clone(&self.0),
      id: None
    }
  }

  /// Reset `KillSwitch`.
  ///
  /// Care should be taken when using this.  Generally speaking it should only
  /// be used immediately following a [`KillSwitch::finalize()`]:
  ///
  /// ```
  /// use killswitch::KillSwitch;
  ///
  /// # tokio_test::block_on(async {
  /// let ks = KillSwitch::default();
  /// assert_eq!(ks.is_triggered(), false);
  ///
  /// // Trigger kill switch
  /// ks.trigger();
  /// assert_eq!(ks.is_triggered(), true);
  /// ks.finalize().await;
  /// ks.reset();
  ///
  /// // KillSwitch became untriggered again
  /// assert_eq!(ks.is_triggered(), false);
  /// # });
  /// ```
  ///
  /// Applications should prefer to call [`KillSwitch::finalize_reset()`]
  /// rather than calling `finalize()` and `reset()`.
  pub fn reset(&self) {
    self.0.triggered.store(false, Ordering::SeqCst);
  }

  /// Finalize and reset `KillSwitch`.
  ///
  /// Returns `Err(())` if the `KillSwitch` wasn't in triggered state.
  pub async fn finalize_reset(&self) -> Result<(), ()> {
    if (self.finalize().await).is_ok() {
      self.reset();
      Ok(())
    } else {
      Err(())
    }
  }

  /// Return a boolean indicating whether kill switch has been triggered.
  ///
  /// Returns `true` if kill switch has been triggered.  Returns `false`
  /// otherwise.
  pub fn is_triggered(&self) -> bool {
    self.0.triggered.load(Ordering::SeqCst)
  }
}

impl Default for KillSwitch {
  fn default() -> Self {
    Self(Arc::new(Shared {
      id: AtomicUsize::new(1),
      triggered: AtomicBool::new(false),
      state: Mutex::new(State {
        waiting: HashMap::new(),
        final_wait: HashMap::new()
      })
    }))
  }
}


impl Clone for KillSwitch {
  fn clone(&self) -> KillSwitch {
    KillSwitch(Arc::clone(&self.0))
  }
}


/// A future used to wait for a [`KillSwitch`] to become triggered.
pub struct WaitFuture {
  ctx: Arc<Shared>,

  /// This waiter's identifier.
  ///
  /// Assigned lazily.
  id: Option<NonZeroUsize>
}

impl Future for WaitFuture {
  type Output = ();

  fn poll(
    mut self: Pin<&mut Self>,
    ctx: &mut Context<'_>
  ) -> Poll<Self::Output> {
    //
    // If the killswitch has already been triggered, then avoid taking
    // the lock and just return Ready.
    //
    match self.ctx.triggered.load(Ordering::SeqCst) {
      true => Poll::Ready(()),
      false => {
        let mut state = self.ctx.state.lock();

        // Re-check once the lock has been acquired
        if self.ctx.triggered.load(Ordering::SeqCst) {
          return Poll::Ready(());
        }

        // Generate a new id (make sure it isn't zero).
        let mut id = self.ctx.id();
        while id == 0 || state.waiting.contains_key(&id) {
          id = self.ctx.id();
        }

        state.waiting.insert(id, ctx.waker().clone());

        //
        // Drop the guard so the identifier can be stored in self
        //
        drop(state);
        self.id = Some(unsafe { NonZeroUsize::new_unchecked(id) });

        Poll::Pending
      }
    }
  }
}

impl Drop for WaitFuture {
  /// When a `WaitFuture` is dropped, then make sure its waker is released from
  /// the internal hashmap.
  ///
  /// If there's a waker waiting for all wakers to waiters to be dropped, and
  /// the removal of this waker causes the map to be empty, then activate the
  /// waiting task.
  fn drop(&mut self) {
    if let Some(id) = self.id {
      let mut state = self.ctx.state.lock();
      state.waiting.remove(&id.get());

      // If the hashmap is empty and there's a killswitch waker stored in the
      // state, then signal it.
      if state.waiting.is_empty() {
        // Tell all tasks waiting to be notified when all waiters have been
        // dropped to wake up [because all waiters have been dropped].
        for (_, waker) in state.final_wait.drain() {
          waker.wake();
        }
      }
    }
  }
}


/// Used to wait for all waiters to deregister.
pub struct FinalizedFuture {
  ctx: Arc<Shared>,

  /// This waiter's identifier.
  ///
  /// Assigned lazily.
  id: Option<NonZeroUsize>
}

impl Future for FinalizedFuture {
  type Output = Result<(), ()>;

  fn poll(
    mut self: Pin<&mut Self>,
    ctx: &mut Context<'_>
  ) -> Poll<Self::Output> {
    if !self.ctx.triggered.load(Ordering::SeqCst) {
      // killswitch must be triggered
      Poll::Ready(Err(()))
    } else {
      let mut state = self.ctx.state.lock();
      if state.waiting.is_empty() {
        Poll::Ready(Ok(()))
      } else {
        let mut id = self.ctx.id();
        while id == 0 || state.waiting.contains_key(&id) {
          id = self.ctx.id();
        }
        state.final_wait.insert(id, ctx.waker().clone());
        drop(state);
        self.id = Some(unsafe { NonZeroUsize::new_unchecked(id) });
        Poll::Pending
      }
    }
  }
}

impl Drop for FinalizedFuture {
  fn drop(&mut self) {
    if let Some(id) = self.id {
      let mut state = self.ctx.state.lock();
      state.final_wait.remove(&id.get());
    }
  }
}

// vim: set ft=rust et sw=2 ts=2 sts=2 cinoptions=2 tw=79 :