1#![allow(non_snake_case)]
2use crate::context::error::{CANCELLED, ContextError, DEADLINE_EXCEEDED, Error};
3use crate::context::state::{CancelState, DoneHandle, StopFunc};
4use crate::context::value::ValueKey;
5use std::any::Any;
6use std::fmt::{Debug, Formatter};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::OnceLock;
11use std::task::{Context as TaskContext, Poll};
12use std::thread;
13use std::time::{Duration, Instant};
14use tokio::runtime::Handle;
15
16pub type CancelFunc = Box<dyn FnOnce() + Send + 'static>;
18pub type CancelCauseFunc =
20 Box<dyn FnOnce(Option<Arc<dyn std::error::Error + Send + Sync>>) + Send + 'static>;
21
22pub struct Context {
23 inner: Arc<ContextInner>,
24}
25
26enum ContextInner {
27 Empty,
28 Cancelable(CancelCtx),
29 Deadline(DeadlineCtx),
30 Value(ValueCtx),
31 WithoutCancel(WithoutCancelCtx),
32}
33
34struct CancelCtx {
35 parent: Context,
36 state: Arc<CancelState>,
37}
38
39struct DeadlineCtx {
40 parent: Context,
41 state: Arc<CancelState>,
42 deadline: Instant,
43}
44
45struct WithoutCancelCtx {
46 parent: Context,
47}
48
49struct ValueCtx {
50 parent: Context,
51 key: Arc<dyn ValueKey>,
52 value: Arc<dyn Any + Send + Sync>,
53}
54
55impl Clone for Context {
56 fn clone(&self) -> Self {
57 if let Some(state) = self.state_arc_opt() {
58 state.add_handle();
59 }
60 Self {
61 inner: self.inner.clone(),
62 }
63 }
64}
65
66impl Drop for Context {
67 fn drop(&mut self) {
68 if let Some(state) = self.state_arc_opt() {
69 state.release_handle();
70 }
71 }
72}
73
74impl Debug for Context {
75 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76 f.debug_struct("Context").finish_non_exhaustive()
77 }
78}
79
80impl Context {
81 fn empty() -> Self {
82 Self {
83 inner: Arc::new(ContextInner::Empty),
84 }
85 }
86
87 fn cancelable(parent: Context, state: Arc<CancelState>) -> Self {
88 Self {
89 inner: Arc::new(ContextInner::Cancelable(CancelCtx { parent, state })),
90 }
91 }
92
93 fn new_deadline(parent: Context, state: Arc<CancelState>, deadline: Instant) -> Self {
94 Self {
95 inner: Arc::new(ContextInner::Deadline(DeadlineCtx {
96 parent,
97 state,
98 deadline,
99 })),
100 }
101 }
102
103 fn new_value(
104 parent: Context,
105 key: Arc<dyn ValueKey>,
106 value: Arc<dyn Any + Send + Sync>,
107 ) -> Self {
108 Self {
109 inner: Arc::new(ContextInner::Value(ValueCtx { parent, key, value })),
110 }
111 }
112
113 fn without_cancel(parent: Context) -> Self {
114 Self {
115 inner: Arc::new(ContextInner::WithoutCancel(WithoutCancelCtx { parent })),
116 }
117 }
118
119 pub fn deadline(&self) -> Option<Instant> {
120 match self.inner.as_ref() {
121 ContextInner::Empty => None,
122 ContextInner::Cancelable(ctx) => ctx.parent.deadline(),
123 ContextInner::Deadline(ctx) => Some(ctx.deadline),
124 ContextInner::Value(ctx) => ctx.parent.deadline(),
125 ContextInner::WithoutCancel(ctx) => ctx.parent.deadline(),
126 }
127 }
128
129 pub fn done(&self) -> DoneHandle {
130 match self.inner.as_ref() {
131 ContextInner::Empty => DoneHandle::never(),
132 ContextInner::Cancelable(ctx) => CancelState::done_handle(&ctx.state),
133 ContextInner::Deadline(ctx) => CancelState::done_handle(&ctx.state),
134 ContextInner::Value(ctx) => ctx.parent.done(),
135 ContextInner::WithoutCancel(_) => DoneHandle::never(),
136 }
137 }
138
139 pub fn done_async(&self) -> DoneFuture {
141 match self.done() {
142 DoneHandle::Never => DoneFuture::Never,
143 DoneHandle::Active(state) => {
144 if state.is_done() {
145 return DoneFuture::Ready;
146 }
147 let notify = state.notify();
148 DoneFuture::Wait(Box::pin(async move {
149 notify.notified().await;
150 }))
151 }
152 }
153 }
154
155 pub fn err(&self) -> Option<ContextError> {
156 match self.inner.as_ref() {
157 ContextInner::Empty => None,
158 ContextInner::Cancelable(ctx) => ctx.state.err(),
159 ContextInner::Deadline(ctx) => ctx.state.err(),
160 ContextInner::Value(ctx) => ctx.parent.err(),
161 ContextInner::WithoutCancel(_) => None,
162 }
163 }
164
165 pub fn cause(&self) -> Option<Arc<dyn std::error::Error + Send + Sync>> {
166 match self.inner.as_ref() {
167 ContextInner::Empty => None,
168 ContextInner::Cancelable(ctx) => ctx.state.cause(),
169 ContextInner::Deadline(ctx) => ctx.state.cause(),
170 ContextInner::Value(ctx) => ctx.parent.cause(),
171 ContextInner::WithoutCancel(_) => None,
172 }
173 }
174
175 pub fn value(&self, key: &dyn ValueKey) -> Option<Arc<dyn Any + Send + Sync>> {
176 match self.inner.as_ref() {
177 ContextInner::Value(ctx) => {
178 if ctx.key.equals(key) {
179 Some(ctx.value.clone())
180 } else {
181 ctx.parent.value(key)
182 }
183 }
184 ContextInner::Empty => None,
185 ContextInner::Cancelable(ctx) => ctx.parent.value(key),
186 ContextInner::Deadline(ctx) => ctx.parent.value(key),
187 ContextInner::WithoutCancel(ctx) => ctx.parent.value(key),
188 }
189 }
190
191 fn state_arc(&self) -> Arc<CancelState> {
192 match self.inner.as_ref() {
193 ContextInner::Cancelable(ctx) => ctx.state.clone(),
194 ContextInner::Deadline(ctx) => ctx.state.clone(),
195 ContextInner::Value(ctx) => ctx.parent.state_arc(),
196 ContextInner::WithoutCancel(_) | ContextInner::Empty => CancelState::new_root(),
197 }
198 }
199
200 fn state_arc_opt(&self) -> Option<Arc<CancelState>> {
201 match self.inner.as_ref() {
202 ContextInner::Cancelable(ctx) => Some(ctx.state.clone()),
203 ContextInner::Deadline(ctx) => Some(ctx.state.clone()),
204 ContextInner::Value(ctx) => ctx.parent.state_arc_opt(),
205 ContextInner::WithoutCancel(_) | ContextInner::Empty => None,
206 }
207 }
208}
209
210pub enum DoneFuture {
212 Ready,
213 Wait(Pin<Box<dyn Future<Output = ()> + Send + 'static>>),
214 Never,
215}
216
217impl Future for DoneFuture {
218 type Output = ();
219
220 fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
221 let this = unsafe { self.get_unchecked_mut() };
222 match this {
223 DoneFuture::Ready => Poll::Ready(()),
224 DoneFuture::Never => Poll::Pending,
225 DoneFuture::Wait(rx) => match rx.as_mut().poll(cx) {
226 Poll::Ready(_) => {
227 *this = DoneFuture::Ready;
228 Poll::Ready(())
229 }
230 Poll::Pending => Poll::Pending,
231 },
232 }
233 }
234}
235
236pub fn Background() -> Context {
237 static BACKGROUND_CTX: OnceLock<Context> = OnceLock::new();
238 BACKGROUND_CTX.get_or_init(Context::empty).clone()
239}
240
241pub fn TODO() -> Context {
242 static TODO_CTX: OnceLock<Context> = OnceLock::new();
243 TODO_CTX.get_or_init(Context::empty).clone()
244}
245
246pub fn WithoutCancel(parent: Context) -> Context {
247 Context::without_cancel(parent)
248}
249
250pub fn WithValue<K, V>(parent: Context, key: K, value: V) -> Context
251where
252 K: ValueKey,
253 V: Any + Send + Sync + 'static,
254{
255 Context::new_value(
256 parent,
257 Arc::new(key),
258 Arc::new(value) as Arc<dyn Any + Send + Sync>,
259 )
260}
261
262pub fn WithCancel(parent: Context) -> (Context, CancelFunc) {
263 WithCancelCause(parent).map_cancel(|f| Box::new(move || f(None)))
264}
265
266pub fn WithCancelCause(parent: Context) -> (Context, CancelCauseFunc) {
267 let state = match parent.inner.as_ref() {
268 ContextInner::Empty => CancelState::new_root(),
269 _ => CancelState::child_of(&parent.state_arc()),
270 };
271 propagate_parent(parent.clone(), state.clone());
272 let ctx = Context::cancelable(parent, state.clone());
273 let cancel = Box::new(
274 move |cause: Option<Arc<dyn std::error::Error + Send + Sync>>| {
275 let final_cause = cause.or_else(default_canceled);
276 state.cancel(Error::Canceled, final_cause);
277 },
278 );
279 (ctx, cancel)
280}
281
282pub fn WithDeadline(parent: Context, deadline: Instant) -> (Context, CancelFunc) {
283 WithDeadlineCause(parent, deadline, None).map_cancel(|f| Box::new(move || f()))
284}
285
286pub fn WithDeadlineCause(
287 parent: Context,
288 deadline: Instant,
289 cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
290) -> (Context, CancelFunc) {
291 let effective_deadline = match parent.deadline() {
292 Some(parent_deadline) if parent_deadline <= deadline => parent_deadline,
293 _ => deadline,
294 };
295
296 let state = match parent.inner.as_ref() {
297 ContextInner::Empty => CancelState::new_root(),
298 _ => CancelState::child_of(&parent.state_arc()),
299 };
300 let ctx = Context::new_deadline(parent.clone(), state.clone(), effective_deadline);
301 propagate_parent(parent, state.clone());
302 start_deadline_timer(state.clone(), effective_deadline, cause.clone());
303
304 let cancel = Box::new(move || {
305 let cancel_cause = cause.clone().or_else(default_canceled);
306 state.cancel(Error::Canceled, cancel_cause);
307 });
308 (ctx, cancel)
309}
310
311pub fn WithTimeout(parent: Context, timeout: Duration) -> (Context, CancelFunc) {
312 WithDeadline(parent, Instant::now() + timeout)
313}
314
315pub fn WithTimeoutCause(
316 parent: Context,
317 timeout: Duration,
318 cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
319) -> (Context, CancelFunc) {
320 WithDeadlineCause(parent, Instant::now() + timeout, cause)
321}
322
323pub fn AfterFunc(ctx: &Context, f: impl FnOnce() + Send + 'static) -> StopFunc {
324 ctx.done().register(f)
325}
326
327pub fn Cause(ctx: &Context) -> Option<Arc<dyn std::error::Error + Send + Sync>> {
328 ctx.cause()
329}
330
331pub async fn Done(ctx: Context) -> Option<ContextError> {
333 ctx.done_async().await;
334 ctx.err()
335}
336
337pub async fn ContextAware<T, F>(ctx: Context, fut: F) -> Result<T, ContextError>
339where
340 F: Future<Output = Result<T, ContextError>>,
341{
342 let done = ctx.done_async();
343 tokio::select! {
344 res = fut => res,
345 _ = done => Err(ctx.err().unwrap_or(CANCELLED)),
346 }
347}
348
349fn start_deadline_timer(
350 state: Arc<CancelState>,
351 deadline: Instant,
352 cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
353) {
354 if deadline <= Instant::now() {
355 let deadline_cause = cause.clone().or_else(default_deadline);
356 state.cancel(Error::DeadlineExceeded, deadline_cause);
357 return;
358 }
359 let sleep_dur = deadline.saturating_duration_since(Instant::now());
360 if let Ok(handle) = Handle::try_current() {
361 handle.spawn(async move {
362 tokio::time::sleep(sleep_dur).await;
363 let deadline_cause = cause.clone().or_else(default_deadline);
364 state.cancel(Error::DeadlineExceeded, deadline_cause);
365 });
366 } else {
367 thread::spawn(move || {
368 thread::sleep(sleep_dur);
369 let deadline_cause = cause.clone().or_else(default_deadline);
370 state.cancel(Error::DeadlineExceeded, deadline_cause);
371 });
372 }
373}
374
375fn propagate_parent(parent: Context, state: Arc<CancelState>) {
376 if state.is_done() {
377 return;
378 }
379 if let Some(err) = parent.err() {
380 let kind = map_error_kind(&err);
381 let inherited = parent
382 .cause()
383 .or_else(|| Some(Arc::new(err) as Arc<dyn std::error::Error + Send + Sync>));
384 state.cancel(kind, inherited);
385 return;
386 }
387 let done = parent.done();
388 done.register(move || {
389 let err = parent.err();
390 let kind = err.as_ref().map(map_error_kind).unwrap_or(Error::Canceled);
391 let inherited = parent
392 .cause()
393 .or_else(|| err.map(|e| Arc::new(e) as Arc<dyn std::error::Error + Send + Sync>));
394 state.cancel(kind, inherited);
395 });
396}
397
398fn map_error_kind(err: &ContextError) -> Error {
399 err.kind()
400}
401
402fn default_canceled() -> Option<Arc<dyn std::error::Error + Send + Sync>> {
403 Some(Arc::new(CANCELLED) as Arc<dyn std::error::Error + Send + Sync>)
404}
405
406fn default_deadline() -> Option<Arc<dyn std::error::Error + Send + Sync>> {
407 Some(Arc::new(DEADLINE_EXCEEDED) as Arc<dyn std::error::Error + Send + Sync>)
408}
409
410trait MapCancel<T> {
411 fn map_cancel(self, f: impl FnOnce(T) -> CancelFunc) -> (Context, CancelFunc);
412}
413
414impl MapCancel<CancelCauseFunc> for (Context, CancelCauseFunc) {
415 fn map_cancel(self, f: impl FnOnce(CancelCauseFunc) -> CancelFunc) -> (Context, CancelFunc) {
416 let (ctx, c) = self;
417 (ctx, f(c))
418 }
419}
420
421impl MapCancel<CancelFunc> for (Context, CancelFunc) {
422 fn map_cancel(self, f: impl FnOnce(CancelFunc) -> CancelFunc) -> (Context, CancelFunc) {
423 let (ctx, c) = self;
424 (ctx, f(c))
425 }
426}
427
428#[cfg(test)]
429mod tests {
430 use super::*;
431 use std::sync::atomic::{AtomicBool, Ordering};
432 use std::thread::sleep;
433
434 fn assert_canceled(ctx: &Context) {
435 let err = ctx.err().expect("expected canceled");
436 assert_eq!(err.kind(), Error::Canceled);
437 }
438
439 #[test]
440 fn background_never_cancels() {
441 let ctx = Background();
442 assert!(ctx.deadline().is_none());
443 assert!(ctx.done().is_done() == false);
444 assert!(ctx.err().is_none());
445 assert!(ctx.cause().is_none());
446 assert!(ctx.value(&"k").is_none());
447 }
448
449 #[test]
450 fn cancel_func_cancels() {
451 let (ctx, cancel) = WithCancel(Background());
452 cancel();
453 assert_canceled(&ctx);
454 assert!(matches!(Cause(&ctx), Some(_)));
455 }
456
457 #[test]
458 fn cancel_cause_propagates() {
459 #[derive(Debug)]
460 struct MyErr;
461 impl std::fmt::Display for MyErr {
462 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463 f.write_str("mine")
464 }
465 }
466 impl std::error::Error for MyErr {}
467
468 let (ctx, cancel) = WithCancelCause(Background());
469 let err = Arc::new(MyErr) as Arc<dyn std::error::Error + Send + Sync>;
470 cancel(Some(err.clone()));
471 assert_canceled(&ctx);
472 let cause = Cause(&ctx).unwrap();
473 assert!(cause.downcast_ref::<MyErr>().is_some());
474 }
475
476 #[test]
477 fn parent_deadline_cancels_child() {
478 let (parent, _) = WithTimeout(Background(), Duration::from_millis(50));
479 let (child, _) = WithCancel(parent);
480 sleep(Duration::from_millis(80));
481 let err = child.err().expect("child canceled");
482 assert_eq!(err.kind(), Error::DeadlineExceeded);
483 }
484
485 #[test]
486 fn deadline_timer_triggers() {
487 let (ctx, _) = WithDeadline(Background(), Instant::now() + Duration::from_millis(30));
488 sleep(Duration::from_millis(60));
489 let err = ctx.err().expect("deadline");
490 assert_eq!(err.kind(), Error::DeadlineExceeded);
491 }
492
493 #[test]
494 fn deadline_cause_used() {
495 #[derive(Debug)]
496 struct CauseErr;
497 impl std::fmt::Display for CauseErr {
498 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499 f.write_str("cause")
500 }
501 }
502 impl std::error::Error for CauseErr {}
503
504 let (ctx, cancel) = WithDeadlineCause(
505 Background(),
506 Instant::now() + Duration::from_millis(100),
507 Some(Arc::new(CauseErr)),
508 );
509 cancel();
510 let cause = Cause(&ctx).unwrap();
511 assert!(cause.downcast_ref::<CauseErr>().is_some());
512 }
513
514 #[test]
515 fn value_lookup_respects_hierarchy() {
516 let root = WithValue(Background(), "a", 1u32);
517 let child = WithValue(root, "b", 2u32);
518 let val_a = child.value(&"a").unwrap();
519 let val_b = child.value(&"b").unwrap();
520 assert_eq!(*val_a.downcast::<u32>().unwrap(), 1);
521 assert_eq!(*val_b.downcast::<u32>().unwrap(), 2);
522 }
523
524 #[test]
525 fn without_cancel_detaches() {
526 let (parent, cancel) = WithCancel(Background());
527 let child = WithoutCancel(parent);
528 cancel();
529 assert!(child.err().is_none());
530 assert!(child.cause().is_none());
531 assert!(child.done().is_done() == false);
532 }
533
534 #[test]
535 fn after_func_runs_on_cancel() {
536 let (ctx, cancel) = WithCancel(Background());
537 let flag = Arc::new(AtomicBool::new(false));
538 let mark = flag.clone();
539 AfterFunc(&ctx, move || {
540 mark.store(true, Ordering::SeqCst);
541 });
542 cancel();
543 ctx.done().wait();
544 std::thread::sleep(Duration::from_millis(10));
545 assert!(flag.load(Ordering::SeqCst));
546 }
547
548 #[test]
549 fn after_func_stop_on_never_done() {
550 let stop = AfterFunc(&Background(), || panic!("should not run"));
551 assert!(!stop.Stop());
552 }
553}