Skip to main content

cognee_core/
cancellation.rs

1use std::sync::Arc;
2use tokio::sync::watch;
3
4/// Creates a linked (`CancellationHandle`, `CancellationToken`) pair.
5///
6/// The handle is given to the *owner* of a task; the token is passed into the
7/// task itself.  Dropping the handle does **not** cancel — call
8/// [`CancellationHandle::cancel`] explicitly.
9pub fn cancellation_pair() -> (CancellationHandle, CancellationToken) {
10    let (tx, rx) = watch::channel(false);
11    (
12        CancellationHandle {
13            sender: Arc::new(tx),
14        },
15        CancellationToken { receiver: rx },
16    )
17}
18/// Allows the owner of a task to request cancellation.
19///
20/// Clone-able so that multiple parties can independently cancel the same task.
21#[derive(Clone)]
22pub struct CancellationHandle {
23    sender: Arc<watch::Sender<bool>>,
24}
25
26impl CancellationHandle {
27    /// Signal cancellation to all associated [`CancellationToken`]s.
28    pub fn cancel(&self) {
29        // Ignore errors: all tokens have been dropped, nothing to signal.
30        let _ = self.sender.send(true);
31    }
32
33    /// Returns `true` if cancellation has already been requested.
34    pub fn is_cancelled(&self) -> bool {
35        *self.sender.borrow()
36    }
37}
38/// Passed into a running task so it can observe cancellation requests.
39///
40/// Clone-able: each clone independently tracks whether it has already seen the
41/// cancellation signal (via the `watch` channel's mark-seen semantics).
42#[derive(Clone)]
43pub struct CancellationToken {
44    receiver: watch::Receiver<bool>,
45}
46
47impl CancellationToken {
48    /// Returns `true` if cancellation has been requested.
49    pub fn is_cancelled(&self) -> bool {
50        *self.receiver.borrow()
51    }
52
53    /// Await until cancellation is requested.
54    ///
55    /// Returns immediately if already cancelled.  Also returns if the
56    /// [`CancellationHandle`] is dropped without cancelling (treat as
57    /// cancelled to avoid hanging forever).
58    pub async fn cancelled(&self) {
59        if self.is_cancelled() {
60            return;
61        }
62        let mut rx = self.receiver.clone();
63        loop {
64            match rx.changed().await {
65                Ok(_) => {
66                    if *rx.borrow() {
67                        return;
68                    }
69                }
70                // Sender dropped — treat as cancelled so tasks don't hang.
71                Err(_) => return,
72            }
73        }
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use std::time::Duration;
81
82    #[test]
83    fn test_cancellation_signal_sync() {
84        let (handle, token) = cancellation_pair();
85
86        assert!(
87            !token.is_cancelled(),
88            "token should not be cancelled initially"
89        );
90
91        handle.cancel();
92
93        assert!(
94            token.is_cancelled(),
95            "token should be cancelled after handle.cancel()"
96        );
97    }
98
99    #[tokio::test]
100    async fn test_cancellation_signal_async() {
101        let (handle, token) = cancellation_pair();
102
103        assert!(!token.is_cancelled());
104
105        handle.cancel();
106
107        assert!(token.is_cancelled());
108
109        // `cancelled().await` should return immediately since cancel was already called.
110        let result = tokio::time::timeout(Duration::from_millis(100), token.cancelled()).await;
111
112        assert!(
113            result.is_ok(),
114            "token.cancelled().await should complete immediately after cancel, not time out"
115        );
116    }
117}