Skip to main content

mcp_kit/server/
cancellation.rs

1//! Request cancellation support.
2//!
3//! Tracks pending requests and allows cancellation via the
4//! `notifications/cancelled` notification from clients.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tokio_util::sync::CancellationToken;
10
11use crate::protocol::RequestId;
12use crate::server::session::SessionId;
13
14/// Manages pending requests and their cancellation tokens.
15///
16/// When a client sends `notifications/cancelled`, we can abort the
17/// corresponding in-flight request.
18#[derive(Clone, Default)]
19pub struct CancellationManager {
20    inner: Arc<RwLock<CancellationState>>,
21}
22
23#[derive(Default)]
24struct CancellationState {
25    /// Map from (session_id, request_id) to cancellation token
26    pending: HashMap<(SessionId, RequestId), CancellationToken>,
27}
28
29impl CancellationManager {
30    /// Create a new cancellation manager.
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Register a new pending request and get a cancellation token.
36    ///
37    /// The handler should check `token.is_cancelled()` periodically or
38    /// use `token.cancelled().await` to respond to cancellation.
39    pub async fn register(
40        &self,
41        session_id: &SessionId,
42        request_id: &RequestId,
43    ) -> CancellationToken {
44        let token = CancellationToken::new();
45        let mut state = self.inner.write().await;
46        state
47            .pending
48            .insert((session_id.clone(), request_id.clone()), token.clone());
49        token
50    }
51
52    /// Remove a completed request from tracking.
53    ///
54    /// Call this when the request completes (success or error).
55    pub async fn complete(&self, session_id: &SessionId, request_id: &RequestId) {
56        let mut state = self.inner.write().await;
57        state
58            .pending
59            .remove(&(session_id.clone(), request_id.clone()));
60    }
61
62    /// Cancel a pending request.
63    ///
64    /// Returns `true` if the request was found and cancelled, `false` otherwise.
65    pub async fn cancel(&self, session_id: &SessionId, request_id: &RequestId) -> bool {
66        let state = self.inner.read().await;
67        if let Some(token) = state.pending.get(&(session_id.clone(), request_id.clone())) {
68            token.cancel();
69            true
70        } else {
71            false
72        }
73    }
74
75    /// Cancel all pending requests for a session.
76    ///
77    /// Call this when a session disconnects.
78    pub async fn cancel_all(&self, session_id: &SessionId) {
79        let state = self.inner.read().await;
80        for ((sid, _), token) in state.pending.iter() {
81            if sid == session_id {
82                token.cancel();
83            }
84        }
85        drop(state);
86
87        // Clean up entries
88        let mut state = self.inner.write().await;
89        state.pending.retain(|(sid, _), _| sid != session_id);
90    }
91
92    /// Get the number of pending requests.
93    pub async fn pending_count(&self) -> usize {
94        let state = self.inner.read().await;
95        state.pending.len()
96    }
97
98    /// Check if a request is still pending (not cancelled).
99    pub async fn is_pending(&self, session_id: &SessionId, request_id: &RequestId) -> bool {
100        let state = self.inner.read().await;
101        state
102            .pending
103            .get(&(session_id.clone(), request_id.clone()))
104            .map(|t| !t.is_cancelled())
105            .unwrap_or(false)
106    }
107}
108
109/// A guard that automatically completes the request when dropped.
110///
111/// Use this to ensure requests are properly cleaned up even on early returns.
112pub struct RequestGuard {
113    manager: CancellationManager,
114    session_id: SessionId,
115    request_id: RequestId,
116    token: CancellationToken,
117}
118
119impl RequestGuard {
120    /// Create a new request guard.
121    pub async fn new(
122        manager: CancellationManager,
123        session_id: SessionId,
124        request_id: RequestId,
125    ) -> Self {
126        let token = manager.register(&session_id, &request_id).await;
127        Self {
128            manager,
129            session_id,
130            request_id,
131            token,
132        }
133    }
134
135    /// Get the cancellation token.
136    pub fn token(&self) -> &CancellationToken {
137        &self.token
138    }
139
140    /// Check if the request has been cancelled.
141    pub fn is_cancelled(&self) -> bool {
142        self.token.is_cancelled()
143    }
144
145    /// Wait for cancellation.
146    pub async fn cancelled(&self) {
147        self.token.cancelled().await
148    }
149}
150
151impl Drop for RequestGuard {
152    fn drop(&mut self) {
153        // We can't await in drop, so spawn a task
154        let manager = self.manager.clone();
155        let session_id = self.session_id.clone();
156        let request_id = self.request_id.clone();
157        tokio::spawn(async move {
158            manager.complete(&session_id, &request_id).await;
159        });
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[tokio::test]
168    async fn test_register_and_cancel() {
169        let mgr = CancellationManager::new();
170        let session = SessionId::new();
171        let request_id = RequestId::Number(1);
172
173        let token = mgr.register(&session, &request_id).await;
174        assert!(!token.is_cancelled());
175        assert!(mgr.is_pending(&session, &request_id).await);
176
177        mgr.cancel(&session, &request_id).await;
178        assert!(token.is_cancelled());
179    }
180
181    #[tokio::test]
182    async fn test_complete_removes() {
183        let mgr = CancellationManager::new();
184        let session = SessionId::new();
185        let request_id = RequestId::Number(1);
186
187        mgr.register(&session, &request_id).await;
188        assert_eq!(mgr.pending_count().await, 1);
189
190        mgr.complete(&session, &request_id).await;
191        assert_eq!(mgr.pending_count().await, 0);
192    }
193}