core/context/
impls.rs

1#![allow(non_snake_case)]
2use crate::context::error::{CANCELLED, ContextError, ContextErrorKind, DEADLINE_EXCEEDED};
3use crate::context::state::{CancelKind, CancelState, DoneHandle, StopFunc};
4use crate::context::value::ValueKey;
5use std::any::Any;
6use std::error::Error;
7use std::fmt::{Debug, Formatter};
8use std::future::Future;
9use std::pin::Pin;
10use std::sync::Arc;
11use std::sync::OnceLock;
12use std::task::{Context as TaskContext, Poll};
13use std::thread;
14use std::time::{Duration, Instant};
15use tokio::runtime::Handle;
16
17/// 与 Go `context` 包一致的取消函数。
18pub type CancelFunc = Box<dyn FnOnce() + Send + 'static>;
19/// 与 Go `context` 包一致的带 cause 取消函数。
20pub type CancelCauseFunc = Box<dyn FnOnce(Option<Arc<dyn Error + Send + Sync>>) + Send + 'static>;
21
22#[derive(Clone)]
23pub struct Context {
24    inner: Arc<ContextInner>,
25}
26
27enum ContextInner {
28    Empty,
29    Cancelable(CancelCtx),
30    Deadline(DeadlineCtx),
31    Value(ValueCtx),
32    WithoutCancel(WithoutCancelCtx),
33}
34
35#[derive(Clone)]
36struct CancelCtx {
37    parent: Context,
38    state: Arc<CancelState>,
39}
40
41#[derive(Clone)]
42struct DeadlineCtx {
43    parent: Context,
44    state: Arc<CancelState>,
45    deadline: Instant,
46}
47
48#[derive(Clone)]
49struct WithoutCancelCtx {
50    parent: Context,
51}
52
53struct ValueCtx {
54    parent: Context,
55    key: Arc<dyn ValueKey>,
56    value: Arc<dyn Any + Send + Sync>,
57}
58
59impl Debug for Context {
60    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("Context").finish_non_exhaustive()
62    }
63}
64
65impl Context {
66    fn empty() -> Self {
67        Self {
68            inner: Arc::new(ContextInner::Empty),
69        }
70    }
71
72    fn cancelable(parent: Context, state: Arc<CancelState>) -> Self {
73        Self {
74            inner: Arc::new(ContextInner::Cancelable(CancelCtx { parent, state })),
75        }
76    }
77
78    fn new_deadline(parent: Context, state: Arc<CancelState>, deadline: Instant) -> Self {
79        Self {
80            inner: Arc::new(ContextInner::Deadline(DeadlineCtx {
81                parent,
82                state,
83                deadline,
84            })),
85        }
86    }
87
88    fn new_value(
89        parent: Context,
90        key: Arc<dyn ValueKey>,
91        value: Arc<dyn Any + Send + Sync>,
92    ) -> Self {
93        Self {
94            inner: Arc::new(ContextInner::Value(ValueCtx { parent, key, value })),
95        }
96    }
97
98    fn without_cancel(parent: Context) -> Self {
99        Self {
100            inner: Arc::new(ContextInner::WithoutCancel(WithoutCancelCtx { parent })),
101        }
102    }
103
104    pub fn deadline(&self) -> Option<Instant> {
105        match self.inner.as_ref() {
106            ContextInner::Empty => None,
107            ContextInner::Cancelable(ctx) => ctx.parent.deadline(),
108            ContextInner::Deadline(ctx) => Some(ctx.deadline),
109            ContextInner::Value(ctx) => ctx.parent.deadline(),
110            ContextInner::WithoutCancel(ctx) => ctx.parent.deadline(),
111        }
112    }
113
114    pub fn done(&self) -> DoneHandle {
115        match self.inner.as_ref() {
116            ContextInner::Empty => DoneHandle::never(),
117            ContextInner::Cancelable(ctx) => CancelState::done_handle(&ctx.state),
118            ContextInner::Deadline(ctx) => CancelState::done_handle(&ctx.state),
119            ContextInner::Value(ctx) => ctx.parent.done(),
120            ContextInner::WithoutCancel(_) => DoneHandle::never(),
121        }
122    }
123
124    /// 异步版 Done:返回一个 Future,在 Context 完成时立即 ready;Never 时永 pending。
125    pub fn done_async(&self) -> DoneFuture {
126        match self.done() {
127            DoneHandle::Never => DoneFuture::Never,
128            DoneHandle::Active(state) => {
129                if state.is_done() {
130                    return DoneFuture::Ready;
131                }
132                let notify = state.notify();
133                DoneFuture::Wait(Box::pin(async move {
134                    notify.notified().await;
135                }))
136            }
137        }
138    }
139
140    pub fn err(&self) -> Option<ContextError> {
141        match self.inner.as_ref() {
142            ContextInner::Empty => None,
143            ContextInner::Cancelable(ctx) => ctx.state.err(),
144            ContextInner::Deadline(ctx) => ctx.state.err(),
145            ContextInner::Value(ctx) => ctx.parent.err(),
146            ContextInner::WithoutCancel(_) => None,
147        }
148    }
149
150    pub fn cause(&self) -> Option<Arc<dyn Error + Send + Sync>> {
151        match self.inner.as_ref() {
152            ContextInner::Empty => None,
153            ContextInner::Cancelable(ctx) => ctx.state.cause(),
154            ContextInner::Deadline(ctx) => ctx.state.cause(),
155            ContextInner::Value(ctx) => ctx.parent.cause(),
156            ContextInner::WithoutCancel(_) => None,
157        }
158    }
159
160    pub fn value(&self, key: &dyn ValueKey) -> Option<Arc<dyn Any + Send + Sync>> {
161        match self.inner.as_ref() {
162            ContextInner::Value(ctx) => {
163                if ctx.key.equals(key) {
164                    Some(ctx.value.clone())
165                } else {
166                    ctx.parent.value(key)
167                }
168            }
169            ContextInner::Empty => None,
170            ContextInner::Cancelable(ctx) => ctx.parent.value(key),
171            ContextInner::Deadline(ctx) => ctx.parent.value(key),
172            ContextInner::WithoutCancel(ctx) => ctx.parent.value(key),
173        }
174    }
175}
176
177/// `Done()` 的异步版 Future。
178pub enum DoneFuture {
179    Ready,
180    Wait(Pin<Box<dyn Future<Output = ()> + Send + 'static>>),
181    Never,
182}
183
184impl Future for DoneFuture {
185    type Output = ();
186
187    fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
188        let this = unsafe { self.get_unchecked_mut() };
189        match this {
190            DoneFuture::Ready => Poll::Ready(()),
191            DoneFuture::Never => Poll::Pending,
192            DoneFuture::Wait(rx) => match rx.as_mut().poll(cx) {
193                Poll::Ready(_) => {
194                    *this = DoneFuture::Ready;
195                    Poll::Ready(())
196                }
197                Poll::Pending => Poll::Pending,
198            },
199        }
200    }
201}
202
203pub fn Background() -> Context {
204    static BG: OnceLock<Context> = OnceLock::new();
205    BG.get_or_init(Context::empty).clone()
206}
207
208pub fn TODO() -> Context {
209    static TD: OnceLock<Context> = OnceLock::new();
210    TD.get_or_init(Context::empty).clone()
211}
212
213pub fn WithoutCancel(parent: Context) -> Context {
214    Context::without_cancel(parent)
215}
216
217pub fn WithValue<K, V>(parent: Context, key: K, value: V) -> Context
218where
219    K: ValueKey,
220    V: Any + Send + Sync + 'static,
221{
222    Context::new_value(
223        parent,
224        Arc::new(key),
225        Arc::new(value) as Arc<dyn Any + Send + Sync>,
226    )
227}
228
229pub fn WithCancel(parent: Context) -> (Context, CancelFunc) {
230    WithCancelCause(parent).map_cancel(|f| Box::new(move || f(None)))
231}
232
233pub fn WithCancelCause(parent: Context) -> (Context, CancelCauseFunc) {
234    let state = CancelState::new();
235    propagate_parent(parent.clone(), state.clone());
236    let ctx = Context::cancelable(parent, state.clone());
237    let cancel = Box::new(move |cause: Option<Arc<dyn Error + Send + Sync>>| {
238        let final_cause = cause.or_else(default_canceled);
239        state.cancel(CancelKind::Canceled, final_cause);
240    });
241    (ctx, cancel)
242}
243
244pub fn WithDeadline(parent: Context, deadline: Instant) -> (Context, CancelFunc) {
245    WithDeadlineCause(parent, deadline, None).map_cancel(|f| Box::new(move || f()))
246}
247
248pub fn WithDeadlineCause(
249    parent: Context,
250    deadline: Instant,
251    cause: Option<Arc<dyn Error + Send + Sync>>,
252) -> (Context, CancelFunc) {
253    let effective_deadline = match parent.deadline() {
254        Some(parent_deadline) if parent_deadline <= deadline => parent_deadline,
255        _ => deadline,
256    };
257
258    let state = CancelState::new();
259    let ctx = Context::new_deadline(parent.clone(), state.clone(), effective_deadline);
260    propagate_parent(parent, state.clone());
261    start_deadline_timer(state.clone(), effective_deadline, cause.clone());
262
263    let cancel = Box::new(move || {
264        let cancel_cause = cause.clone().or_else(default_canceled);
265        state.cancel(CancelKind::Canceled, cancel_cause);
266    });
267    (ctx, cancel)
268}
269
270pub fn WithTimeout(parent: Context, timeout: Duration) -> (Context, CancelFunc) {
271    WithDeadline(parent, Instant::now() + timeout)
272}
273
274pub fn WithTimeoutCause(
275    parent: Context,
276    timeout: Duration,
277    cause: Option<Arc<dyn Error + Send + Sync>>,
278) -> (Context, CancelFunc) {
279    WithDeadlineCause(parent, Instant::now() + timeout, cause)
280}
281
282pub fn AfterFunc(ctx: &Context, f: impl FnOnce() + Send + 'static) -> StopFunc {
283    ctx.done().register(f)
284}
285
286pub fn Cause(ctx: &Context) -> Option<Arc<dyn Error + Send + Sync>> {
287    ctx.cause()
288}
289
290/// 异步等待上下文完成,返回 Err 时即 ctx.err()。
291pub async fn Done(ctx: Context) -> Option<ContextError> {
292    ctx.done_async().await;
293    ctx.err()
294}
295
296/// 并发等待业务 Future 与 ctx 完成(取消/超时)。ctx 完成时优先返回其错误。
297pub async fn ContextAware<T, F>(ctx: Context, fut: F) -> Result<T, ContextError>
298where
299    F: Future<Output = Result<T, ContextError>>,
300{
301    let done = ctx.done_async();
302    tokio::select! {
303        res = fut => res,
304        _ = done => Err(ctx.err().unwrap_or(CANCELLED)),
305    }
306}
307
308fn start_deadline_timer(
309    state: Arc<CancelState>,
310    deadline: Instant,
311    cause: Option<Arc<dyn Error + Send + Sync>>,
312) {
313    if deadline <= Instant::now() {
314        let deadline_cause = cause.clone().or_else(default_deadline);
315        state.cancel(CancelKind::Deadline, deadline_cause);
316        return;
317    }
318    let sleep_dur = deadline.saturating_duration_since(Instant::now());
319    if let Ok(handle) = Handle::try_current() {
320        handle.spawn(async move {
321            tokio::time::sleep(sleep_dur).await;
322            let deadline_cause = cause.clone().or_else(default_deadline);
323            state.cancel(CancelKind::Deadline, deadline_cause);
324        });
325    } else {
326        thread::spawn(move || {
327            thread::sleep(sleep_dur);
328            let deadline_cause = cause.clone().or_else(default_deadline);
329            state.cancel(CancelKind::Deadline, deadline_cause);
330        });
331    }
332}
333
334fn propagate_parent(parent: Context, state: Arc<CancelState>) {
335    if state.is_done() {
336        return;
337    }
338    if let Some(err) = parent.err() {
339        let kind = map_error_kind(&err);
340        let inherited = parent
341            .cause()
342            .or_else(|| Some(Arc::new(err) as Arc<dyn Error + Send + Sync>));
343        state.cancel(kind, inherited);
344        return;
345    }
346    let done = parent.done();
347    done.register(move || {
348        let err = parent.err();
349        let kind = err
350            .as_ref()
351            .map(map_error_kind)
352            .unwrap_or(CancelKind::Canceled);
353        let inherited = parent
354            .cause()
355            .or_else(|| err.map(|e| Arc::new(e) as Arc<dyn Error + Send + Sync>));
356        state.cancel(kind, inherited);
357    });
358}
359
360fn map_error_kind(err: &ContextError) -> CancelKind {
361    match err.kind() {
362        ContextErrorKind::Canceled => CancelKind::Canceled,
363        ContextErrorKind::DeadlineExceeded => CancelKind::Deadline,
364    }
365}
366
367fn default_canceled() -> Option<Arc<dyn Error + Send + Sync>> {
368    Some(Arc::new(CANCELLED) as Arc<dyn Error + Send + Sync>)
369}
370
371fn default_deadline() -> Option<Arc<dyn Error + Send + Sync>> {
372    Some(Arc::new(DEADLINE_EXCEEDED) as Arc<dyn Error + Send + Sync>)
373}
374
375trait MapCancel<T> {
376    fn map_cancel(self, f: impl FnOnce(T) -> CancelFunc) -> (Context, CancelFunc);
377}
378
379impl MapCancel<CancelCauseFunc> for (Context, CancelCauseFunc) {
380    fn map_cancel(self, f: impl FnOnce(CancelCauseFunc) -> CancelFunc) -> (Context, CancelFunc) {
381        let (ctx, c) = self;
382        (ctx, f(c))
383    }
384}
385
386impl MapCancel<CancelFunc> for (Context, CancelFunc) {
387    fn map_cancel(self, f: impl FnOnce(CancelFunc) -> CancelFunc) -> (Context, CancelFunc) {
388        let (ctx, c) = self;
389        (ctx, f(c))
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use std::sync::atomic::{AtomicBool, Ordering};
397    use std::thread::sleep;
398
399    fn assert_canceled(ctx: &Context) {
400        let err = ctx.err().expect("expected canceled");
401        assert_eq!(err.kind(), ContextErrorKind::Canceled);
402    }
403
404    #[test]
405    fn background_never_cancels() {
406        let ctx = Background();
407        assert!(ctx.deadline().is_none());
408        assert!(ctx.done().is_done() == false);
409        assert!(ctx.err().is_none());
410        assert!(ctx.cause().is_none());
411        assert!(ctx.value(&"k").is_none());
412    }
413
414    #[test]
415    fn cancel_func_cancels() {
416        let (ctx, cancel) = WithCancel(Background());
417        cancel();
418        assert_canceled(&ctx);
419        assert!(matches!(Cause(&ctx), Some(_)));
420    }
421
422    #[test]
423    fn cancel_cause_propagates() {
424        #[derive(Debug)]
425        struct MyErr;
426        impl std::fmt::Display for MyErr {
427            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428                f.write_str("mine")
429            }
430        }
431        impl Error for MyErr {}
432
433        let (ctx, cancel) = WithCancelCause(Background());
434        let err = Arc::new(MyErr) as Arc<dyn Error + Send + Sync>;
435        cancel(Some(err.clone()));
436        assert_canceled(&ctx);
437        let cause = Cause(&ctx).unwrap();
438        assert!(cause.downcast_ref::<MyErr>().is_some());
439    }
440
441    #[test]
442    fn parent_deadline_cancels_child() {
443        let (parent, _) = WithTimeout(Background(), Duration::from_millis(50));
444        let (child, _) = WithCancel(parent);
445        sleep(Duration::from_millis(80));
446        let err = child.err().expect("child canceled");
447        assert_eq!(err.kind(), ContextErrorKind::DeadlineExceeded);
448    }
449
450    #[test]
451    fn deadline_timer_triggers() {
452        let (ctx, _) = WithDeadline(Background(), Instant::now() + Duration::from_millis(30));
453        sleep(Duration::from_millis(60));
454        let err = ctx.err().expect("deadline");
455        assert_eq!(err.kind(), ContextErrorKind::DeadlineExceeded);
456    }
457
458    #[test]
459    fn deadline_cause_used() {
460        #[derive(Debug)]
461        struct CauseErr;
462        impl std::fmt::Display for CauseErr {
463            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
464                f.write_str("cause")
465            }
466        }
467        impl Error for CauseErr {}
468
469        let (ctx, cancel) = WithDeadlineCause(
470            Background(),
471            Instant::now() + Duration::from_millis(100),
472            Some(Arc::new(CauseErr)),
473        );
474        cancel();
475        let cause = Cause(&ctx).unwrap();
476        assert!(cause.downcast_ref::<CauseErr>().is_some());
477    }
478
479    #[test]
480    fn value_lookup_respects_hierarchy() {
481        let root = WithValue(Background(), "a", 1u32);
482        let child = WithValue(root, "b", 2u32);
483        let val_a = child.value(&"a").unwrap();
484        let val_b = child.value(&"b").unwrap();
485        assert_eq!(*val_a.downcast::<u32>().unwrap(), 1);
486        assert_eq!(*val_b.downcast::<u32>().unwrap(), 2);
487    }
488
489    #[test]
490    fn without_cancel_detaches() {
491        let (parent, cancel) = WithCancel(Background());
492        let child = WithoutCancel(parent);
493        cancel();
494        assert!(child.err().is_none());
495        assert!(child.cause().is_none());
496        assert!(child.done().is_done() == false);
497    }
498
499    #[test]
500    fn after_func_runs_on_cancel() {
501        let (ctx, cancel) = WithCancel(Background());
502        let flag = Arc::new(AtomicBool::new(false));
503        let mark = flag.clone();
504        AfterFunc(&ctx, move || {
505            mark.store(true, Ordering::SeqCst);
506        });
507        cancel();
508        ctx.done().wait();
509        std::thread::sleep(Duration::from_millis(10));
510        assert!(flag.load(Ordering::SeqCst));
511    }
512
513    #[test]
514    fn after_func_stop_on_never_done() {
515        let stop = AfterFunc(&Background(), || panic!("should not run"));
516        assert!(!stop.Stop());
517    }
518}