Skip to main content

aster/mcp/
cancellation.rs

1//! MCP Cancellation Module
2//!
3//! Implements request cancellation for MCP operations. Provides:
4//! - Request tracking and cancellation
5//! - Timeout-based cancellation
6//! - Cancellation token pattern
7//! - Integration with tokio CancellationToken
8
9use std::collections::HashMap;
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::{broadcast, RwLock};
13
14use super::error::{McpError, McpResult};
15
16/// Cancellation reason
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
18pub enum CancellationReason {
19    /// User cancelled the request
20    UserCancelled,
21    /// Request timed out
22    Timeout,
23    /// Server requested cancellation
24    ServerRequest,
25    /// System is shutting down
26    Shutdown,
27    /// Error occurred
28    Error,
29}
30
31impl std::fmt::Display for CancellationReason {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            Self::UserCancelled => write!(f, "Request cancelled by user"),
35            Self::Timeout => write!(f, "Request timed out"),
36            Self::ServerRequest => write!(f, "Cancelled at server request"),
37            Self::Shutdown => write!(f, "Cancelled due to shutdown"),
38            Self::Error => write!(f, "Cancelled due to error"),
39        }
40    }
41}
42
43/// Cancellable request information
44#[derive(Debug, Clone)]
45pub struct CancellableRequest {
46    /// Request ID
47    pub id: String,
48    /// Server name
49    pub server_name: String,
50    /// Method name
51    pub method: String,
52    /// Start time
53    pub start_time: Instant,
54    /// Timeout duration (if set)
55    pub timeout: Option<Duration>,
56}
57
58/// Cancellation result
59#[derive(Debug, Clone)]
60pub struct CancellationResult {
61    /// Whether cancellation was successful
62    pub success: bool,
63    /// Cancellation reason
64    pub reason: CancellationReason,
65    /// Request ID
66    pub request_id: String,
67    /// Server name
68    pub server_name: String,
69    /// Duration since request started
70    pub duration: Duration,
71}
72
73/// Cancellation token for request tracking
74///
75/// Provides a way to check if a request has been cancelled
76/// and to register callbacks for cancellation events.
77#[derive(Debug, Clone)]
78pub struct CancellationToken {
79    inner: Arc<RwLock<CancellationTokenInner>>,
80    sender: broadcast::Sender<CancellationReason>,
81}
82
83#[derive(Debug)]
84struct CancellationTokenInner {
85    cancelled: bool,
86    reason: Option<CancellationReason>,
87    timestamp: Option<Instant>,
88}
89
90impl CancellationToken {
91    /// Create a new cancellation token
92    pub fn new() -> Self {
93        let (sender, _) = broadcast::channel(16);
94        Self {
95            inner: Arc::new(RwLock::new(CancellationTokenInner {
96                cancelled: false,
97                reason: None,
98                timestamp: None,
99            })),
100            sender,
101        }
102    }
103
104    /// Check if cancellation has been requested
105    pub async fn is_cancelled(&self) -> bool {
106        self.inner.read().await.cancelled
107    }
108
109    /// Get cancellation reason
110    pub async fn reason(&self) -> Option<CancellationReason> {
111        self.inner.read().await.reason
112    }
113
114    /// Get cancellation timestamp
115    pub async fn timestamp(&self) -> Option<Instant> {
116        self.inner.read().await.timestamp
117    }
118
119    /// Request cancellation
120    pub async fn cancel(&self, reason: CancellationReason) {
121        let mut inner = self.inner.write().await;
122        if inner.cancelled {
123            return;
124        }
125
126        inner.cancelled = true;
127        inner.reason = Some(reason);
128        inner.timestamp = Some(Instant::now());
129
130        let _ = self.sender.send(reason);
131    }
132
133    /// Throw if cancelled
134    pub async fn throw_if_cancelled(&self) -> McpResult<()> {
135        let inner = self.inner.read().await;
136        if inner.cancelled {
137            let reason = inner.reason.unwrap_or(CancellationReason::UserCancelled);
138            return Err(McpError::cancelled(
139                reason.to_string(),
140                Some(reason.to_string()),
141            ));
142        }
143        Ok(())
144    }
145
146    /// Subscribe to cancellation events
147    pub fn subscribe(&self) -> broadcast::Receiver<CancellationReason> {
148        self.sender.subscribe()
149    }
150}
151
152impl Default for CancellationToken {
153    fn default() -> Self {
154        Self::new()
155    }
156}
157
158/// Cancellation event for broadcasting
159#[derive(Debug, Clone)]
160pub enum CancellationEvent {
161    /// Request registered
162    RequestRegistered {
163        id: String,
164        server_name: String,
165        method: String,
166    },
167    /// Request unregistered
168    RequestUnregistered { id: String, server_name: String },
169    /// Request cancelled
170    RequestCancelled(CancellationResult),
171    /// Server requests cancelled
172    ServerCancelled { server_name: String, count: usize },
173    /// All requests cancelled
174    AllCancelled { count: usize },
175}
176
177/// Manages request cancellation for MCP operations
178///
179/// Features:
180/// - Request registration and tracking
181/// - Manual and timeout-based cancellation
182/// - Cancellation notification
183/// - Event emission for monitoring
184pub struct McpCancellationManager {
185    requests: Arc<RwLock<HashMap<String, CancellableRequest>>>,
186    tokens: Arc<RwLock<HashMap<String, CancellationToken>>>,
187    event_sender: broadcast::Sender<CancellationEvent>,
188}
189
190impl McpCancellationManager {
191    /// Create a new cancellation manager
192    pub fn new() -> Self {
193        let (event_sender, _) = broadcast::channel(256);
194        Self {
195            requests: Arc::new(RwLock::new(HashMap::new())),
196            tokens: Arc::new(RwLock::new(HashMap::new())),
197            event_sender,
198        }
199    }
200
201    /// Subscribe to cancellation events
202    pub fn subscribe(&self) -> broadcast::Receiver<CancellationEvent> {
203        self.event_sender.subscribe()
204    }
205
206    /// Register a cancellable request
207    pub async fn register_request(
208        &self,
209        id: impl Into<String>,
210        server_name: impl Into<String>,
211        method: impl Into<String>,
212        timeout: Option<Duration>,
213    ) -> CancellationToken {
214        let id = id.into();
215        let server_name = server_name.into();
216        let method = method.into();
217
218        let request = CancellableRequest {
219            id: id.clone(),
220            server_name: server_name.clone(),
221            method: method.clone(),
222            start_time: Instant::now(),
223            timeout,
224        };
225
226        let token = CancellationToken::new();
227
228        self.requests.write().await.insert(id.clone(), request);
229        self.tokens.write().await.insert(id.clone(), token.clone());
230
231        let _ = self
232            .event_sender
233            .send(CancellationEvent::RequestRegistered {
234                id,
235                server_name,
236                method,
237            });
238
239        token
240    }
241
242    /// Unregister a request (called when completed successfully)
243    pub async fn unregister_request(&self, id: &str) -> bool {
244        let request = self.requests.write().await.remove(id);
245        self.tokens.write().await.remove(id);
246
247        if let Some(req) = request {
248            let _ = self
249                .event_sender
250                .send(CancellationEvent::RequestUnregistered {
251                    id: id.to_string(),
252                    server_name: req.server_name,
253                });
254            true
255        } else {
256            false
257        }
258    }
259
260    /// Check if a request is registered
261    pub async fn has_request(&self, id: &str) -> bool {
262        self.requests.read().await.contains_key(id)
263    }
264
265    /// Get a registered request
266    pub async fn get_request(&self, id: &str) -> Option<CancellableRequest> {
267        self.requests.read().await.get(id).cloned()
268    }
269
270    /// Get all registered requests
271    pub async fn get_all_requests(&self) -> Vec<CancellableRequest> {
272        self.requests.read().await.values().cloned().collect()
273    }
274
275    /// Get requests for a specific server
276    pub async fn get_server_requests(&self, server_name: &str) -> Vec<CancellableRequest> {
277        self.requests
278            .read()
279            .await
280            .values()
281            .filter(|r| r.server_name == server_name)
282            .cloned()
283            .collect()
284    }
285
286    /// Cancel a request
287    pub async fn cancel_request(
288        &self,
289        id: &str,
290        reason: CancellationReason,
291    ) -> Option<CancellationResult> {
292        let request = self.requests.write().await.remove(id)?;
293        let token = self.tokens.write().await.remove(id);
294
295        // Cancel the token
296        if let Some(t) = token {
297            t.cancel(reason).await;
298        }
299
300        let duration = request.start_time.elapsed();
301        let result = CancellationResult {
302            success: true,
303            reason,
304            request_id: id.to_string(),
305            server_name: request.server_name,
306            duration,
307        };
308
309        let _ = self
310            .event_sender
311            .send(CancellationEvent::RequestCancelled(result.clone()));
312
313        Some(result)
314    }
315
316    /// Cancel all requests for a server
317    pub async fn cancel_server_requests(
318        &self,
319        server_name: &str,
320        reason: CancellationReason,
321    ) -> Vec<CancellationResult> {
322        let requests = self.get_server_requests(server_name).await;
323        let mut results = Vec::new();
324
325        for request in requests {
326            if let Some(result) = self.cancel_request(&request.id, reason).await {
327                results.push(result);
328            }
329        }
330
331        let _ = self.event_sender.send(CancellationEvent::ServerCancelled {
332            server_name: server_name.to_string(),
333            count: results.len(),
334        });
335
336        results
337    }
338
339    /// Cancel all requests
340    pub async fn cancel_all(&self, reason: CancellationReason) -> Vec<CancellationResult> {
341        let requests = self.get_all_requests().await;
342        let mut results = Vec::new();
343
344        for request in requests {
345            if let Some(result) = self.cancel_request(&request.id, reason).await {
346                results.push(result);
347            }
348        }
349
350        let _ = self.event_sender.send(CancellationEvent::AllCancelled {
351            count: results.len(),
352        });
353
354        results
355    }
356
357    /// Get statistics about cancellations
358    pub async fn get_stats(&self) -> CancellationStats {
359        let requests = self.get_all_requests().await;
360
361        let mut by_server: HashMap<String, usize> = HashMap::new();
362        let mut with_timeout = 0;
363
364        for request in &requests {
365            *by_server.entry(request.server_name.clone()).or_insert(0) += 1;
366            if request.timeout.is_some() {
367                with_timeout += 1;
368            }
369        }
370
371        CancellationStats {
372            active_requests: requests.len(),
373            by_server,
374            with_timeout,
375        }
376    }
377
378    /// Get request durations
379    pub async fn get_request_durations(&self) -> Vec<RequestDuration> {
380        self.requests
381            .read()
382            .await
383            .values()
384            .map(|r| RequestDuration {
385                id: r.id.clone(),
386                server_name: r.server_name.clone(),
387                method: r.method.clone(),
388                duration: r.start_time.elapsed(),
389            })
390            .collect()
391    }
392
393    /// Find requests exceeding a duration threshold
394    pub async fn find_long_running_requests(&self, threshold: Duration) -> Vec<CancellableRequest> {
395        self.requests
396            .read()
397            .await
398            .values()
399            .filter(|r| r.start_time.elapsed() > threshold)
400            .cloned()
401            .collect()
402    }
403
404    /// Clean up all requests
405    pub async fn cleanup(&self) {
406        self.requests.write().await.clear();
407        self.tokens.write().await.clear();
408    }
409}
410
411impl Default for McpCancellationManager {
412    fn default() -> Self {
413        Self::new()
414    }
415}
416
417/// Cancellation statistics
418#[derive(Debug, Clone)]
419pub struct CancellationStats {
420    /// Number of active requests
421    pub active_requests: usize,
422    /// Requests by server
423    pub by_server: HashMap<String, usize>,
424    /// Requests with timeout
425    pub with_timeout: usize,
426}
427
428/// Request duration information
429#[derive(Debug, Clone)]
430pub struct RequestDuration {
431    /// Request ID
432    pub id: String,
433    /// Server name
434    pub server_name: String,
435    /// Method name
436    pub method: String,
437    /// Duration since request started
438    pub duration: Duration,
439}
440
441/// Cancelled notification for MCP protocol
442#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
443pub struct CancelledNotification {
444    /// Request ID that was cancelled
445    pub request_id: String,
446    /// Optional reason for cancellation
447    #[serde(skip_serializing_if = "Option::is_none")]
448    pub reason: Option<String>,
449}
450
451impl CancelledNotification {
452    /// Create a new cancelled notification
453    pub fn new(request_id: impl Into<String>, reason: Option<String>) -> Self {
454        Self {
455            request_id: request_id.into(),
456            reason,
457        }
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464
465    #[test]
466    fn test_cancellation_reason_display() {
467        assert_eq!(
468            CancellationReason::UserCancelled.to_string(),
469            "Request cancelled by user"
470        );
471        assert_eq!(CancellationReason::Timeout.to_string(), "Request timed out");
472        assert_eq!(
473            CancellationReason::Shutdown.to_string(),
474            "Cancelled due to shutdown"
475        );
476    }
477
478    #[tokio::test]
479    async fn test_cancellation_token_new() {
480        let token = CancellationToken::new();
481        assert!(!token.is_cancelled().await);
482        assert!(token.reason().await.is_none());
483    }
484
485    #[tokio::test]
486    async fn test_cancellation_token_cancel() {
487        let token = CancellationToken::new();
488        token.cancel(CancellationReason::UserCancelled).await;
489
490        assert!(token.is_cancelled().await);
491        assert_eq!(
492            token.reason().await,
493            Some(CancellationReason::UserCancelled)
494        );
495    }
496
497    #[tokio::test]
498    async fn test_cancellation_token_throw_if_cancelled() {
499        let token = CancellationToken::new();
500        assert!(token.throw_if_cancelled().await.is_ok());
501
502        token.cancel(CancellationReason::Timeout).await;
503        assert!(token.throw_if_cancelled().await.is_err());
504    }
505
506    #[tokio::test]
507    async fn test_manager_register_request() {
508        let manager = McpCancellationManager::new();
509        let token = manager
510            .register_request("req-1", "server-1", "tools/call", None)
511            .await;
512
513        assert!(!token.is_cancelled().await);
514        assert!(manager.has_request("req-1").await);
515    }
516
517    #[tokio::test]
518    async fn test_manager_unregister_request() {
519        let manager = McpCancellationManager::new();
520        manager
521            .register_request("req-1", "server-1", "tools/call", None)
522            .await;
523
524        assert!(manager.unregister_request("req-1").await);
525        assert!(!manager.has_request("req-1").await);
526    }
527
528    #[tokio::test]
529    async fn test_manager_cancel_request() {
530        let manager = McpCancellationManager::new();
531        let token = manager
532            .register_request("req-1", "server-1", "tools/call", None)
533            .await;
534
535        let result = manager
536            .cancel_request("req-1", CancellationReason::UserCancelled)
537            .await;
538
539        assert!(result.is_some());
540        let result = result.unwrap();
541        assert!(result.success);
542        assert_eq!(result.reason, CancellationReason::UserCancelled);
543        assert!(token.is_cancelled().await);
544    }
545
546    #[tokio::test]
547    async fn test_manager_cancel_server_requests() {
548        let manager = McpCancellationManager::new();
549        manager
550            .register_request("req-1", "server-1", "tools/call", None)
551            .await;
552        manager
553            .register_request("req-2", "server-1", "resources/read", None)
554            .await;
555        manager
556            .register_request("req-3", "server-2", "tools/call", None)
557            .await;
558
559        let results = manager
560            .cancel_server_requests("server-1", CancellationReason::Shutdown)
561            .await;
562
563        assert_eq!(results.len(), 2);
564        assert!(!manager.has_request("req-1").await);
565        assert!(!manager.has_request("req-2").await);
566        assert!(manager.has_request("req-3").await);
567    }
568
569    #[tokio::test]
570    async fn test_manager_cancel_all() {
571        let manager = McpCancellationManager::new();
572        manager
573            .register_request("req-1", "server-1", "tools/call", None)
574            .await;
575        manager
576            .register_request("req-2", "server-2", "tools/call", None)
577            .await;
578
579        let results = manager.cancel_all(CancellationReason::Shutdown).await;
580
581        assert_eq!(results.len(), 2);
582        assert!(manager.get_all_requests().await.is_empty());
583    }
584
585    #[tokio::test]
586    async fn test_manager_get_stats() {
587        let manager = McpCancellationManager::new();
588        manager
589            .register_request(
590                "req-1",
591                "server-1",
592                "tools/call",
593                Some(Duration::from_secs(30)),
594            )
595            .await;
596        manager
597            .register_request("req-2", "server-1", "resources/read", None)
598            .await;
599        manager
600            .register_request("req-3", "server-2", "tools/call", None)
601            .await;
602
603        let stats = manager.get_stats().await;
604
605        assert_eq!(stats.active_requests, 3);
606        assert_eq!(stats.by_server.get("server-1"), Some(&2));
607        assert_eq!(stats.by_server.get("server-2"), Some(&1));
608        assert_eq!(stats.with_timeout, 1);
609    }
610
611    #[test]
612    fn test_cancelled_notification() {
613        let notification = CancelledNotification::new("req-1", Some("User cancelled".to_string()));
614        assert_eq!(notification.request_id, "req-1");
615        assert_eq!(notification.reason, Some("User cancelled".to_string()));
616    }
617}