use std::{
error, fmt,
future::{pending, Future},
pin::{pin, Pin},
sync::{Arc, Mutex, Weak},
task::{Context, Poll, Waker},
};
use slabmap::SlabMap;
#[cfg(doctest)]
pub mod tests {
#[doc = include_str!("../README.md")]
pub mod readme {}
}
struct RawTokenSource(Mutex<Option<Data>>);
impl RawTokenSource {
fn new(parent: CancellationTokenRegistration) -> Self {
Self(Mutex::new(Some(Data::new(parent))))
}
fn is_canceled(&self) -> bool {
self.0.lock().unwrap().is_none()
}
}
impl OnCanceled for RawTokenSource {
fn on_canceled(&self) {
let Some(data) = self.0.lock().unwrap().take() else {
return;
};
data.cbs.into_iter().for_each(|(_, cb)| cb.on_canceled());
}
}
struct Data {
cbs: SlabMap<CancelCallback>,
_parent: CancellationTokenRegistration, }
impl Data {
fn new(parent: CancellationTokenRegistration) -> Self {
Self {
cbs: SlabMap::new(),
_parent: parent,
}
}
}
pub struct CancellationTokenSource(Option<Arc<RawTokenSource>>);
impl CancellationTokenSource {
fn new_canceled() -> Self {
Self(None)
}
pub fn new() -> Self {
Self(Some(Arc::new(RawTokenSource::new(
CancellationTokenRegistration(None),
))))
}
#[doc(alias = "CreateLinkedTokenSource")]
pub fn with_parent(parent: &CancellationToken) -> Self {
match &parent.0 {
RawToken::IsCanceled(true) => Self::new_canceled(),
RawToken::IsCanceled(false) => Self::new(),
RawToken::Source(source) => {
if let Some(data) = &mut *source.0.lock().unwrap() {
Self(Some(Arc::new_cyclic(|child: &Weak<RawTokenSource>| {
RawTokenSource::new(CancellationTokenRegistration(Some(RawRegistration {
source: source.clone(),
key: data.cbs.insert(CancelCallback::Weak(child.clone())),
})))
})))
} else {
Self::new_canceled()
}
}
}
}
pub fn cancel(&self) {
if let Some(source) = &self.0 {
source.on_canceled();
}
}
pub fn cancel_defer(&self) -> CancelOnDrop {
CancelOnDrop(Some(self.clone()))
}
pub fn token(&self) -> CancellationToken {
if let Some(source) = &self.0 {
CancellationToken(RawToken::Source(source.clone()))
} else {
CancellationToken(RawToken::IsCanceled(true))
}
}
pub fn is_canceled(&self) -> bool {
if let Some(source) = &self.0 {
source.is_canceled()
} else {
true
}
}
}
impl Clone for CancellationTokenSource {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl Default for CancellationTokenSource {
fn default() -> Self {
Self::new()
}
}
impl fmt::Debug for CancellationTokenSource {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CancellationTokenSource")
.field("is_canceled", &self.is_canceled())
.finish()
}
}
#[derive(Clone)]
enum RawToken {
IsCanceled(bool),
Source(Arc<RawTokenSource>),
}
#[derive(Clone)]
pub struct CancellationToken(RawToken);
impl CancellationToken {
pub const fn new(is_canceled: bool) -> Self {
Self(RawToken::IsCanceled(is_canceled))
}
pub fn can_be_canceled(&self) -> bool {
match &self.0 {
RawToken::IsCanceled(is_canceled) => *is_canceled,
RawToken::Source(_) => true,
}
}
pub fn is_canceled(&self) -> bool {
match &self.0 {
RawToken::IsCanceled(is_canceled) => *is_canceled,
RawToken::Source(source) => source.is_canceled(),
}
}
#[doc(alias = "ThrowIfCancellationRequested")]
pub fn canceled(&self) -> MayBeCanceled {
if self.is_canceled() {
Err(Canceled)
} else {
Ok(())
}
}
pub fn register(&self, cb: CancelCallback) -> CancellationTokenRegistration {
let is_canceled = match &self.0 {
RawToken::IsCanceled(is_canceled) => *is_canceled,
RawToken::Source(source) => {
if let Some(data) = &mut *source.0.lock().unwrap() {
return CancellationTokenRegistration(Some(RawRegistration {
source: source.clone(),
key: data.cbs.insert(cb),
}));
} else {
true
}
}
};
if is_canceled {
cb.on_canceled();
}
CancellationTokenRegistration::empty()
}
pub async fn wait(&self) {
match &self.0 {
RawToken::IsCanceled(false) => pending().await,
RawToken::IsCanceled(true) => {}
RawToken::Source(source) => WaitForCancellation(WakerRegistration::new(source)).await,
}
}
pub async fn run<T>(&self, future: impl Future<Output = T>) -> MayBeCanceled<T> {
match &self.0 {
RawToken::IsCanceled(false) => Ok(future.await),
RawToken::IsCanceled(true) => Err(Canceled),
RawToken::Source(source) => {
WithCanceled {
r: WakerRegistration::new(source),
future: pin!(future),
}
.await
}
}
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new(false)
}
}
impl fmt::Debug for CancellationToken {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CancellationToken")
.field("can_be_canceled", &self.can_be_canceled())
.field("is_canceled", &self.is_canceled())
.finish()
}
}
pub trait OnCanceled: Sync + Send {
fn on_canceled(&self);
}
#[non_exhaustive]
pub enum CancelCallback {
FnOnce(Box<dyn FnOnce() + Sync + Send>),
Waker(Waker),
Box(Box<dyn OnCanceled>),
Arc(Arc<dyn OnCanceled>),
Weak(Weak<dyn OnCanceled>),
}
impl CancelCallback {
fn on_canceled(self) {
match self {
Self::FnOnce(f) => f(),
Self::Waker(w) => w.wake(),
Self::Box(b) => b.on_canceled(),
Self::Arc(a) => a.on_canceled(),
Self::Weak(w) => {
if let Some(w) = w.upgrade() {
w.on_canceled();
}
}
}
}
}
struct RawRegistration {
source: Arc<RawTokenSource>,
key: usize,
}
#[derive(Default)]
pub struct CancellationTokenRegistration(Option<RawRegistration>);
impl CancellationTokenRegistration {
fn empty() -> Self {
Self(None)
}
pub fn detach(mut self) {
self.0.take();
}
}
impl fmt::Debug for CancellationTokenRegistration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CancellationTokenRegistration")
.field("is_empty", &self.0.is_none())
.finish()
}
}
impl Drop for CancellationTokenRegistration {
fn drop(&mut self) {
if let Some(raw) = self.0.take() {
if let Some(data) = &mut *raw.source.0.lock().unwrap() {
data.cbs.remove(raw.key);
}
}
}
}
struct WakerRegistration<'a> {
source: &'a RawTokenSource,
key: Option<usize>,
}
impl<'a> WakerRegistration<'a> {
pub fn new(source: &'a RawTokenSource) -> Self {
Self { source, key: None }
}
pub fn is_canceled(&self) -> bool {
self.source.is_canceled()
}
pub fn set(&mut self, waker: &Waker) -> bool {
if let Some(data) = &mut *self.source.0.lock().unwrap() {
let cb = CancelCallback::Waker(waker.clone());
if let Some(key) = self.key {
data.cbs[key] = cb;
} else {
self.key = Some(data.cbs.insert(cb));
}
true
} else {
false
}
}
}
impl Drop for WakerRegistration<'_> {
fn drop(&mut self) {
if let Some(key) = self.key.take() {
if let Some(data) = &mut *self.source.0.lock().unwrap() {
data.cbs.remove(key);
}
}
}
}
struct WaitForCancellation<'a>(WakerRegistration<'a>);
impl Future for WaitForCancellation<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if this.0.set(cx.waker()) {
Poll::Pending
} else {
Poll::Ready(())
}
}
}
struct WithCanceled<'a, Fut> {
r: WakerRegistration<'a>,
future: Pin<&'a mut Fut>,
}
impl<Fut: Future> Future for WithCanceled<'_, Fut> {
type Output = Result<Fut::Output, Canceled>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.r.is_canceled() {
return Poll::Ready(Err(Canceled));
}
match Pin::new(&mut self.future).poll(cx) {
Poll::Pending => {
if self.r.set(cx.waker()) {
Poll::Pending
} else {
Poll::Ready(Err(Canceled))
}
}
Poll::Ready(v) => Poll::Ready(Ok(v)),
}
}
}
pub type MayBeCanceled<T = ()> = Result<T, Canceled>;
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, PartialOrd, Ord)]
pub struct Canceled;
impl fmt::Display for Canceled {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
"operation has been cancelled".fmt(f)
}
}
impl error::Error for Canceled {}
pub struct CancelOnDrop(Option<CancellationTokenSource>);
impl CancelOnDrop {
pub fn detach(mut self) {
self.0.take();
}
}
impl Drop for CancelOnDrop {
fn drop(&mut self) {
if let Some(source) = self.0.take() {
source.cancel();
}
}
}