mssf_core/sync/
token.rs

1// ------------------------------------------------------------
2// Copyright (c) Microsoft Corporation.  All rights reserved.
3// Licensed under the MIT License (MIT). See License.txt in the repo root for license information.
4// ------------------------------------------------------------
5
6use std::future::Future;
7use std::pin::Pin;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::{Arc, Mutex};
10use std::task::{Context, Poll, Waker};
11
12use crate::runtime::executor::{BoxedCancelToken, CancelToken, EventFuture};
13
14/// A simple cancel token implementation
15#[derive(Clone, Debug)]
16pub struct SimpleCancelToken {
17    inner: Arc<TokenInner>,
18}
19
20#[derive(Debug)]
21struct TokenInner {
22    cancelled: AtomicBool,
23    wakers: Mutex<Vec<Waker>>,
24}
25
26impl SimpleCancelToken {
27    pub fn new() -> Self {
28        SimpleCancelToken {
29            inner: Arc::new(TokenInner {
30                cancelled: AtomicBool::new(false),
31                wakers: Mutex::new(Vec::new()),
32            }),
33        }
34    }
35
36    pub fn new_boxed() -> BoxedCancelToken {
37        Box::new(Self::new())
38    }
39
40    pub fn cancel(&self) {
41        // Set the cancelled flag
42        self.inner.cancelled.store(true, Ordering::Release);
43
44        // Wake all waiting tasks
45        let mut wakers = self.inner.wakers.lock().unwrap();
46        for waker in wakers.drain(..) {
47            waker.wake();
48        }
49    }
50
51    pub fn is_cancelled(&self) -> bool {
52        self.inner.cancelled.load(Ordering::Acquire)
53    }
54
55    /// Returns a future that completes when cancellation is triggered
56    pub fn cancelled(&self) -> CancelledFuture {
57        CancelledFuture {
58            token: self.clone(),
59        }
60    }
61}
62
63/// This future is cancel safe.
64pub struct CancelledFuture {
65    token: SimpleCancelToken,
66}
67
68impl Future for CancelledFuture {
69    type Output = ();
70
71    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
72        if self.token.is_cancelled() {
73            return Poll::Ready(());
74        }
75
76        // Register this task's waker to be notified when cancelled
77        let mut wakers = self.token.inner.wakers.lock().unwrap();
78
79        // Double-check after acquiring the lock
80        if self.token.is_cancelled() {
81            return Poll::Ready(());
82        }
83
84        // Store the waker to be called when cancel() is invoked
85        wakers.push(cx.waker().clone());
86        Poll::Pending
87    }
88}
89
90impl Default for SimpleCancelToken {
91    fn default() -> Self {
92        Self::new()
93    }
94}
95
96/// Integrate with mssf trait system.
97impl CancelToken for SimpleCancelToken {
98    fn cancel(&self) {
99        self.cancel();
100    }
101
102    fn is_cancelled(&self) -> bool {
103        self.is_cancelled()
104    }
105
106    fn wait(&self) -> Pin<Box<dyn EventFuture>> {
107        Box::pin(self.cancelled())
108    }
109
110    fn clone_box(&self) -> Box<dyn CancelToken> {
111        Box::new(self.clone())
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118
119    #[test]
120    fn test_cancel_token() {
121        let token = SimpleCancelToken::new();
122        assert!(!token.is_cancelled());
123        token.cancel();
124        assert!(token.is_cancelled());
125    }
126
127    #[tokio::test]
128    async fn test_cancel_token_async() {
129        let token = SimpleCancelToken::new();
130        let h = tokio::spawn({
131            let token = token.clone();
132            async move {
133                token.wait().await;
134            }
135        });
136        token.cancel();
137        assert!(token.is_cancelled());
138        h.await.unwrap();
139    }
140
141    #[tokio::test]
142    async fn test_cancel_token_multi() {
143        let token = SimpleCancelToken::new();
144        let mut join_set = tokio::task::JoinSet::new();
145
146        for _ in 0..10 {
147            let token = token.clone();
148            join_set.spawn(async move {
149                token.wait().await;
150            });
151            tokio::task::yield_now().await;
152        }
153        token.cancel();
154        assert!(token.is_cancelled());
155        join_set.join_all().await;
156    }
157}