core/context/
state.rs

1use crate::context::error::{CANCELLED, ContextError, DEADLINE_EXCEEDED, Error};
2use std::fmt::{Debug, Formatter};
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Condvar, Mutex, Weak};
5use std::time::{Duration, Instant};
6use tokio::sync::Notify;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9enum Status {
10    Active,
11    Canceled,
12    DeadlineExceeded,
13}
14
15/// CancelState 内部数据。
16/// 不变量与锁序:
17/// - 锁顺序:始终先锁父再锁子,避免死锁。
18/// - parent_idx:在父 children 中的位置,detach 时用 O(1) 位置移除并更新被交换节点。
19/// - handle_count:Context 克隆/Drop 维护;handle_count 为 0 且 children 为空且 done 时向上裁剪。
20/// - callbacks_head:侵入式回调链表头,只能在持有 self.inner 锁时读写。
21struct Inner {
22    status: Status,
23    cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
24    parent: Option<Weak<CancelState>>,
25    children: Vec<Weak<CancelState>>,
26    parent_idx: Option<usize>,
27    done: bool,
28    handle_count: AtomicUsize,
29    next_id: usize,
30    callbacks_head: *mut CallbackNode,
31}
32
33unsafe impl Send for Inner {}
34unsafe impl Sync for Inner {}
35
36/// 侵入式回调节点,仿 tokio-util 链表。
37pub(crate) struct CallbackNode {
38    #[allow(dead_code)]
39    id: usize,
40    callback: Option<Box<dyn FnOnce() + Send + 'static>>,
41    next: *mut CallbackNode,
42}
43
44unsafe impl Send for CallbackNode {}
45unsafe impl Sync for CallbackNode {}
46
47/// 取消状态核心(树形结构,紧贴 tokio-util 设计)。
48pub struct CancelState {
49    inner: Mutex<Inner>,
50    cvar: Condvar,
51    notify: Arc<Notify>,
52}
53
54impl CancelState {
55    pub fn new_root() -> Arc<Self> {
56        Arc::new(Self {
57            inner: Mutex::new(Inner {
58                status: Status::Active,
59                cause: None,
60                parent: None,
61                children: Vec::new(),
62                parent_idx: None,
63                done: false,
64                handle_count: AtomicUsize::new(1),
65                next_id: 0,
66                callbacks_head: std::ptr::null_mut(),
67            }),
68            cvar: Condvar::new(),
69            notify: Arc::new(Notify::new()),
70        })
71    }
72
73    pub fn child_of(parent: &Arc<Self>) -> Arc<Self> {
74        let child = Arc::new(Self {
75            inner: Mutex::new(Inner {
76                status: Status::Active,
77                cause: None,
78                parent: Some(Arc::downgrade(parent)),
79                children: Vec::new(),
80                parent_idx: None,
81                done: false,
82                handle_count: AtomicUsize::new(1),
83                next_id: 0,
84                callbacks_head: std::ptr::null_mut(),
85            }),
86            cvar: Condvar::new(),
87            notify: Arc::new(Notify::new()),
88        });
89
90        // 挂载到父节点(锁父后推入)
91        let weak_child = Arc::downgrade(&child);
92        let mut guard = parent.inner.lock().unwrap();
93        let idx = guard.children.len();
94        guard.children.push(weak_child);
95        drop(guard);
96        child.inner.lock().unwrap().parent_idx = Some(idx);
97
98        child
99    }
100
101    pub fn done_handle(this: &Arc<Self>) -> DoneHandle {
102        DoneHandle::Active(this.clone())
103    }
104
105    pub fn err(&self) -> Option<ContextError> {
106        let guard = self.inner.lock().unwrap();
107        match guard.status {
108            Status::Active => None,
109            Status::Canceled => Some(ContextError::with_cause(
110                Error::Canceled,
111                guard.cause.clone(),
112            )),
113            Status::DeadlineExceeded => Some(ContextError::with_cause(
114                Error::DeadlineExceeded,
115                guard.cause.clone(),
116            )),
117        }
118    }
119
120    pub fn cause(&self) -> Option<Arc<dyn std::error::Error + Send + Sync>> {
121        self.inner.lock().unwrap().cause.clone()
122    }
123
124    pub fn is_done(&self) -> bool {
125        self.inner.lock().unwrap().done
126    }
127
128    pub fn add_handle(&self) {
129        let guard = self.inner.lock().unwrap();
130        guard.handle_count.fetch_add(1, Ordering::Relaxed);
131    }
132
133    pub fn release_handle(self: &Arc<Self>) {
134        let mut current = self.clone();
135        loop {
136            let parent_opt = {
137                let mut guard = current.inner.lock().unwrap();
138                guard.handle_count.fetch_sub(1, Ordering::Relaxed);
139                guard.children.retain(|w| w.upgrade().is_some());
140                let has_handles = guard.handle_count.load(Ordering::Relaxed) > 0;
141                let has_children = !guard.children.is_empty();
142                let done = guard.done;
143                let parent = guard.parent.clone();
144                let parent_idx = guard.parent_idx;
145                drop(guard);
146                if has_handles || has_children || !done {
147                    return;
148                }
149                parent.zip(parent_idx)
150            };
151
152            let Some((parent_weak, idx)) = parent_opt else {
153                return;
154            };
155            let Some(parent) = parent_weak.upgrade() else {
156                return;
157            };
158
159            if !Self::detach_from_parent_idx(&parent, &current, idx) {
160                return;
161            }
162            current = parent;
163        }
164    }
165
166    pub fn cancel(
167        self: &Arc<Self>,
168        kind: Error,
169        cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
170    ) {
171        // 尝试标记自身
172        let cause_for_self = cause.clone();
173        let (callbacks, children, notify_needed) = {
174            let mut guard = self.inner.lock().unwrap();
175            if guard.done {
176                return;
177            }
178
179            guard.status = match kind {
180                Error::Canceled => Status::Canceled,
181                Error::DeadlineExceeded => Status::DeadlineExceeded,
182            };
183            if guard.cause.is_none() {
184                guard.cause = cause_for_self.clone().or_else(|| {
185                    let err = match kind {
186                        Error::Canceled => CANCELLED,
187                        Error::DeadlineExceeded => DEADLINE_EXCEEDED,
188                    };
189                    Some(Arc::new(err) as Arc<dyn std::error::Error + Send + Sync>)
190                });
191            }
192            guard.done = true;
193
194            let head = guard.callbacks_head;
195            guard.callbacks_head = std::ptr::null_mut();
196            let callbacks = if head.is_null() {
197                Vec::new()
198            } else {
199                Self::drain_callbacks(head)
200            };
201
202            let children = guard
203                .children
204                .iter()
205                .filter_map(|w| w.upgrade())
206                .collect::<Vec<_>>();
207            (callbacks, children, true)
208        };
209
210        if notify_needed {
211            self.notify.notify_waiters();
212            self.cvar.notify_all();
213        }
214
215        // 直接执行回调,避免额外 spawn
216        for cb in callbacks {
217            cb();
218        }
219
220        // 递归取消子节点
221        for child in children {
222            child.cancel(kind, cause.clone());
223        }
224
225        // 若已完成且无子节点,向上裁剪
226        self.prune_if_detached();
227    }
228
229    fn prune_if_detached(self: &Arc<Self>) {
230        let mut current = self.clone();
231        loop {
232            let parent_info = {
233                let mut guard = current.inner.lock().unwrap();
234                if !guard.done {
235                    return;
236                }
237                guard.children.retain(|w| w.upgrade().is_some());
238                if !guard.children.is_empty() {
239                    return;
240                }
241                guard.parent.clone().zip(guard.parent_idx)
242            };
243
244            let Some((parent_weak, idx)) = parent_info else {
245                return;
246            };
247            let Some(parent) = parent_weak.upgrade() else {
248                return;
249            };
250
251            if !Self::detach_from_parent_idx(&parent, &current, idx) {
252                return;
253            }
254
255            current = parent;
256        }
257    }
258
259    pub fn notify(&self) -> Arc<Notify> {
260        self.notify.clone()
261    }
262
263    fn drain_callbacks(mut head: *mut CallbackNode) -> Vec<Box<dyn FnOnce() + Send + 'static>> {
264        let mut out = Vec::new();
265        while !head.is_null() {
266            let node = unsafe { Box::from_raw(head) };
267            if let Some(callback) = node.callback {
268                out.push(callback);
269            }
270            head = node.next;
271        }
272        out
273    }
274
275    pub fn register(
276        &self,
277        owner: Arc<CancelState>,
278        cb: Box<dyn FnOnce() + Send + 'static>,
279    ) -> StopFunc {
280        let mut guard = self.inner.lock().unwrap();
281        if guard.done {
282            drop(guard);
283            cb();
284            return StopFunc::noop();
285        }
286        let id = guard.next_id;
287        guard.next_id += 1;
288        let node = Box::new(CallbackNode {
289            id,
290            callback: Some(cb),
291            next: guard.callbacks_head,
292        });
293        let ptr = Box::into_raw(node);
294        guard.callbacks_head = ptr;
295        StopFunc::new(owner, ptr)
296    }
297
298    pub(crate) fn remove(&self, ptr: *mut CallbackNode) -> bool {
299        let mut guard = self.inner.lock().unwrap();
300        if guard.done {
301            return false;
302        }
303        let mut current = guard.callbacks_head;
304        let mut prev: *mut CallbackNode = std::ptr::null_mut();
305        while !current.is_null() {
306            if current == ptr {
307                let next = unsafe { (*current).next };
308                if prev.is_null() {
309                    guard.callbacks_head = next;
310                } else {
311                    unsafe { (*prev).next = next };
312                }
313                let mut boxed = unsafe { Box::from_raw(current) };
314                let existed = boxed.callback.take().is_some();
315                drop(boxed);
316                return existed;
317            }
318            prev = current;
319            current = unsafe { (*current).next };
320        }
321        false
322    }
323
324    fn detach_from_parent_idx(
325        parent: &Arc<CancelState>,
326        child: &Arc<CancelState>,
327        idx: usize,
328    ) -> bool {
329        let mut p_guard = parent.inner.lock().unwrap();
330        let len_before = p_guard.children.len();
331        if idx >= len_before {
332            return false;
333        }
334        let last = p_guard.children.pop().unwrap();
335        let len_after = p_guard.children.len();
336        if idx < len_after {
337            p_guard.children[idx] = last;
338            if let Some(last_child) = p_guard.children[idx].upgrade() {
339                last_child.inner.lock().unwrap().parent_idx = Some(idx);
340            }
341        }
342        drop(p_guard);
343        child.inner.lock().unwrap().parent_idx = None;
344        true
345    }
346
347    pub fn wait(&self) {
348        let mut guard = self.inner.lock().unwrap();
349        while !guard.done {
350            guard = self.cvar.wait(guard).unwrap();
351        }
352    }
353
354    pub fn wait_timeout(&self, dur: Duration) -> bool {
355        let mut guard = self.inner.lock().unwrap();
356        let deadline = Instant::now() + dur;
357        while !guard.done {
358            let now = Instant::now();
359            if now >= deadline {
360                return guard.done;
361            }
362            let remaining = deadline.saturating_duration_since(now);
363            let (g, timeout_res) = self.cvar.wait_timeout(guard, remaining).unwrap();
364            guard = g;
365            if timeout_res.timed_out() {
366                return guard.done;
367            }
368        }
369        true
370    }
371}
372
373impl Debug for CancelState {
374    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
375        let status = self.inner.lock().unwrap().status;
376        f.debug_struct("CancelState")
377            .field("status", &status)
378            .finish()
379    }
380}
381
382/// `Done()` 的返回句柄。
383#[derive(Clone, Debug)]
384pub enum DoneHandle {
385    Never,
386    Active(Arc<CancelState>),
387}
388
389impl DoneHandle {
390    pub const fn never() -> Self {
391        Self::Never
392    }
393
394    pub fn is_done(&self) -> bool {
395        match self {
396            DoneHandle::Never => false,
397            DoneHandle::Active(state) => state.is_done(),
398        }
399    }
400
401    pub fn wait(&self) {
402        if let DoneHandle::Active(state) = self {
403            state.wait();
404        }
405    }
406
407    pub fn wait_timeout(&self, dur: Duration) -> bool {
408        match self {
409            DoneHandle::Never => false,
410            DoneHandle::Active(state) => state.wait_timeout(dur),
411        }
412    }
413
414    pub fn register(&self, cb: impl FnOnce() + Send + 'static) -> StopFunc {
415        match self {
416            DoneHandle::Never => StopFunc::noop(),
417            DoneHandle::Active(state) => state.register(state.clone(), Box::new(cb)),
418        }
419    }
420}
421
422/// 与 Go AfterFunc 返回的 StopFunc 对齐。
423pub struct StopFunc {
424    inner: Option<Box<dyn FnOnce() -> bool + Send + 'static>>,
425}
426
427impl StopFunc {
428    fn new(state: Arc<CancelState>, ptr: *mut CallbackNode) -> Self {
429        let ptr_usize = ptr as usize;
430        Self {
431            inner: Some(Box::new(move || {
432                state.remove(ptr_usize as *mut CallbackNode)
433            })),
434        }
435    }
436
437    pub fn noop() -> Self {
438        Self {
439            inner: Some(Box::new(|| false)),
440        }
441    }
442
443    #[allow(non_snake_case)]
444    pub fn Stop(mut self) -> bool {
445        if let Some(f) = self.inner.take() {
446            f()
447        } else {
448            false
449        }
450    }
451}
452
453impl Clone for StopFunc {
454    fn clone(&self) -> Self {
455        // StopFunc 在 Go 中允许多次调用,我们在这里提供一次性的克隆,返回的闭包在重复调用时返回 false。
456        Self {
457            inner: Some(Box::new(|| false)),
458        }
459    }
460}