1use crate::context::error::{CANCELLED, ContextError, DEADLINE_EXCEEDED, Error};
2use std::fmt::{Debug, Formatter};
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::{Arc, Condvar, Mutex, Weak};
5use std::time::{Duration, Instant};
6use tokio::sync::Notify;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9enum Status {
10 Active,
11 Canceled,
12 DeadlineExceeded,
13}
14
15struct Inner {
22 status: Status,
23 cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
24 parent: Option<Weak<CancelState>>,
25 children: Vec<Weak<CancelState>>,
26 parent_idx: Option<usize>,
27 done: bool,
28 handle_count: AtomicUsize,
29 next_id: usize,
30 callbacks_head: *mut CallbackNode,
31}
32
33unsafe impl Send for Inner {}
34unsafe impl Sync for Inner {}
35
36pub(crate) struct CallbackNode {
38 #[allow(dead_code)]
39 id: usize,
40 callback: Option<Box<dyn FnOnce() + Send + 'static>>,
41 next: *mut CallbackNode,
42}
43
44unsafe impl Send for CallbackNode {}
45unsafe impl Sync for CallbackNode {}
46
47pub struct CancelState {
49 inner: Mutex<Inner>,
50 cvar: Condvar,
51 notify: Arc<Notify>,
52}
53
54impl CancelState {
55 pub fn new_root() -> Arc<Self> {
56 Arc::new(Self {
57 inner: Mutex::new(Inner {
58 status: Status::Active,
59 cause: None,
60 parent: None,
61 children: Vec::new(),
62 parent_idx: None,
63 done: false,
64 handle_count: AtomicUsize::new(1),
65 next_id: 0,
66 callbacks_head: std::ptr::null_mut(),
67 }),
68 cvar: Condvar::new(),
69 notify: Arc::new(Notify::new()),
70 })
71 }
72
73 pub fn child_of(parent: &Arc<Self>) -> Arc<Self> {
74 let child = Arc::new(Self {
75 inner: Mutex::new(Inner {
76 status: Status::Active,
77 cause: None,
78 parent: Some(Arc::downgrade(parent)),
79 children: Vec::new(),
80 parent_idx: None,
81 done: false,
82 handle_count: AtomicUsize::new(1),
83 next_id: 0,
84 callbacks_head: std::ptr::null_mut(),
85 }),
86 cvar: Condvar::new(),
87 notify: Arc::new(Notify::new()),
88 });
89
90 let weak_child = Arc::downgrade(&child);
92 let mut guard = parent.inner.lock().unwrap();
93 let idx = guard.children.len();
94 guard.children.push(weak_child);
95 drop(guard);
96 child.inner.lock().unwrap().parent_idx = Some(idx);
97
98 child
99 }
100
101 pub fn done_handle(this: &Arc<Self>) -> DoneHandle {
102 DoneHandle::Active(this.clone())
103 }
104
105 pub fn err(&self) -> Option<ContextError> {
106 let guard = self.inner.lock().unwrap();
107 match guard.status {
108 Status::Active => None,
109 Status::Canceled => Some(ContextError::with_cause(
110 Error::Canceled,
111 guard.cause.clone(),
112 )),
113 Status::DeadlineExceeded => Some(ContextError::with_cause(
114 Error::DeadlineExceeded,
115 guard.cause.clone(),
116 )),
117 }
118 }
119
120 pub fn cause(&self) -> Option<Arc<dyn std::error::Error + Send + Sync>> {
121 self.inner.lock().unwrap().cause.clone()
122 }
123
124 pub fn is_done(&self) -> bool {
125 self.inner.lock().unwrap().done
126 }
127
128 pub fn add_handle(&self) {
129 let guard = self.inner.lock().unwrap();
130 guard.handle_count.fetch_add(1, Ordering::Relaxed);
131 }
132
133 pub fn release_handle(self: &Arc<Self>) {
134 let mut current = self.clone();
135 loop {
136 let parent_opt = {
137 let mut guard = current.inner.lock().unwrap();
138 guard.handle_count.fetch_sub(1, Ordering::Relaxed);
139 guard.children.retain(|w| w.upgrade().is_some());
140 let has_handles = guard.handle_count.load(Ordering::Relaxed) > 0;
141 let has_children = !guard.children.is_empty();
142 let done = guard.done;
143 let parent = guard.parent.clone();
144 let parent_idx = guard.parent_idx;
145 drop(guard);
146 if has_handles || has_children || !done {
147 return;
148 }
149 parent.zip(parent_idx)
150 };
151
152 let Some((parent_weak, idx)) = parent_opt else {
153 return;
154 };
155 let Some(parent) = parent_weak.upgrade() else {
156 return;
157 };
158
159 if !Self::detach_from_parent_idx(&parent, ¤t, idx) {
160 return;
161 }
162 current = parent;
163 }
164 }
165
166 pub fn cancel(
167 self: &Arc<Self>,
168 kind: Error,
169 cause: Option<Arc<dyn std::error::Error + Send + Sync>>,
170 ) {
171 let cause_for_self = cause.clone();
173 let (callbacks, children, notify_needed) = {
174 let mut guard = self.inner.lock().unwrap();
175 if guard.done {
176 return;
177 }
178
179 guard.status = match kind {
180 Error::Canceled => Status::Canceled,
181 Error::DeadlineExceeded => Status::DeadlineExceeded,
182 };
183 if guard.cause.is_none() {
184 guard.cause = cause_for_self.clone().or_else(|| {
185 let err = match kind {
186 Error::Canceled => CANCELLED,
187 Error::DeadlineExceeded => DEADLINE_EXCEEDED,
188 };
189 Some(Arc::new(err) as Arc<dyn std::error::Error + Send + Sync>)
190 });
191 }
192 guard.done = true;
193
194 let head = guard.callbacks_head;
195 guard.callbacks_head = std::ptr::null_mut();
196 let callbacks = if head.is_null() {
197 Vec::new()
198 } else {
199 Self::drain_callbacks(head)
200 };
201
202 let children = guard
203 .children
204 .iter()
205 .filter_map(|w| w.upgrade())
206 .collect::<Vec<_>>();
207 (callbacks, children, true)
208 };
209
210 if notify_needed {
211 self.notify.notify_waiters();
212 self.cvar.notify_all();
213 }
214
215 for cb in callbacks {
217 cb();
218 }
219
220 for child in children {
222 child.cancel(kind, cause.clone());
223 }
224
225 self.prune_if_detached();
227 }
228
229 fn prune_if_detached(self: &Arc<Self>) {
230 let mut current = self.clone();
231 loop {
232 let parent_info = {
233 let mut guard = current.inner.lock().unwrap();
234 if !guard.done {
235 return;
236 }
237 guard.children.retain(|w| w.upgrade().is_some());
238 if !guard.children.is_empty() {
239 return;
240 }
241 guard.parent.clone().zip(guard.parent_idx)
242 };
243
244 let Some((parent_weak, idx)) = parent_info else {
245 return;
246 };
247 let Some(parent) = parent_weak.upgrade() else {
248 return;
249 };
250
251 if !Self::detach_from_parent_idx(&parent, ¤t, idx) {
252 return;
253 }
254
255 current = parent;
256 }
257 }
258
259 pub fn notify(&self) -> Arc<Notify> {
260 self.notify.clone()
261 }
262
263 fn drain_callbacks(mut head: *mut CallbackNode) -> Vec<Box<dyn FnOnce() + Send + 'static>> {
264 let mut out = Vec::new();
265 while !head.is_null() {
266 let node = unsafe { Box::from_raw(head) };
267 if let Some(callback) = node.callback {
268 out.push(callback);
269 }
270 head = node.next;
271 }
272 out
273 }
274
275 pub fn register(
276 &self,
277 owner: Arc<CancelState>,
278 cb: Box<dyn FnOnce() + Send + 'static>,
279 ) -> StopFunc {
280 let mut guard = self.inner.lock().unwrap();
281 if guard.done {
282 drop(guard);
283 cb();
284 return StopFunc::noop();
285 }
286 let id = guard.next_id;
287 guard.next_id += 1;
288 let node = Box::new(CallbackNode {
289 id,
290 callback: Some(cb),
291 next: guard.callbacks_head,
292 });
293 let ptr = Box::into_raw(node);
294 guard.callbacks_head = ptr;
295 StopFunc::new(owner, ptr)
296 }
297
298 pub(crate) fn remove(&self, ptr: *mut CallbackNode) -> bool {
299 let mut guard = self.inner.lock().unwrap();
300 if guard.done {
301 return false;
302 }
303 let mut current = guard.callbacks_head;
304 let mut prev: *mut CallbackNode = std::ptr::null_mut();
305 while !current.is_null() {
306 if current == ptr {
307 let next = unsafe { (*current).next };
308 if prev.is_null() {
309 guard.callbacks_head = next;
310 } else {
311 unsafe { (*prev).next = next };
312 }
313 let mut boxed = unsafe { Box::from_raw(current) };
314 let existed = boxed.callback.take().is_some();
315 drop(boxed);
316 return existed;
317 }
318 prev = current;
319 current = unsafe { (*current).next };
320 }
321 false
322 }
323
324 fn detach_from_parent_idx(
325 parent: &Arc<CancelState>,
326 child: &Arc<CancelState>,
327 idx: usize,
328 ) -> bool {
329 let mut p_guard = parent.inner.lock().unwrap();
330 let len_before = p_guard.children.len();
331 if idx >= len_before {
332 return false;
333 }
334 let last = p_guard.children.pop().unwrap();
335 let len_after = p_guard.children.len();
336 if idx < len_after {
337 p_guard.children[idx] = last;
338 if let Some(last_child) = p_guard.children[idx].upgrade() {
339 last_child.inner.lock().unwrap().parent_idx = Some(idx);
340 }
341 }
342 drop(p_guard);
343 child.inner.lock().unwrap().parent_idx = None;
344 true
345 }
346
347 pub fn wait(&self) {
348 let mut guard = self.inner.lock().unwrap();
349 while !guard.done {
350 guard = self.cvar.wait(guard).unwrap();
351 }
352 }
353
354 pub fn wait_timeout(&self, dur: Duration) -> bool {
355 let mut guard = self.inner.lock().unwrap();
356 let deadline = Instant::now() + dur;
357 while !guard.done {
358 let now = Instant::now();
359 if now >= deadline {
360 return guard.done;
361 }
362 let remaining = deadline.saturating_duration_since(now);
363 let (g, timeout_res) = self.cvar.wait_timeout(guard, remaining).unwrap();
364 guard = g;
365 if timeout_res.timed_out() {
366 return guard.done;
367 }
368 }
369 true
370 }
371}
372
373impl Debug for CancelState {
374 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
375 let status = self.inner.lock().unwrap().status;
376 f.debug_struct("CancelState")
377 .field("status", &status)
378 .finish()
379 }
380}
381
382#[derive(Clone, Debug)]
384pub enum DoneHandle {
385 Never,
386 Active(Arc<CancelState>),
387}
388
389impl DoneHandle {
390 pub const fn never() -> Self {
391 Self::Never
392 }
393
394 pub fn is_done(&self) -> bool {
395 match self {
396 DoneHandle::Never => false,
397 DoneHandle::Active(state) => state.is_done(),
398 }
399 }
400
401 pub fn wait(&self) {
402 if let DoneHandle::Active(state) = self {
403 state.wait();
404 }
405 }
406
407 pub fn wait_timeout(&self, dur: Duration) -> bool {
408 match self {
409 DoneHandle::Never => false,
410 DoneHandle::Active(state) => state.wait_timeout(dur),
411 }
412 }
413
414 pub fn register(&self, cb: impl FnOnce() + Send + 'static) -> StopFunc {
415 match self {
416 DoneHandle::Never => StopFunc::noop(),
417 DoneHandle::Active(state) => state.register(state.clone(), Box::new(cb)),
418 }
419 }
420}
421
422pub struct StopFunc {
424 inner: Option<Box<dyn FnOnce() -> bool + Send + 'static>>,
425}
426
427impl StopFunc {
428 fn new(state: Arc<CancelState>, ptr: *mut CallbackNode) -> Self {
429 let ptr_usize = ptr as usize;
430 Self {
431 inner: Some(Box::new(move || {
432 state.remove(ptr_usize as *mut CallbackNode)
433 })),
434 }
435 }
436
437 pub fn noop() -> Self {
438 Self {
439 inner: Some(Box::new(|| false)),
440 }
441 }
442
443 #[allow(non_snake_case)]
444 pub fn Stop(mut self) -> bool {
445 if let Some(f) = self.inner.take() {
446 f()
447 } else {
448 false
449 }
450 }
451}
452
453impl Clone for StopFunc {
454 fn clone(&self) -> Self {
455 Self {
457 inner: Some(Box::new(|| false)),
458 }
459 }
460}