1#![allow(non_snake_case)]
2use crate::context::error::{CANCELLED, ContextError, ContextErrorKind, DEADLINE_EXCEEDED};
3use crate::context::state::{CancelKind, CancelState, DoneHandle, StopFunc};
4use crate::context::value::ValueKey;
5use std::any::Any;
6use std::error::Error;
7use std::fmt::{Debug, Formatter};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::sync::OnceLock;
12use std::task::{Context as TaskContext, Poll};
13use std::thread;
14use std::time::{Duration, Instant};
15use tokio::runtime::Handle;
16
17pub type CancelFunc = Box<dyn FnOnce() + Send + 'static>;
19pub type CancelCauseFunc = Box<dyn FnOnce(Option<Arc<dyn Error + Send + Sync>>) + Send + 'static>;
21
22#[derive(Clone)]
23pub struct Context {
24 inner: Arc<ContextInner>,
25}
26
27enum ContextInner {
28 Empty,
29 Cancelable(CancelCtx),
30 Deadline(DeadlineCtx),
31 Value(ValueCtx),
32 WithoutCancel(WithoutCancelCtx),
33}
34
35#[derive(Clone)]
36struct CancelCtx {
37 parent: Context,
38 state: Arc<CancelState>,
39}
40
41#[derive(Clone)]
42struct DeadlineCtx {
43 parent: Context,
44 state: Arc<CancelState>,
45 deadline: Instant,
46}
47
48#[derive(Clone)]
49struct WithoutCancelCtx {
50 parent: Context,
51}
52
53struct ValueCtx {
54 parent: Context,
55 key: Arc<dyn ValueKey>,
56 value: Arc<dyn Any + Send + Sync>,
57}
58
59impl Debug for Context {
60 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61 f.debug_struct("Context").finish_non_exhaustive()
62 }
63}
64
65impl Context {
66 fn empty() -> Self {
67 Self {
68 inner: Arc::new(ContextInner::Empty),
69 }
70 }
71
72 fn cancelable(parent: Context, state: Arc<CancelState>) -> Self {
73 Self {
74 inner: Arc::new(ContextInner::Cancelable(CancelCtx { parent, state })),
75 }
76 }
77
78 fn new_deadline(parent: Context, state: Arc<CancelState>, deadline: Instant) -> Self {
79 Self {
80 inner: Arc::new(ContextInner::Deadline(DeadlineCtx {
81 parent,
82 state,
83 deadline,
84 })),
85 }
86 }
87
88 fn new_value(
89 parent: Context,
90 key: Arc<dyn ValueKey>,
91 value: Arc<dyn Any + Send + Sync>,
92 ) -> Self {
93 Self {
94 inner: Arc::new(ContextInner::Value(ValueCtx { parent, key, value })),
95 }
96 }
97
98 fn without_cancel(parent: Context) -> Self {
99 Self {
100 inner: Arc::new(ContextInner::WithoutCancel(WithoutCancelCtx { parent })),
101 }
102 }
103
104 pub fn deadline(&self) -> Option<Instant> {
105 match self.inner.as_ref() {
106 ContextInner::Empty => None,
107 ContextInner::Cancelable(ctx) => ctx.parent.deadline(),
108 ContextInner::Deadline(ctx) => Some(ctx.deadline),
109 ContextInner::Value(ctx) => ctx.parent.deadline(),
110 ContextInner::WithoutCancel(ctx) => ctx.parent.deadline(),
111 }
112 }
113
114 pub fn done(&self) -> DoneHandle {
115 match self.inner.as_ref() {
116 ContextInner::Empty => DoneHandle::never(),
117 ContextInner::Cancelable(ctx) => CancelState::done_handle(&ctx.state),
118 ContextInner::Deadline(ctx) => CancelState::done_handle(&ctx.state),
119 ContextInner::Value(ctx) => ctx.parent.done(),
120 ContextInner::WithoutCancel(_) => DoneHandle::never(),
121 }
122 }
123
124 pub fn done_async(&self) -> DoneFuture {
126 match self.done() {
127 DoneHandle::Never => DoneFuture::Never,
128 DoneHandle::Active(state) => {
129 if state.is_done() {
130 return DoneFuture::Ready;
131 }
132 let notify = state.notify();
133 DoneFuture::Wait(Box::pin(async move {
134 notify.notified().await;
135 }))
136 }
137 }
138 }
139
140 pub fn err(&self) -> Option<ContextError> {
141 match self.inner.as_ref() {
142 ContextInner::Empty => None,
143 ContextInner::Cancelable(ctx) => ctx.state.err(),
144 ContextInner::Deadline(ctx) => ctx.state.err(),
145 ContextInner::Value(ctx) => ctx.parent.err(),
146 ContextInner::WithoutCancel(_) => None,
147 }
148 }
149
150 pub fn cause(&self) -> Option<Arc<dyn Error + Send + Sync>> {
151 match self.inner.as_ref() {
152 ContextInner::Empty => None,
153 ContextInner::Cancelable(ctx) => ctx.state.cause(),
154 ContextInner::Deadline(ctx) => ctx.state.cause(),
155 ContextInner::Value(ctx) => ctx.parent.cause(),
156 ContextInner::WithoutCancel(_) => None,
157 }
158 }
159
160 pub fn value(&self, key: &dyn ValueKey) -> Option<Arc<dyn Any + Send + Sync>> {
161 match self.inner.as_ref() {
162 ContextInner::Value(ctx) => {
163 if ctx.key.equals(key) {
164 Some(ctx.value.clone())
165 } else {
166 ctx.parent.value(key)
167 }
168 }
169 ContextInner::Empty => None,
170 ContextInner::Cancelable(ctx) => ctx.parent.value(key),
171 ContextInner::Deadline(ctx) => ctx.parent.value(key),
172 ContextInner::WithoutCancel(ctx) => ctx.parent.value(key),
173 }
174 }
175}
176
177pub enum DoneFuture {
179 Ready,
180 Wait(Pin<Box<dyn Future<Output = ()> + Send + 'static>>),
181 Never,
182}
183
184impl Future for DoneFuture {
185 type Output = ();
186
187 fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
188 let this = unsafe { self.get_unchecked_mut() };
189 match this {
190 DoneFuture::Ready => Poll::Ready(()),
191 DoneFuture::Never => Poll::Pending,
192 DoneFuture::Wait(rx) => match rx.as_mut().poll(cx) {
193 Poll::Ready(_) => {
194 *this = DoneFuture::Ready;
195 Poll::Ready(())
196 }
197 Poll::Pending => Poll::Pending,
198 },
199 }
200 }
201}
202
203pub fn Background() -> Context {
204 static BG: OnceLock<Context> = OnceLock::new();
205 BG.get_or_init(Context::empty).clone()
206}
207
208pub fn TODO() -> Context {
209 static TD: OnceLock<Context> = OnceLock::new();
210 TD.get_or_init(Context::empty).clone()
211}
212
213pub fn WithoutCancel(parent: Context) -> Context {
214 Context::without_cancel(parent)
215}
216
217pub fn WithValue<K, V>(parent: Context, key: K, value: V) -> Context
218where
219 K: ValueKey,
220 V: Any + Send + Sync + 'static,
221{
222 Context::new_value(
223 parent,
224 Arc::new(key),
225 Arc::new(value) as Arc<dyn Any + Send + Sync>,
226 )
227}
228
229pub fn WithCancel(parent: Context) -> (Context, CancelFunc) {
230 WithCancelCause(parent).map_cancel(|f| Box::new(move || f(None)))
231}
232
233pub fn WithCancelCause(parent: Context) -> (Context, CancelCauseFunc) {
234 let state = CancelState::new();
235 propagate_parent(parent.clone(), state.clone());
236 let ctx = Context::cancelable(parent, state.clone());
237 let cancel = Box::new(move |cause: Option<Arc<dyn Error + Send + Sync>>| {
238 let final_cause = cause.or_else(default_canceled);
239 state.cancel(CancelKind::Canceled, final_cause);
240 });
241 (ctx, cancel)
242}
243
244pub fn WithDeadline(parent: Context, deadline: Instant) -> (Context, CancelFunc) {
245 WithDeadlineCause(parent, deadline, None).map_cancel(|f| Box::new(move || f()))
246}
247
248pub fn WithDeadlineCause(
249 parent: Context,
250 deadline: Instant,
251 cause: Option<Arc<dyn Error + Send + Sync>>,
252) -> (Context, CancelFunc) {
253 let effective_deadline = match parent.deadline() {
254 Some(parent_deadline) if parent_deadline <= deadline => parent_deadline,
255 _ => deadline,
256 };
257
258 let state = CancelState::new();
259 let ctx = Context::new_deadline(parent.clone(), state.clone(), effective_deadline);
260 propagate_parent(parent, state.clone());
261 start_deadline_timer(state.clone(), effective_deadline, cause.clone());
262
263 let cancel = Box::new(move || {
264 let cancel_cause = cause.clone().or_else(default_canceled);
265 state.cancel(CancelKind::Canceled, cancel_cause);
266 });
267 (ctx, cancel)
268}
269
270pub fn WithTimeout(parent: Context, timeout: Duration) -> (Context, CancelFunc) {
271 WithDeadline(parent, Instant::now() + timeout)
272}
273
274pub fn WithTimeoutCause(
275 parent: Context,
276 timeout: Duration,
277 cause: Option<Arc<dyn Error + Send + Sync>>,
278) -> (Context, CancelFunc) {
279 WithDeadlineCause(parent, Instant::now() + timeout, cause)
280}
281
282pub fn AfterFunc(ctx: &Context, f: impl FnOnce() + Send + 'static) -> StopFunc {
283 ctx.done().register(f)
284}
285
286pub fn Cause(ctx: &Context) -> Option<Arc<dyn Error + Send + Sync>> {
287 ctx.cause()
288}
289
290pub async fn Done(ctx: Context) -> Option<ContextError> {
292 ctx.done_async().await;
293 ctx.err()
294}
295
296pub async fn ContextAware<T, F>(ctx: Context, fut: F) -> Result<T, ContextError>
298where
299 F: Future<Output = Result<T, ContextError>>,
300{
301 let done = ctx.done_async();
302 tokio::select! {
303 res = fut => res,
304 _ = done => Err(ctx.err().unwrap_or(CANCELLED)),
305 }
306}
307
308fn start_deadline_timer(
309 state: Arc<CancelState>,
310 deadline: Instant,
311 cause: Option<Arc<dyn Error + Send + Sync>>,
312) {
313 if deadline <= Instant::now() {
314 let deadline_cause = cause.clone().or_else(default_deadline);
315 state.cancel(CancelKind::Deadline, deadline_cause);
316 return;
317 }
318 let sleep_dur = deadline.saturating_duration_since(Instant::now());
319 if let Ok(handle) = Handle::try_current() {
320 handle.spawn(async move {
321 tokio::time::sleep(sleep_dur).await;
322 let deadline_cause = cause.clone().or_else(default_deadline);
323 state.cancel(CancelKind::Deadline, deadline_cause);
324 });
325 } else {
326 thread::spawn(move || {
327 thread::sleep(sleep_dur);
328 let deadline_cause = cause.clone().or_else(default_deadline);
329 state.cancel(CancelKind::Deadline, deadline_cause);
330 });
331 }
332}
333
334fn propagate_parent(parent: Context, state: Arc<CancelState>) {
335 if state.is_done() {
336 return;
337 }
338 if let Some(err) = parent.err() {
339 let kind = map_error_kind(&err);
340 let inherited = parent
341 .cause()
342 .or_else(|| Some(Arc::new(err) as Arc<dyn Error + Send + Sync>));
343 state.cancel(kind, inherited);
344 return;
345 }
346 let done = parent.done();
347 done.register(move || {
348 let err = parent.err();
349 let kind = err
350 .as_ref()
351 .map(map_error_kind)
352 .unwrap_or(CancelKind::Canceled);
353 let inherited = parent
354 .cause()
355 .or_else(|| err.map(|e| Arc::new(e) as Arc<dyn Error + Send + Sync>));
356 state.cancel(kind, inherited);
357 });
358}
359
360fn map_error_kind(err: &ContextError) -> CancelKind {
361 match err.kind() {
362 ContextErrorKind::Canceled => CancelKind::Canceled,
363 ContextErrorKind::DeadlineExceeded => CancelKind::Deadline,
364 }
365}
366
367fn default_canceled() -> Option<Arc<dyn Error + Send + Sync>> {
368 Some(Arc::new(CANCELLED) as Arc<dyn Error + Send + Sync>)
369}
370
371fn default_deadline() -> Option<Arc<dyn Error + Send + Sync>> {
372 Some(Arc::new(DEADLINE_EXCEEDED) as Arc<dyn Error + Send + Sync>)
373}
374
375trait MapCancel<T> {
376 fn map_cancel(self, f: impl FnOnce(T) -> CancelFunc) -> (Context, CancelFunc);
377}
378
379impl MapCancel<CancelCauseFunc> for (Context, CancelCauseFunc) {
380 fn map_cancel(self, f: impl FnOnce(CancelCauseFunc) -> CancelFunc) -> (Context, CancelFunc) {
381 let (ctx, c) = self;
382 (ctx, f(c))
383 }
384}
385
386impl MapCancel<CancelFunc> for (Context, CancelFunc) {
387 fn map_cancel(self, f: impl FnOnce(CancelFunc) -> CancelFunc) -> (Context, CancelFunc) {
388 let (ctx, c) = self;
389 (ctx, f(c))
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use std::sync::atomic::{AtomicBool, Ordering};
397 use std::thread::sleep;
398
399 fn assert_canceled(ctx: &Context) {
400 let err = ctx.err().expect("expected canceled");
401 assert_eq!(err.kind(), ContextErrorKind::Canceled);
402 }
403
404 #[test]
405 fn background_never_cancels() {
406 let ctx = Background();
407 assert!(ctx.deadline().is_none());
408 assert!(ctx.done().is_done() == false);
409 assert!(ctx.err().is_none());
410 assert!(ctx.cause().is_none());
411 assert!(ctx.value(&"k").is_none());
412 }
413
414 #[test]
415 fn cancel_func_cancels() {
416 let (ctx, cancel) = WithCancel(Background());
417 cancel();
418 assert_canceled(&ctx);
419 assert!(matches!(Cause(&ctx), Some(_)));
420 }
421
422 #[test]
423 fn cancel_cause_propagates() {
424 #[derive(Debug)]
425 struct MyErr;
426 impl std::fmt::Display for MyErr {
427 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428 f.write_str("mine")
429 }
430 }
431 impl Error for MyErr {}
432
433 let (ctx, cancel) = WithCancelCause(Background());
434 let err = Arc::new(MyErr) as Arc<dyn Error + Send + Sync>;
435 cancel(Some(err.clone()));
436 assert_canceled(&ctx);
437 let cause = Cause(&ctx).unwrap();
438 assert!(cause.downcast_ref::<MyErr>().is_some());
439 }
440
441 #[test]
442 fn parent_deadline_cancels_child() {
443 let (parent, _) = WithTimeout(Background(), Duration::from_millis(50));
444 let (child, _) = WithCancel(parent);
445 sleep(Duration::from_millis(80));
446 let err = child.err().expect("child canceled");
447 assert_eq!(err.kind(), ContextErrorKind::DeadlineExceeded);
448 }
449
450 #[test]
451 fn deadline_timer_triggers() {
452 let (ctx, _) = WithDeadline(Background(), Instant::now() + Duration::from_millis(30));
453 sleep(Duration::from_millis(60));
454 let err = ctx.err().expect("deadline");
455 assert_eq!(err.kind(), ContextErrorKind::DeadlineExceeded);
456 }
457
458 #[test]
459 fn deadline_cause_used() {
460 #[derive(Debug)]
461 struct CauseErr;
462 impl std::fmt::Display for CauseErr {
463 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464 f.write_str("cause")
465 }
466 }
467 impl Error for CauseErr {}
468
469 let (ctx, cancel) = WithDeadlineCause(
470 Background(),
471 Instant::now() + Duration::from_millis(100),
472 Some(Arc::new(CauseErr)),
473 );
474 cancel();
475 let cause = Cause(&ctx).unwrap();
476 assert!(cause.downcast_ref::<CauseErr>().is_some());
477 }
478
479 #[test]
480 fn value_lookup_respects_hierarchy() {
481 let root = WithValue(Background(), "a", 1u32);
482 let child = WithValue(root, "b", 2u32);
483 let val_a = child.value(&"a").unwrap();
484 let val_b = child.value(&"b").unwrap();
485 assert_eq!(*val_a.downcast::<u32>().unwrap(), 1);
486 assert_eq!(*val_b.downcast::<u32>().unwrap(), 2);
487 }
488
489 #[test]
490 fn without_cancel_detaches() {
491 let (parent, cancel) = WithCancel(Background());
492 let child = WithoutCancel(parent);
493 cancel();
494 assert!(child.err().is_none());
495 assert!(child.cause().is_none());
496 assert!(child.done().is_done() == false);
497 }
498
499 #[test]
500 fn after_func_runs_on_cancel() {
501 let (ctx, cancel) = WithCancel(Background());
502 let flag = Arc::new(AtomicBool::new(false));
503 let mark = flag.clone();
504 AfterFunc(&ctx, move || {
505 mark.store(true, Ordering::SeqCst);
506 });
507 cancel();
508 ctx.done().wait();
509 std::thread::sleep(Duration::from_millis(10));
510 assert!(flag.load(Ordering::SeqCst));
511 }
512
513 #[test]
514 fn after_func_stop_on_never_done() {
515 let stop = AfterFunc(&Background(), || panic!("should not run"));
516 assert!(!stop.Stop());
517 }
518}