use std::sync::{Arc, Mutex, Weak};
use std::time::{Duration, Instant};
use crate::chan::{chan, Receiver, Sender};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ContextError {
Cancelled,
DeadlineExceeded,
}
impl std::fmt::Display for ContextError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Cancelled => f.write_str("context cancelled"),
Self::DeadlineExceeded => f.write_str("context deadline exceeded"),
}
}
}
struct ContextInner {
deadline: Option<Instant>,
done_tx: Mutex<Option<Sender<()>>>,
done_rx: Receiver<()>,
err: Mutex<Option<ContextError>>,
children: Mutex<Vec<Weak<ContextInner>>>,
}
impl ContextInner {
fn cancel(&self, err: ContextError) {
{
let mut e = self.err.lock().unwrap();
if e.is_some() { return; }
*e = Some(err.clone());
}
if let Some(tx) = self.done_tx.lock().unwrap().take() {
tx.close();
}
let children: Vec<Weak<ContextInner>> =
self.children.lock().unwrap().drain(..).collect();
for weak in children {
if let Some(child) = weak.upgrade() {
child.cancel(err.clone());
}
}
}
}
#[derive(Clone)]
pub struct Context(Arc<ContextInner>);
impl Context {
pub fn done(&self) -> &Receiver<()> {
&self.0.done_rx
}
pub fn deadline(&self) -> Option<Instant> {
self.0.deadline
}
pub fn err(&self) -> Option<ContextError> {
self.0.err.lock().unwrap().clone()
}
pub fn is_done(&self) -> bool {
self.err().is_some()
}
}
#[derive(Clone)]
pub struct CancelFn(Arc<ContextInner>);
impl CancelFn {
pub fn cancel(&self) {
self.0.cancel(ContextError::Cancelled);
}
}
pub fn background() -> Context {
let (done_tx, done_rx) = chan::<()>(0);
Context(Arc::new(ContextInner {
deadline: None,
done_tx: Mutex::new(Some(done_tx)),
done_rx,
err: Mutex::new(None),
children: Mutex::new(Vec::new()),
}))
}
pub fn with_cancel(parent: &Context) -> (Context, CancelFn) {
let (ctx, cancel) = make_child(parent, None);
(ctx, cancel)
}
pub fn with_deadline(parent: &Context, deadline: Instant) -> (Context, CancelFn) {
let (ctx, cancel) = make_child(parent, Some(deadline));
let cancel_dl = cancel.clone();
let now = Instant::now();
if deadline <= now {
cancel_dl.0.cancel(ContextError::DeadlineExceeded);
} else {
let d = deadline.duration_since(now);
let inner_weak = Arc::downgrade(&cancel_dl.0);
unsafe {
crate::runtime::sched::spawn_goroutine(move || {
crate::sleep(d);
if let Some(inner) = inner_weak.upgrade() {
inner.cancel(ContextError::DeadlineExceeded);
}
});
}
}
(ctx, cancel)
}
pub fn with_timeout(parent: &Context, timeout: Duration) -> (Context, CancelFn) {
with_deadline(parent, Instant::now() + timeout)
}
fn make_child(parent: &Context, deadline: Option<Instant>) -> (Context, CancelFn) {
let (done_tx, done_rx) = chan::<()>(0);
let inner = Arc::new(ContextInner {
deadline,
done_tx: Mutex::new(Some(done_tx)),
done_rx,
err: Mutex::new(None),
children: Mutex::new(Vec::new()),
});
let parent_inner = &parent.0;
let parent_err = parent_inner.err.lock().unwrap().clone();
if let Some(err) = parent_err {
inner.cancel(err);
} else {
parent_inner
.children
.lock()
.unwrap()
.push(Arc::downgrade(&inner));
}
let cancel_fn = CancelFn(Arc::clone(&inner));
(Context(inner), cancel_fn)
}
#[cfg(all(test, not(loom)))]
mod tests {
use super::*;
use crate::runtime::sched::run_impl;
use std::sync::atomic::{AtomicBool, Ordering};
#[test]
fn background_not_done() {
let bg = background();
assert!(bg.err().is_none());
assert!(!bg.is_done());
assert!(bg.deadline().is_none());
}
#[test]
fn with_cancel_cancels() {
let bg = background();
let (ctx, cancel) = with_cancel(&bg);
assert!(!ctx.is_done());
cancel.cancel();
assert_eq!(ctx.err(), Some(ContextError::Cancelled));
}
#[test]
fn with_cancel_idempotent() {
let bg = background();
let (ctx, cancel) = with_cancel(&bg);
cancel.cancel();
cancel.cancel(); assert_eq!(ctx.err(), Some(ContextError::Cancelled));
}
#[test]
fn cancel_propagates_to_child() {
let bg = background();
let (parent, parent_cancel) = with_cancel(&bg);
let (child, _child_cancel) = with_cancel(&parent);
parent_cancel.cancel();
assert_eq!(child.err(), Some(ContextError::Cancelled));
}
#[test]
fn child_cancel_does_not_affect_parent() {
let bg = background();
let (parent, _parent_cancel) = with_cancel(&bg);
let (_child, child_cancel) = with_cancel(&parent);
child_cancel.cancel();
assert!(parent.err().is_none(), "parent must not be cancelled by child");
}
#[test]
fn child_of_cancelled_parent_is_immediate() {
let bg = background();
let (parent, parent_cancel) = with_cancel(&bg);
parent_cancel.cancel();
let (child, _) = with_cancel(&parent);
assert!(child.is_done(), "child must inherit parent's cancellation");
}
#[test]
fn done_channel_fires_in_goroutine() {
let fired = std::sync::Arc::new(AtomicBool::new(false));
let fired2 = std::sync::Arc::clone(&fired);
run_impl(move || {
let bg = background();
let (ctx, cancel) = with_cancel(&bg);
unsafe {
crate::runtime::sched::spawn_goroutine(move || {
ctx.done().recv(); fired2.store(true, Ordering::Release);
});
}
for _ in 0..20 { crate::gosched(); }
cancel.cancel();
let deadline = Instant::now() + Duration::from_millis(500);
loop {
if fired.load(Ordering::Acquire) { break; }
assert!(Instant::now() < deadline, "done channel did not fire");
crate::gosched();
}
});
}
#[test]
fn with_timeout_cancels_after_duration() {
run_impl(|| {
let bg = background();
let (ctx, _cancel) = with_timeout(&bg, Duration::from_millis(20));
ctx.done().recv(); assert_eq!(ctx.err(), Some(ContextError::DeadlineExceeded));
});
}
#[test]
fn with_deadline_in_past_cancels_immediately() {
run_impl(|| {
let bg = background();
let past = Instant::now() - Duration::from_secs(1);
let (ctx, _cancel) = with_deadline(&bg, past);
assert!(ctx.is_done(), "past deadline must cancel immediately");
});
}
#[test]
fn cancel_fn_clone_works() {
let bg = background();
let (ctx, cancel1) = with_cancel(&bg);
let cancel2 = cancel1.clone();
cancel2.cancel(); assert!(ctx.is_done());
}
}