core/context/
impls.rs

1#![allow(non_snake_case)]
2use crate::context::error::{CANCELLED, ContextError, DEADLINE_EXCEEDED, Error};
3use crate::context::state::{CancelState, DoneHandle, StopFunc};
4use crate::context::value::ValueKey;
5use std::any::Any;
6use std::fmt::{Debug, Formatter};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::OnceLock;
11use std::task::{Context as TaskContext, Poll};
12use std::thread;
13use std::time::{Duration, Instant};
14use tokio::runtime::Handle;
15
16/// 与 Go `context` 包一致的取消函数。
17pub type CancelFunc = Box<dyn FnOnce() + Send + 'static>;
18/// 与 Go `context` 包一致的带 cause 取消函数。
19pub type CancelCauseFunc =
20    Box<dyn FnOnce(Option<Arc<dyn std::error::Error + Send + Sync>>) + Send + 'static>;
21
22pub struct Context {
23    inner: Arc<ContextInner>,
24}
25
26enum ContextInner {
27    Empty,
28    Cancelable(CancelCtx),
29    Deadline(DeadlineCtx),
30    Value(ValueCtx),
31    WithoutCancel(WithoutCancelCtx),
32}
33
34struct CancelCtx {
35    parent: Context,
36    state: Arc<CancelState>,
37}
38
39struct DeadlineCtx {
40    parent: Context,
41    state: Arc<CancelState>,
42    deadline: Instant,
43}
44
45struct WithoutCancelCtx {
46    parent: Context,
47}
48
49struct ValueCtx {
50    parent: Context,
51    key: Arc<dyn ValueKey>,
52    value: Arc<dyn Any + Send + Sync>,
53}
54
55impl Clone for Context {
56    fn clone(&self) -> Self {
57        if let Some(state) = self.state_arc_opt() {
58            state.add_handle();
59        }
60        Self {
61            inner: self.inner.clone(),
62        }
63    }
64}
65
66impl Drop for Context {
67    fn drop(&mut self) {
68        if let Some(state) = self.state_arc_opt() {
69            state.release_handle();
70        }
71    }
72}
73
74impl Debug for Context {
75    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("Context").finish_non_exhaustive()
77    }
78}
79
80impl Context {
81    fn empty() -> Self {
82        Self {
83            inner: Arc::new(ContextInner::Empty),
84        }
85    }
86
87    fn cancelable(parent: Context, state: Arc<CancelState>) -> Self {
88        Self {
89            inner: Arc::new(ContextInner::Cancelable(CancelCtx { parent, state })),
90        }
91    }
92
93    fn new_deadline(parent: Context, state: Arc<CancelState>, deadline: Instant) -> Self {
94        Self {
95            inner: Arc::new(ContextInner::Deadline(DeadlineCtx {
96                parent,
97                state,
98                deadline,
99            })),
100        }
101    }
102
103    fn new_value(
104        parent: Context,
105        key: Arc<dyn ValueKey>,
106        value: Arc<dyn Any + Send + Sync>,
107    ) -> Self {
108        Self {
109            inner: Arc::new(ContextInner::Value(ValueCtx { parent, key, value })),
110        }
111    }
112
113    fn without_cancel(parent: Context) -> Self {
114        Self {
115            inner: Arc::new(ContextInner::WithoutCancel(WithoutCancelCtx { parent })),
116        }
117    }
118
119    pub fn deadline(&self) -> Option<Instant> {
120        match self.inner.as_ref() {
121            ContextInner::Empty => None,
122            ContextInner::Cancelable(ctx) => ctx.parent.deadline(),
123            ContextInner::Deadline(ctx) => Some(ctx.deadline),
124            ContextInner::Value(ctx) => ctx.parent.deadline(),
125            ContextInner::WithoutCancel(ctx) => ctx.parent.deadline(),
126        }
127    }
128
129    pub fn done(&self) -> DoneHandle {
130        match self.inner.as_ref() {
131            ContextInner::Empty => DoneHandle::never(),
132            ContextInner::Cancelable(ctx) => CancelState::done_handle(&ctx.state),
133            ContextInner::Deadline(ctx) => CancelState::done_handle(&ctx.state),
134            ContextInner::Value(ctx) => ctx.parent.done(),
135            ContextInner::WithoutCancel(_) => DoneHandle::never(),
136        }
137    }
138
139    /// 异步版 Done:返回一个 Future,在 Context 完成时立即 ready;Never 时永 pending。
140    pub fn done_async(&self) -> DoneFuture {
141        match self.done() {
142            DoneHandle::Never => DoneFuture::Never,
143            DoneHandle::Active(state) => {
144                if state.is_done() {
145                    return DoneFuture::Ready;
146                }
147                let notify = state.notify();
148                DoneFuture::Wait(Box::pin(async move {
149                    notify.notified().await;
150                }))
151            }
152        }
153    }
154
155    pub fn err(&self) -> Option<ContextError> {
156        match self.inner.as_ref() {
157            ContextInner::Empty => None,
158            ContextInner::Cancelable(ctx) => ctx.state.err(),
159            ContextInner::Deadline(ctx) => ctx.state.err(),
160            ContextInner::Value(ctx) => ctx.parent.err(),
161            ContextInner::WithoutCancel(_) => None,
162        }
163    }
164
165    pub fn cause(&self) -> Option<Arc<dyn std::error::Error + Send + Sync>> {
166        match self.inner.as_ref() {
167            ContextInner::Empty => None,
168            ContextInner::Cancelable(ctx) => ctx.state.cause(),
169            ContextInner::Deadline(ctx) => ctx.state.cause(),
170            ContextInner::Value(ctx) => ctx.parent.cause(),
171            ContextInner::WithoutCancel(_) => None,
172        }
173    }
174
175    pub fn value(&self, key: &dyn ValueKey) -> Option<Arc<dyn Any + Send + Sync>> {
176        match self.inner.as_ref() {
177            ContextInner::Value(ctx) => {
178                if ctx.key.equals(key) {
179                    Some(ctx.value.clone())
180                } else {
181                    ctx.parent.value(key)
182                }
183            }
184            ContextInner::Empty => None,
185            ContextInner::Cancelable(ctx) => ctx.parent.value(key),
186            ContextInner::Deadline(ctx) => ctx.parent.value(key),
187            ContextInner::WithoutCancel(ctx) => ctx.parent.value(key),
188        }
189    }
190
191    fn state_arc(&self) -> Arc<CancelState> {
192        match self.inner.as_ref() {
193            ContextInner::Cancelable(ctx) => ctx.state.clone(),
194            ContextInner::Deadline(ctx) => ctx.state.clone(),
195            ContextInner::Value(ctx) => ctx.parent.state_arc(),
196            ContextInner::WithoutCancel(_) | ContextInner::Empty => CancelState::new_root(),
197        }
198    }
199
200    fn state_arc_opt(&self) -> Option<Arc<CancelState>> {
201        match self.inner.as_ref() {
202            ContextInner::Cancelable(ctx) => Some(ctx.state.clone()),
203            ContextInner::Deadline(ctx) => Some(ctx.state.clone()),
204            ContextInner::Value(ctx) => ctx.parent.state_arc_opt(),
205            ContextInner::WithoutCancel(_) | ContextInner::Empty => None,
206        }
207    }
208}
209
210/// `Done()` 的异步版 Future。
211pub enum DoneFuture {
212    Ready,
213    Wait(Pin<Box<dyn Future<Output = ()> + Send + 'static>>),
214    Never,
215}
216
217impl Future for DoneFuture {
218    type Output = ();
219
220    fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll<Self::Output> {
221        let this = unsafe { self.get_unchecked_mut() };
222        match this {
223            DoneFuture::Ready => Poll::Ready(()),
224            DoneFuture::Never => Poll::Pending,
225            DoneFuture::Wait(rx) => match rx.as_mut().poll(cx) {
226                Poll::Ready(_) => {
227                    *this = DoneFuture::Ready;
228                    Poll::Ready(())
229                }
230                Poll::Pending => Poll::Pending,
231            },
232        }
233    }
234}
235
236pub fn Background() -> Context {
237    static BACKGROUND_CTX: OnceLock<Context> = OnceLock::new();
238    BACKGROUND_CTX.get_or_init(Context::empty).clone()
239}
240
241pub fn TODO() -> Context {
242    static TODO_CTX: OnceLock<Context> = OnceLock::new();
243    TODO_CTX.get_or_init(Context::empty).clone()
244}
245
246pub fn WithoutCancel(parent: Context) -> Context {
247    Context::without_cancel(parent)
248}
249
250pub fn WithValue<K, V>(parent: Context, key: K, value: V) -> Context
251where
252    K: ValueKey,
253    V: Any + Send + Sync + 'static,
254{
255    Context::new_value(
256        parent,
257        Arc::new(key),
258        Arc::new(value) as Arc<dyn Any + Send + Sync>,
259    )
260}
261
262pub fn WithCancel(parent: Context) -> (Context, CancelFunc) {
263    WithCancelCause(parent).map_cancel(|f| Box::new(move || f(None)))
264}
265
266pub fn WithCancelCause(parent: Context) -> (Context, CancelCauseFunc) {
267    let state = match parent.inner.as_ref() {
268        ContextInner::Empty => CancelState::new_root(),
269        _ => CancelState::child_of(&parent.state_arc()),
270    };
271    propagate_parent(parent.clone(), state.clone());
272    let ctx = Context::cancelable(parent, state.clone());
273    let cancel = Box::new(
274        move |cause: Option<Arc<dyn std::error::Error + Send + Sync>>| {
275            let final_cause = cause.or_else(default_canceled);
276            state.cancel(Error::Canceled, final_cause);
277        },
278    );
279    (ctx, cancel)
280}
281
282pub fn WithDeadline(parent: Context, deadline: Instant) -> (Context, CancelFunc) {
283    WithDeadlineCause(parent, deadline, None).map_cancel(|f| Box::new(move || f()))
284}
285
286pub fn WithDeadlineCause(
287    parent: Context,
288    deadline: Instant,
289    cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
290) -> (Context, CancelFunc) {
291    let effective_deadline = match parent.deadline() {
292        Some(parent_deadline) if parent_deadline <= deadline => parent_deadline,
293        _ => deadline,
294    };
295
296    let state = match parent.inner.as_ref() {
297        ContextInner::Empty => CancelState::new_root(),
298        _ => CancelState::child_of(&parent.state_arc()),
299    };
300    let ctx = Context::new_deadline(parent.clone(), state.clone(), effective_deadline);
301    propagate_parent(parent, state.clone());
302    start_deadline_timer(state.clone(), effective_deadline, cause.clone());
303
304    let cancel = Box::new(move || {
305        let cancel_cause = cause.clone().or_else(default_canceled);
306        state.cancel(Error::Canceled, cancel_cause);
307    });
308    (ctx, cancel)
309}
310
311pub fn WithTimeout(parent: Context, timeout: Duration) -> (Context, CancelFunc) {
312    WithDeadline(parent, Instant::now() + timeout)
313}
314
315pub fn WithTimeoutCause(
316    parent: Context,
317    timeout: Duration,
318    cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
319) -> (Context, CancelFunc) {
320    WithDeadlineCause(parent, Instant::now() + timeout, cause)
321}
322
323pub fn AfterFunc(ctx: &Context, f: impl FnOnce() + Send + 'static) -> StopFunc {
324    ctx.done().register(f)
325}
326
327pub fn Cause(ctx: &Context) -> Option<Arc<dyn std::error::Error + Send + Sync>> {
328    ctx.cause()
329}
330
331/// 异步等待上下文完成,返回 Err 时即 ctx.err()。
332pub async fn Done(ctx: Context) -> Option<ContextError> {
333    ctx.done_async().await;
334    ctx.err()
335}
336
337/// 并发等待业务 Future 与 ctx 完成(取消/超时)。ctx 完成时优先返回其错误。
338pub async fn ContextAware<T, F>(ctx: Context, fut: F) -> Result<T, ContextError>
339where
340    F: Future<Output = Result<T, ContextError>>,
341{
342    let done = ctx.done_async();
343    tokio::select! {
344        res = fut => res,
345        _ = done => Err(ctx.err().unwrap_or(CANCELLED)),
346    }
347}
348
349fn start_deadline_timer(
350    state: Arc<CancelState>,
351    deadline: Instant,
352    cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
353) {
354    if deadline <= Instant::now() {
355        let deadline_cause = cause.clone().or_else(default_deadline);
356        state.cancel(Error::DeadlineExceeded, deadline_cause);
357        return;
358    }
359    let sleep_dur = deadline.saturating_duration_since(Instant::now());
360    if let Ok(handle) = Handle::try_current() {
361        handle.spawn(async move {
362            tokio::time::sleep(sleep_dur).await;
363            let deadline_cause = cause.clone().or_else(default_deadline);
364            state.cancel(Error::DeadlineExceeded, deadline_cause);
365        });
366    } else {
367        thread::spawn(move || {
368            thread::sleep(sleep_dur);
369            let deadline_cause = cause.clone().or_else(default_deadline);
370            state.cancel(Error::DeadlineExceeded, deadline_cause);
371        });
372    }
373}
374
375fn propagate_parent(parent: Context, state: Arc<CancelState>) {
376    if state.is_done() {
377        return;
378    }
379    if let Some(err) = parent.err() {
380        let kind = map_error_kind(&err);
381        let inherited = parent
382            .cause()
383            .or_else(|| Some(Arc::new(err) as Arc<dyn std::error::Error + Send + Sync>));
384        state.cancel(kind, inherited);
385        return;
386    }
387    let done = parent.done();
388    done.register(move || {
389        let err = parent.err();
390        let kind = err.as_ref().map(map_error_kind).unwrap_or(Error::Canceled);
391        let inherited = parent
392            .cause()
393            .or_else(|| err.map(|e| Arc::new(e) as Arc<dyn std::error::Error + Send + Sync>));
394        state.cancel(kind, inherited);
395    });
396}
397
398fn map_error_kind(err: &ContextError) -> Error {
399    err.kind()
400}
401
402fn default_canceled() -> Option<Arc<dyn std::error::Error + Send + Sync>> {
403    Some(Arc::new(CANCELLED) as Arc<dyn std::error::Error + Send + Sync>)
404}
405
406fn default_deadline() -> Option<Arc<dyn std::error::Error + Send + Sync>> {
407    Some(Arc::new(DEADLINE_EXCEEDED) as Arc<dyn std::error::Error + Send + Sync>)
408}
409
410trait MapCancel<T> {
411    fn map_cancel(self, f: impl FnOnce(T) -> CancelFunc) -> (Context, CancelFunc);
412}
413
414impl MapCancel<CancelCauseFunc> for (Context, CancelCauseFunc) {
415    fn map_cancel(self, f: impl FnOnce(CancelCauseFunc) -> CancelFunc) -> (Context, CancelFunc) {
416        let (ctx, c) = self;
417        (ctx, f(c))
418    }
419}
420
421impl MapCancel<CancelFunc> for (Context, CancelFunc) {
422    fn map_cancel(self, f: impl FnOnce(CancelFunc) -> CancelFunc) -> (Context, CancelFunc) {
423        let (ctx, c) = self;
424        (ctx, f(c))
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use std::sync::atomic::{AtomicBool, Ordering};
432    use std::thread::sleep;
433
434    fn assert_canceled(ctx: &Context) {
435        let err = ctx.err().expect("expected canceled");
436        assert_eq!(err.kind(), Error::Canceled);
437    }
438
439    #[test]
440    fn background_never_cancels() {
441        let ctx = Background();
442        assert!(ctx.deadline().is_none());
443        assert!(ctx.done().is_done() == false);
444        assert!(ctx.err().is_none());
445        assert!(ctx.cause().is_none());
446        assert!(ctx.value(&"k").is_none());
447    }
448
449    #[test]
450    fn cancel_func_cancels() {
451        let (ctx, cancel) = WithCancel(Background());
452        cancel();
453        assert_canceled(&ctx);
454        assert!(matches!(Cause(&ctx), Some(_)));
455    }
456
457    #[test]
458    fn cancel_cause_propagates() {
459        #[derive(Debug)]
460        struct MyErr;
461        impl std::fmt::Display for MyErr {
462            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
463                f.write_str("mine")
464            }
465        }
466        impl std::error::Error for MyErr {}
467
468        let (ctx, cancel) = WithCancelCause(Background());
469        let err = Arc::new(MyErr) as Arc<dyn std::error::Error + Send + Sync>;
470        cancel(Some(err.clone()));
471        assert_canceled(&ctx);
472        let cause = Cause(&ctx).unwrap();
473        assert!(cause.downcast_ref::<MyErr>().is_some());
474    }
475
476    #[test]
477    fn parent_deadline_cancels_child() {
478        let (parent, _) = WithTimeout(Background(), Duration::from_millis(50));
479        let (child, _) = WithCancel(parent);
480        sleep(Duration::from_millis(80));
481        let err = child.err().expect("child canceled");
482        assert_eq!(err.kind(), Error::DeadlineExceeded);
483    }
484
485    #[test]
486    fn deadline_timer_triggers() {
487        let (ctx, _) = WithDeadline(Background(), Instant::now() + Duration::from_millis(30));
488        sleep(Duration::from_millis(60));
489        let err = ctx.err().expect("deadline");
490        assert_eq!(err.kind(), Error::DeadlineExceeded);
491    }
492
493    #[test]
494    fn deadline_cause_used() {
495        #[derive(Debug)]
496        struct CauseErr;
497        impl std::fmt::Display for CauseErr {
498            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
499                f.write_str("cause")
500            }
501        }
502        impl std::error::Error for CauseErr {}
503
504        let (ctx, cancel) = WithDeadlineCause(
505            Background(),
506            Instant::now() + Duration::from_millis(100),
507            Some(Arc::new(CauseErr)),
508        );
509        cancel();
510        let cause = Cause(&ctx).unwrap();
511        assert!(cause.downcast_ref::<CauseErr>().is_some());
512    }
513
514    #[test]
515    fn value_lookup_respects_hierarchy() {
516        let root = WithValue(Background(), "a", 1u32);
517        let child = WithValue(root, "b", 2u32);
518        let val_a = child.value(&"a").unwrap();
519        let val_b = child.value(&"b").unwrap();
520        assert_eq!(*val_a.downcast::<u32>().unwrap(), 1);
521        assert_eq!(*val_b.downcast::<u32>().unwrap(), 2);
522    }
523
524    #[test]
525    fn without_cancel_detaches() {
526        let (parent, cancel) = WithCancel(Background());
527        let child = WithoutCancel(parent);
528        cancel();
529        assert!(child.err().is_none());
530        assert!(child.cause().is_none());
531        assert!(child.done().is_done() == false);
532    }
533
534    #[test]
535    fn after_func_runs_on_cancel() {
536        let (ctx, cancel) = WithCancel(Background());
537        let flag = Arc::new(AtomicBool::new(false));
538        let mark = flag.clone();
539        AfterFunc(&ctx, move || {
540            mark.store(true, Ordering::SeqCst);
541        });
542        cancel();
543        ctx.done().wait();
544        std::thread::sleep(Duration::from_millis(10));
545        assert!(flag.load(Ordering::SeqCst));
546    }
547
548    #[test]
549    fn after_func_stop_on_never_done() {
550        let stop = AfterFunc(&Background(), || panic!("should not run"));
551        assert!(!stop.Stop());
552    }
553}