1use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::{broadcast, RwLock};
13
14use super::error::{McpError, McpResult};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum CancellationReason {
19 UserCancelled,
21 Timeout,
23 ServerRequest,
25 Shutdown,
27 Error,
29}
30
31impl std::fmt::Display for CancellationReason {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 Self::UserCancelled => write!(f, "Request cancelled by user"),
35 Self::Timeout => write!(f, "Request timed out"),
36 Self::ServerRequest => write!(f, "Cancelled at server request"),
37 Self::Shutdown => write!(f, "Cancelled due to shutdown"),
38 Self::Error => write!(f, "Cancelled due to error"),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct CancellableRequest {
46 pub id: String,
48 pub server_name: String,
50 pub method: String,
52 pub start_time: Instant,
54 pub timeout: Option<Duration>,
56}
57
58#[derive(Debug, Clone)]
60pub struct CancellationResult {
61 pub success: bool,
63 pub reason: CancellationReason,
65 pub request_id: String,
67 pub server_name: String,
69 pub duration: Duration,
71}
72
73#[derive(Debug, Clone)]
78pub struct CancellationToken {
79 inner: Arc<RwLock<CancellationTokenInner>>,
80 sender: broadcast::Sender<CancellationReason>,
81}
82
83#[derive(Debug)]
84struct CancellationTokenInner {
85 cancelled: bool,
86 reason: Option<CancellationReason>,
87 timestamp: Option<Instant>,
88}
89
90impl CancellationToken {
91 pub fn new() -> Self {
93 let (sender, _) = broadcast::channel(16);
94 Self {
95 inner: Arc::new(RwLock::new(CancellationTokenInner {
96 cancelled: false,
97 reason: None,
98 timestamp: None,
99 })),
100 sender,
101 }
102 }
103
104 pub async fn is_cancelled(&self) -> bool {
106 self.inner.read().await.cancelled
107 }
108
109 pub async fn reason(&self) -> Option<CancellationReason> {
111 self.inner.read().await.reason
112 }
113
114 pub async fn timestamp(&self) -> Option<Instant> {
116 self.inner.read().await.timestamp
117 }
118
119 pub async fn cancel(&self, reason: CancellationReason) {
121 let mut inner = self.inner.write().await;
122 if inner.cancelled {
123 return;
124 }
125
126 inner.cancelled = true;
127 inner.reason = Some(reason);
128 inner.timestamp = Some(Instant::now());
129
130 let _ = self.sender.send(reason);
131 }
132
133 pub async fn throw_if_cancelled(&self) -> McpResult<()> {
135 let inner = self.inner.read().await;
136 if inner.cancelled {
137 let reason = inner.reason.unwrap_or(CancellationReason::UserCancelled);
138 return Err(McpError::cancelled(
139 reason.to_string(),
140 Some(reason.to_string()),
141 ));
142 }
143 Ok(())
144 }
145
146 pub fn subscribe(&self) -> broadcast::Receiver<CancellationReason> {
148 self.sender.subscribe()
149 }
150}
151
152impl Default for CancellationToken {
153 fn default() -> Self {
154 Self::new()
155 }
156}
157
158#[derive(Debug, Clone)]
160pub enum CancellationEvent {
161 RequestRegistered {
163 id: String,
164 server_name: String,
165 method: String,
166 },
167 RequestUnregistered { id: String, server_name: String },
169 RequestCancelled(CancellationResult),
171 ServerCancelled { server_name: String, count: usize },
173 AllCancelled { count: usize },
175}
176
177pub struct McpCancellationManager {
185 requests: Arc<RwLock<HashMap<String, CancellableRequest>>>,
186 tokens: Arc<RwLock<HashMap<String, CancellationToken>>>,
187 event_sender: broadcast::Sender<CancellationEvent>,
188}
189
190impl McpCancellationManager {
191 pub fn new() -> Self {
193 let (event_sender, _) = broadcast::channel(256);
194 Self {
195 requests: Arc::new(RwLock::new(HashMap::new())),
196 tokens: Arc::new(RwLock::new(HashMap::new())),
197 event_sender,
198 }
199 }
200
201 pub fn subscribe(&self) -> broadcast::Receiver<CancellationEvent> {
203 self.event_sender.subscribe()
204 }
205
206 pub async fn register_request(
208 &self,
209 id: impl Into<String>,
210 server_name: impl Into<String>,
211 method: impl Into<String>,
212 timeout: Option<Duration>,
213 ) -> CancellationToken {
214 let id = id.into();
215 let server_name = server_name.into();
216 let method = method.into();
217
218 let request = CancellableRequest {
219 id: id.clone(),
220 server_name: server_name.clone(),
221 method: method.clone(),
222 start_time: Instant::now(),
223 timeout,
224 };
225
226 let token = CancellationToken::new();
227
228 self.requests.write().await.insert(id.clone(), request);
229 self.tokens.write().await.insert(id.clone(), token.clone());
230
231 let _ = self
232 .event_sender
233 .send(CancellationEvent::RequestRegistered {
234 id,
235 server_name,
236 method,
237 });
238
239 token
240 }
241
242 pub async fn unregister_request(&self, id: &str) -> bool {
244 let request = self.requests.write().await.remove(id);
245 self.tokens.write().await.remove(id);
246
247 if let Some(req) = request {
248 let _ = self
249 .event_sender
250 .send(CancellationEvent::RequestUnregistered {
251 id: id.to_string(),
252 server_name: req.server_name,
253 });
254 true
255 } else {
256 false
257 }
258 }
259
260 pub async fn has_request(&self, id: &str) -> bool {
262 self.requests.read().await.contains_key(id)
263 }
264
265 pub async fn get_request(&self, id: &str) -> Option<CancellableRequest> {
267 self.requests.read().await.get(id).cloned()
268 }
269
270 pub async fn get_all_requests(&self) -> Vec<CancellableRequest> {
272 self.requests.read().await.values().cloned().collect()
273 }
274
275 pub async fn get_server_requests(&self, server_name: &str) -> Vec<CancellableRequest> {
277 self.requests
278 .read()
279 .await
280 .values()
281 .filter(|r| r.server_name == server_name)
282 .cloned()
283 .collect()
284 }
285
286 pub async fn cancel_request(
288 &self,
289 id: &str,
290 reason: CancellationReason,
291 ) -> Option<CancellationResult> {
292 let request = self.requests.write().await.remove(id)?;
293 let token = self.tokens.write().await.remove(id);
294
295 if let Some(t) = token {
297 t.cancel(reason).await;
298 }
299
300 let duration = request.start_time.elapsed();
301 let result = CancellationResult {
302 success: true,
303 reason,
304 request_id: id.to_string(),
305 server_name: request.server_name,
306 duration,
307 };
308
309 let _ = self
310 .event_sender
311 .send(CancellationEvent::RequestCancelled(result.clone()));
312
313 Some(result)
314 }
315
316 pub async fn cancel_server_requests(
318 &self,
319 server_name: &str,
320 reason: CancellationReason,
321 ) -> Vec<CancellationResult> {
322 let requests = self.get_server_requests(server_name).await;
323 let mut results = Vec::new();
324
325 for request in requests {
326 if let Some(result) = self.cancel_request(&request.id, reason).await {
327 results.push(result);
328 }
329 }
330
331 let _ = self.event_sender.send(CancellationEvent::ServerCancelled {
332 server_name: server_name.to_string(),
333 count: results.len(),
334 });
335
336 results
337 }
338
339 pub async fn cancel_all(&self, reason: CancellationReason) -> Vec<CancellationResult> {
341 let requests = self.get_all_requests().await;
342 let mut results = Vec::new();
343
344 for request in requests {
345 if let Some(result) = self.cancel_request(&request.id, reason).await {
346 results.push(result);
347 }
348 }
349
350 let _ = self.event_sender.send(CancellationEvent::AllCancelled {
351 count: results.len(),
352 });
353
354 results
355 }
356
357 pub async fn get_stats(&self) -> CancellationStats {
359 let requests = self.get_all_requests().await;
360
361 let mut by_server: HashMap<String, usize> = HashMap::new();
362 let mut with_timeout = 0;
363
364 for request in &requests {
365 *by_server.entry(request.server_name.clone()).or_insert(0) += 1;
366 if request.timeout.is_some() {
367 with_timeout += 1;
368 }
369 }
370
371 CancellationStats {
372 active_requests: requests.len(),
373 by_server,
374 with_timeout,
375 }
376 }
377
378 pub async fn get_request_durations(&self) -> Vec<RequestDuration> {
380 self.requests
381 .read()
382 .await
383 .values()
384 .map(|r| RequestDuration {
385 id: r.id.clone(),
386 server_name: r.server_name.clone(),
387 method: r.method.clone(),
388 duration: r.start_time.elapsed(),
389 })
390 .collect()
391 }
392
393 pub async fn find_long_running_requests(&self, threshold: Duration) -> Vec<CancellableRequest> {
395 self.requests
396 .read()
397 .await
398 .values()
399 .filter(|r| r.start_time.elapsed() > threshold)
400 .cloned()
401 .collect()
402 }
403
404 pub async fn cleanup(&self) {
406 self.requests.write().await.clear();
407 self.tokens.write().await.clear();
408 }
409}
410
411impl Default for McpCancellationManager {
412 fn default() -> Self {
413 Self::new()
414 }
415}
416
417#[derive(Debug, Clone)]
419pub struct CancellationStats {
420 pub active_requests: usize,
422 pub by_server: HashMap<String, usize>,
424 pub with_timeout: usize,
426}
427
428#[derive(Debug, Clone)]
430pub struct RequestDuration {
431 pub id: String,
433 pub server_name: String,
435 pub method: String,
437 pub duration: Duration,
439}
440
441#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
443pub struct CancelledNotification {
444 pub request_id: String,
446 #[serde(skip_serializing_if = "Option::is_none")]
448 pub reason: Option<String>,
449}
450
451impl CancelledNotification {
452 pub fn new(request_id: impl Into<String>, reason: Option<String>) -> Self {
454 Self {
455 request_id: request_id.into(),
456 reason,
457 }
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464
465 #[test]
466 fn test_cancellation_reason_display() {
467 assert_eq!(
468 CancellationReason::UserCancelled.to_string(),
469 "Request cancelled by user"
470 );
471 assert_eq!(CancellationReason::Timeout.to_string(), "Request timed out");
472 assert_eq!(
473 CancellationReason::Shutdown.to_string(),
474 "Cancelled due to shutdown"
475 );
476 }
477
478 #[tokio::test]
479 async fn test_cancellation_token_new() {
480 let token = CancellationToken::new();
481 assert!(!token.is_cancelled().await);
482 assert!(token.reason().await.is_none());
483 }
484
485 #[tokio::test]
486 async fn test_cancellation_token_cancel() {
487 let token = CancellationToken::new();
488 token.cancel(CancellationReason::UserCancelled).await;
489
490 assert!(token.is_cancelled().await);
491 assert_eq!(
492 token.reason().await,
493 Some(CancellationReason::UserCancelled)
494 );
495 }
496
497 #[tokio::test]
498 async fn test_cancellation_token_throw_if_cancelled() {
499 let token = CancellationToken::new();
500 assert!(token.throw_if_cancelled().await.is_ok());
501
502 token.cancel(CancellationReason::Timeout).await;
503 assert!(token.throw_if_cancelled().await.is_err());
504 }
505
506 #[tokio::test]
507 async fn test_manager_register_request() {
508 let manager = McpCancellationManager::new();
509 let token = manager
510 .register_request("req-1", "server-1", "tools/call", None)
511 .await;
512
513 assert!(!token.is_cancelled().await);
514 assert!(manager.has_request("req-1").await);
515 }
516
517 #[tokio::test]
518 async fn test_manager_unregister_request() {
519 let manager = McpCancellationManager::new();
520 manager
521 .register_request("req-1", "server-1", "tools/call", None)
522 .await;
523
524 assert!(manager.unregister_request("req-1").await);
525 assert!(!manager.has_request("req-1").await);
526 }
527
528 #[tokio::test]
529 async fn test_manager_cancel_request() {
530 let manager = McpCancellationManager::new();
531 let token = manager
532 .register_request("req-1", "server-1", "tools/call", None)
533 .await;
534
535 let result = manager
536 .cancel_request("req-1", CancellationReason::UserCancelled)
537 .await;
538
539 assert!(result.is_some());
540 let result = result.unwrap();
541 assert!(result.success);
542 assert_eq!(result.reason, CancellationReason::UserCancelled);
543 assert!(token.is_cancelled().await);
544 }
545
546 #[tokio::test]
547 async fn test_manager_cancel_server_requests() {
548 let manager = McpCancellationManager::new();
549 manager
550 .register_request("req-1", "server-1", "tools/call", None)
551 .await;
552 manager
553 .register_request("req-2", "server-1", "resources/read", None)
554 .await;
555 manager
556 .register_request("req-3", "server-2", "tools/call", None)
557 .await;
558
559 let results = manager
560 .cancel_server_requests("server-1", CancellationReason::Shutdown)
561 .await;
562
563 assert_eq!(results.len(), 2);
564 assert!(!manager.has_request("req-1").await);
565 assert!(!manager.has_request("req-2").await);
566 assert!(manager.has_request("req-3").await);
567 }
568
569 #[tokio::test]
570 async fn test_manager_cancel_all() {
571 let manager = McpCancellationManager::new();
572 manager
573 .register_request("req-1", "server-1", "tools/call", None)
574 .await;
575 manager
576 .register_request("req-2", "server-2", "tools/call", None)
577 .await;
578
579 let results = manager.cancel_all(CancellationReason::Shutdown).await;
580
581 assert_eq!(results.len(), 2);
582 assert!(manager.get_all_requests().await.is_empty());
583 }
584
585 #[tokio::test]
586 async fn test_manager_get_stats() {
587 let manager = McpCancellationManager::new();
588 manager
589 .register_request(
590 "req-1",
591 "server-1",
592 "tools/call",
593 Some(Duration::from_secs(30)),
594 )
595 .await;
596 manager
597 .register_request("req-2", "server-1", "resources/read", None)
598 .await;
599 manager
600 .register_request("req-3", "server-2", "tools/call", None)
601 .await;
602
603 let stats = manager.get_stats().await;
604
605 assert_eq!(stats.active_requests, 3);
606 assert_eq!(stats.by_server.get("server-1"), Some(&2));
607 assert_eq!(stats.by_server.get("server-2"), Some(&1));
608 assert_eq!(stats.with_timeout, 1);
609 }
610
611 #[test]
612 fn test_cancelled_notification() {
613 let notification = CancelledNotification::new("req-1", Some("User cancelled".to_string()));
614 assert_eq!(notification.request_id, "req-1");
615 assert_eq!(notification.reason, Some("User cancelled".to_string()));
616 }
617}