mcp_kit/server/
cancellation.rs1use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tokio_util::sync::CancellationToken;
10
11use crate::protocol::RequestId;
12use crate::server::session::SessionId;
13
14#[derive(Clone, Default)]
19pub struct CancellationManager {
20 inner: Arc<RwLock<CancellationState>>,
21}
22
23#[derive(Default)]
24struct CancellationState {
25 pending: HashMap<(SessionId, RequestId), CancellationToken>,
27}
28
29impl CancellationManager {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub async fn register(
40 &self,
41 session_id: &SessionId,
42 request_id: &RequestId,
43 ) -> CancellationToken {
44 let token = CancellationToken::new();
45 let mut state = self.inner.write().await;
46 state
47 .pending
48 .insert((session_id.clone(), request_id.clone()), token.clone());
49 token
50 }
51
52 pub async fn complete(&self, session_id: &SessionId, request_id: &RequestId) {
56 let mut state = self.inner.write().await;
57 state
58 .pending
59 .remove(&(session_id.clone(), request_id.clone()));
60 }
61
62 pub async fn cancel(&self, session_id: &SessionId, request_id: &RequestId) -> bool {
66 let state = self.inner.read().await;
67 if let Some(token) = state.pending.get(&(session_id.clone(), request_id.clone())) {
68 token.cancel();
69 true
70 } else {
71 false
72 }
73 }
74
75 pub async fn cancel_all(&self, session_id: &SessionId) {
79 let state = self.inner.read().await;
80 for ((sid, _), token) in state.pending.iter() {
81 if sid == session_id {
82 token.cancel();
83 }
84 }
85 drop(state);
86
87 let mut state = self.inner.write().await;
89 state.pending.retain(|(sid, _), _| sid != session_id);
90 }
91
92 pub async fn pending_count(&self) -> usize {
94 let state = self.inner.read().await;
95 state.pending.len()
96 }
97
98 pub async fn is_pending(&self, session_id: &SessionId, request_id: &RequestId) -> bool {
100 let state = self.inner.read().await;
101 state
102 .pending
103 .get(&(session_id.clone(), request_id.clone()))
104 .map(|t| !t.is_cancelled())
105 .unwrap_or(false)
106 }
107}
108
109pub struct RequestGuard {
113 manager: CancellationManager,
114 session_id: SessionId,
115 request_id: RequestId,
116 token: CancellationToken,
117}
118
119impl RequestGuard {
120 pub async fn new(
122 manager: CancellationManager,
123 session_id: SessionId,
124 request_id: RequestId,
125 ) -> Self {
126 let token = manager.register(&session_id, &request_id).await;
127 Self {
128 manager,
129 session_id,
130 request_id,
131 token,
132 }
133 }
134
135 pub fn token(&self) -> &CancellationToken {
137 &self.token
138 }
139
140 pub fn is_cancelled(&self) -> bool {
142 self.token.is_cancelled()
143 }
144
145 pub async fn cancelled(&self) {
147 self.token.cancelled().await
148 }
149}
150
151impl Drop for RequestGuard {
152 fn drop(&mut self) {
153 let manager = self.manager.clone();
155 let session_id = self.session_id.clone();
156 let request_id = self.request_id.clone();
157 tokio::spawn(async move {
158 manager.complete(&session_id, &request_id).await;
159 });
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166
167 #[tokio::test]
168 async fn test_register_and_cancel() {
169 let mgr = CancellationManager::new();
170 let session = SessionId::new();
171 let request_id = RequestId::Number(1);
172
173 let token = mgr.register(&session, &request_id).await;
174 assert!(!token.is_cancelled());
175 assert!(mgr.is_pending(&session, &request_id).await);
176
177 mgr.cancel(&session, &request_id).await;
178 assert!(token.is_cancelled());
179 }
180
181 #[tokio::test]
182 async fn test_complete_removes() {
183 let mgr = CancellationManager::new();
184 let session = SessionId::new();
185 let request_id = RequestId::Number(1);
186
187 mgr.register(&session, &request_id).await;
188 assert_eq!(mgr.pending_count().await, 1);
189
190 mgr.complete(&session, &request_id).await;
191 assert_eq!(mgr.pending_count().await, 0);
192 }
193}