use std::{future::Future, pin::Pin, sync::Arc, time::Duration};
use tokio::sync::broadcast::{Receiver, Sender};
use tokio::sync::{broadcast, Mutex};
use tokio::time::Instant;
pub struct Context {
pub(crate) timeout: Option<Instant>,
pub(crate) cancel_receiver: Receiver<()>,
pub(crate) parent_ctx: Option<RefContext>,
}
#[derive(Clone)]
pub struct RefContext(Arc<Mutex<Context>>);
pub struct Handle {
pub(crate) timeout: Option<Instant>,
pub(crate) cancel_sender: Sender<()>,
pub(crate) parent_ctx: Option<RefContext>,
}
impl Handle {
pub fn cancel(self) {}
pub fn spawn_ctx(&mut self) -> Context {
if let Some(ref ctx) = self.parent_ctx {
Context {
timeout: self.timeout.clone(),
cancel_receiver: self.cancel_sender.subscribe(),
parent_ctx: Some(ctx.clone()),
}
} else {
Context {
timeout: self.timeout.clone(),
cancel_receiver: self.cancel_sender.subscribe(),
parent_ctx: None,
}
}
}
pub fn spawn_ref(&mut self) -> RefContext {
RefContext::from(self.spawn_ctx())
}
}
impl Context {
pub fn new() -> (Context, Handle) {
let (tx, _) = broadcast::channel(1);
let mut handle = Handle {
timeout: None,
cancel_sender: tx,
parent_ctx: None,
};
(handle.spawn_ctx(), handle)
}
pub fn with_timeout(timeout: Duration) -> (Context, Handle) {
let (tx, _) = broadcast::channel(1);
let mut handle = Handle {
timeout: Some(Instant::now() + timeout),
cancel_sender: tx,
parent_ctx: None,
};
(handle.spawn_ctx(), handle)
}
pub fn with_parent(parent_ctx: &RefContext, timeout: Option<Duration>) -> (Context, Handle) {
let timeout = if let Some(t) = timeout {
Some(Instant::now() + t)
} else {
None
};
let (tx, _) = broadcast::channel(1);
let mut handle = Handle {
timeout,
cancel_sender: tx,
parent_ctx: Some(parent_ctx.clone()),
};
(handle.spawn_ctx(), handle)
}
#[allow(unused_must_use)] pub fn done(&mut self) -> Pin<Box<dyn Future<Output = ()> + '_ + Send>> {
Box::pin(async move {
match (self.timeout, self.parent_ctx.as_ref()) {
(Some(instant), None) => {
tokio::select! {
_ = tokio::time::sleep_until(instant) => return,
_ = self.cancel_receiver.recv() => return,
}
}
(None, None) => {
self.cancel_receiver.recv().await;
}
(Some(instant), Some(ctx)) => {
let parent_ctx = ctx.clone();
let mut inner = parent_ctx.0.lock().await;
tokio::select! {
_ = tokio::time::sleep_until(instant) => return,
_ = self.cancel_receiver.recv() => return,
_ = inner.done() => return,
}
}
(None, Some(ctx)) => {
let parent_ctx = ctx.clone();
let mut inner = parent_ctx.0.lock().await;
tokio::select! {
_ = self.cancel_receiver.recv() => return,
_ = inner.done() => return,
}
}
}
})
}
}
impl RefContext {
pub fn new() -> (RefContext, Handle) {
let (context, handle) = Context::new();
(RefContext::from(context), handle)
}
pub fn with_timeout(timeout: Duration) -> (RefContext, Handle) {
let (context, handle) = Context::with_timeout(timeout);
(RefContext::from(context), handle)
}
pub fn with_parent(parent_ctx: &RefContext, timeout: Option<Duration>) -> (RefContext, Handle) {
let (context, handle) = Context::with_parent(parent_ctx, timeout);
(RefContext::from(context), handle)
}
pub fn done(&mut self) -> Pin<Box<dyn Future<Output = ()> + '_>> {
let soft_copy = self.clone();
Box::pin(async move {
let mut inner = soft_copy.0.lock().await;
inner.done().await
})
}
}
impl From<Context> for RefContext {
fn from(ctx: Context) -> Self {
RefContext(Arc::new(Mutex::new(ctx)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn cancel_handle_cancels_context() {
let (mut ctx, handle) = Context::new();
handle.cancel();
tokio::select! {
_ = ctx.done() => assert!(true),
_ = tokio::time::sleep(Duration::from_millis(1)) => assert!(false),
}
}
#[tokio::test]
async fn duration_cancels_context() {
let (mut ctx, _handle) = Context::with_timeout(Duration::from_millis(10));
tokio::select! {
_ = ctx.done() => assert!(true),
_ = tokio::time::sleep(Duration::from_millis(15)) => assert!(false),
}
}
#[tokio::test]
async fn cancelling_parent_ctx_cancels_child() {
let (parent_ctx, parent_handle) = RefContext::new();
let (mut ctx, _handle) = Context::with_parent(&parent_ctx, None);
parent_handle.cancel();
tokio::select! {
_ = ctx.done() => assert!(true),
_ = tokio::time::sleep(Duration::from_millis(15)) => assert!(false),
}
}
#[tokio::test]
async fn cancelling_child_ctx_doesnt_cancel_parent() {
let (mut parent_ctx, _parent_handle) = RefContext::new();
let (_ctx, handle) = Context::with_parent(&parent_ctx, None);
handle.cancel();
tokio::select! {
_ = parent_ctx.done() => assert!(false),
_ = async {} => assert!(true),
}
}
#[tokio::test]
async fn parent_timeout_cancels_child() {
let (parent_ctx, _parent_handle) = RefContext::with_timeout(Duration::from_millis(5));
let (mut ctx, _handle) = Context::with_parent(&parent_ctx, Some(Duration::from_millis(10)));
tokio::select! {
_ = ctx.done() => assert!(true),
_ = tokio::time::sleep(Duration::from_millis(7)) => assert!(false),
}
}
}