Skip to main content

aster/mcp/
notifications.rs

1//! MCP Notifications Module
2//!
3//! Handles notification messages from MCP servers. Notifications are one-way
4//! messages that don't require a response, used for:
5//! - Progress updates
6//! - Resource/tool/prompt list changes
7//! - Request cancellations
8//! - Custom server events
9
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15use tokio::sync::{broadcast, RwLock};
16
17/// Notification types
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum NotificationType {
21    /// Progress update
22    Progress,
23    /// Request cancelled
24    Cancelled,
25    /// Resources list changed
26    ResourcesListChanged,
27    /// Resources updated
28    ResourcesUpdated,
29    /// Tools list changed
30    ToolsListChanged,
31    /// Prompts list changed
32    PromptsListChanged,
33    /// Roots list changed
34    RootsListChanged,
35    /// Custom notification
36    Custom,
37}
38
39impl std::fmt::Display for NotificationType {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            Self::Progress => write!(f, "progress"),
43            Self::Cancelled => write!(f, "cancelled"),
44            Self::ResourcesListChanged => write!(f, "resources/list_changed"),
45            Self::ResourcesUpdated => write!(f, "resources/updated"),
46            Self::ToolsListChanged => write!(f, "tools/list_changed"),
47            Self::PromptsListChanged => write!(f, "prompts/list_changed"),
48            Self::RootsListChanged => write!(f, "roots/list_changed"),
49            Self::Custom => write!(f, "custom"),
50        }
51    }
52}
53
54/// Base notification
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Notification {
57    /// Notification type
58    pub notification_type: NotificationType,
59    /// Server name
60    pub server_name: String,
61    /// Timestamp
62    pub timestamp: DateTime<Utc>,
63    /// Method name
64    pub method: String,
65    /// Optional parameters
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub params: Option<serde_json::Value>,
68}
69
70/// Progress notification parameters
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ProgressNotification {
73    /// Progress token
74    pub progress_token: String,
75    /// Current progress value
76    pub progress: u64,
77    /// Total value (if known)
78    #[serde(skip_serializing_if = "Option::is_none")]
79    pub total: Option<u64>,
80}
81
82/// Progress state tracking
83#[derive(Debug, Clone)]
84pub struct ProgressState {
85    /// Server name
86    pub server_name: String,
87    /// Progress token
88    pub token: String,
89    /// Current progress
90    pub progress: u64,
91    /// Total (if known)
92    pub total: Option<u64>,
93    /// Start time
94    pub start_time: Instant,
95    /// Last update time
96    pub last_update: Instant,
97}
98
99/// Notification event for broadcasting
100#[derive(Debug, Clone)]
101pub enum NotificationEvent {
102    /// General notification received
103    Notification(Notification),
104    /// Progress update
105    Progress {
106        server_name: String,
107        token: String,
108        progress: u64,
109        total: Option<u64>,
110    },
111    /// Progress completed
112    ProgressComplete { server_name: String, token: String },
113    /// Request cancelled
114    Cancelled {
115        server_name: String,
116        request_id: String,
117        reason: Option<String>,
118    },
119    /// List changed
120    ListChanged {
121        server_name: String,
122        list_type: NotificationType,
123    },
124    /// Resource updated
125    ResourceUpdated { server_name: String, uri: String },
126    /// History cleared
127    HistoryCleared { count: usize },
128}
129
130/// Manages notifications from MCP servers
131pub struct McpNotificationManager {
132    history: Arc<RwLock<Vec<Notification>>>,
133    progress_states: Arc<RwLock<HashMap<String, ProgressState>>>,
134    max_history_size: usize,
135    event_sender: broadcast::Sender<NotificationEvent>,
136}
137
138impl McpNotificationManager {
139    /// Create a new notification manager
140    pub fn new(max_history_size: usize) -> Self {
141        let (event_sender, _) = broadcast::channel(256);
142        Self {
143            history: Arc::new(RwLock::new(Vec::new())),
144            progress_states: Arc::new(RwLock::new(HashMap::new())),
145            max_history_size,
146            event_sender,
147        }
148    }
149
150    /// Subscribe to notification events
151    pub fn subscribe(&self) -> broadcast::Receiver<NotificationEvent> {
152        self.event_sender.subscribe()
153    }
154
155    /// Handle a notification from a server
156    pub async fn handle_notification(
157        &self,
158        server_name: &str,
159        method: &str,
160        params: Option<serde_json::Value>,
161    ) {
162        let notification_type = Self::get_notification_type(method);
163
164        let notification = Notification {
165            notification_type,
166            server_name: server_name.to_string(),
167            timestamp: Utc::now(),
168            method: method.to_string(),
169            params: params.clone(),
170        };
171
172        // Add to history
173        self.add_to_history(notification.clone()).await;
174
175        // Emit general event
176        let _ = self
177            .event_sender
178            .send(NotificationEvent::Notification(notification.clone()));
179
180        // Handle specific types
181        self.handle_specific_type(server_name, notification_type, params)
182            .await;
183    }
184
185    /// Get notification type from method name
186    fn get_notification_type(method: &str) -> NotificationType {
187        match method {
188            "notifications/progress" => NotificationType::Progress,
189            "notifications/cancelled" => NotificationType::Cancelled,
190            "notifications/resources/list_changed" => NotificationType::ResourcesListChanged,
191            "notifications/resources/updated" => NotificationType::ResourcesUpdated,
192            "notifications/tools/list_changed" => NotificationType::ToolsListChanged,
193            "notifications/prompts/list_changed" => NotificationType::PromptsListChanged,
194            m if m.contains("roots/list_changed") => NotificationType::RootsListChanged,
195            _ => NotificationType::Custom,
196        }
197    }
198
199    /// Handle specific notification types
200    async fn handle_specific_type(
201        &self,
202        server_name: &str,
203        notification_type: NotificationType,
204        params: Option<serde_json::Value>,
205    ) {
206        match notification_type {
207            NotificationType::Progress => {
208                if let Some(params) = params {
209                    self.handle_progress(server_name, params).await;
210                }
211            }
212            NotificationType::Cancelled => {
213                if let Some(params) = params {
214                    self.handle_cancelled(server_name, params).await;
215                }
216            }
217            NotificationType::ResourcesListChanged
218            | NotificationType::ToolsListChanged
219            | NotificationType::PromptsListChanged
220            | NotificationType::RootsListChanged => {
221                let _ = self.event_sender.send(NotificationEvent::ListChanged {
222                    server_name: server_name.to_string(),
223                    list_type: notification_type,
224                });
225            }
226            NotificationType::ResourcesUpdated => {
227                if let Some(params) = params {
228                    if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
229                        let _ = self.event_sender.send(NotificationEvent::ResourceUpdated {
230                            server_name: server_name.to_string(),
231                            uri: uri.to_string(),
232                        });
233                    }
234                }
235            }
236            NotificationType::Custom => {}
237        }
238    }
239
240    /// Handle progress notification
241    async fn handle_progress(&self, server_name: &str, params: serde_json::Value) {
242        let progress_token = params
243            .get("progressToken")
244            .and_then(|v| v.as_str())
245            .unwrap_or("unknown")
246            .to_string();
247        let progress = params.get("progress").and_then(|v| v.as_u64()).unwrap_or(0);
248        let total = params.get("total").and_then(|v| v.as_u64());
249
250        let key = format!("{}:{}", server_name, progress_token);
251        let now = Instant::now();
252
253        let mut states = self.progress_states.write().await;
254        let start_time = states.get(&key).map(|e| e.start_time).unwrap_or(now);
255
256        states.insert(
257            key.clone(),
258            ProgressState {
259                server_name: server_name.to_string(),
260                token: progress_token.clone(),
261                progress,
262                total,
263                start_time,
264                last_update: now,
265            },
266        );
267
268        let _ = self.event_sender.send(NotificationEvent::Progress {
269            server_name: server_name.to_string(),
270            token: progress_token.clone(),
271            progress,
272            total,
273        });
274
275        // Check if complete
276        let is_complete = total.map(|t| progress >= t).unwrap_or(false) || progress == 100;
277        if is_complete {
278            states.remove(&key);
279            let _ = self.event_sender.send(NotificationEvent::ProgressComplete {
280                server_name: server_name.to_string(),
281                token: progress_token,
282            });
283        }
284    }
285
286    /// Handle cancelled notification
287    async fn handle_cancelled(&self, server_name: &str, params: serde_json::Value) {
288        let request_id = params
289            .get("requestId")
290            .and_then(|v| v.as_str())
291            .unwrap_or("unknown")
292            .to_string();
293        let reason = params
294            .get("reason")
295            .and_then(|v| v.as_str())
296            .map(String::from);
297
298        let _ = self.event_sender.send(NotificationEvent::Cancelled {
299            server_name: server_name.to_string(),
300            request_id,
301            reason,
302        });
303    }
304
305    /// Add notification to history
306    async fn add_to_history(&self, notification: Notification) {
307        let mut history = self.history.write().await;
308        history.push(notification);
309
310        if history.len() > self.max_history_size {
311            history.remove(0);
312        }
313    }
314
315    /// Get notification history
316    pub async fn get_history(&self, filter: Option<NotificationFilter>) -> Vec<Notification> {
317        let history = self.history.read().await;
318        let mut filtered: Vec<_> = history.iter().cloned().collect();
319
320        if let Some(f) = filter {
321            if let Some(server_name) = f.server_name {
322                filtered.retain(|n| n.server_name == server_name);
323            }
324            if let Some(notification_type) = f.notification_type {
325                filtered.retain(|n| n.notification_type == notification_type);
326            }
327            if let Some(since) = f.since {
328                filtered.retain(|n| n.timestamp >= since);
329            }
330            if let Some(limit) = f.limit {
331                let len = filtered.len();
332                if len > limit {
333                    filtered = filtered.into_iter().skip(len - limit).collect();
334                }
335            }
336        }
337
338        filtered
339    }
340
341    /// Clear history
342    pub async fn clear_history(&self) {
343        let mut history = self.history.write().await;
344        let count = history.len();
345        history.clear();
346        let _ = self
347            .event_sender
348            .send(NotificationEvent::HistoryCleared { count });
349    }
350
351    /// Clear history for a specific server
352    pub async fn clear_server_history(&self, server_name: &str) -> usize {
353        let mut history = self.history.write().await;
354        let before = history.len();
355        history.retain(|n| n.server_name != server_name);
356        before - history.len()
357    }
358
359    /// Get active progress operations
360    pub async fn get_active_progress(&self) -> Vec<ProgressState> {
361        self.progress_states
362            .read()
363            .await
364            .values()
365            .cloned()
366            .collect()
367    }
368
369    /// Get progress for a specific server
370    pub async fn get_server_progress(&self, server_name: &str) -> Vec<ProgressState> {
371        self.progress_states
372            .read()
373            .await
374            .values()
375            .filter(|p| p.server_name == server_name)
376            .cloned()
377            .collect()
378    }
379
380    /// Cancel progress tracking for a token
381    pub async fn cancel_progress(&self, server_name: &str, token: &str) -> bool {
382        let key = format!("{}:{}", server_name, token);
383        self.progress_states.write().await.remove(&key).is_some()
384    }
385
386    /// Clear all progress tracking
387    pub async fn clear_progress(&self) {
388        self.progress_states.write().await.clear();
389    }
390
391    /// Get statistics
392    pub async fn get_stats(&self) -> NotificationStats {
393        let history = self.history.read().await;
394
395        let mut by_type: HashMap<NotificationType, usize> = HashMap::new();
396        let mut by_server: HashMap<String, usize> = HashMap::new();
397
398        for notification in history.iter() {
399            *by_type.entry(notification.notification_type).or_insert(0) += 1;
400            *by_server
401                .entry(notification.server_name.clone())
402                .or_insert(0) += 1;
403        }
404
405        NotificationStats {
406            total_notifications: history.len(),
407            max_history_size: self.max_history_size,
408            active_progress: self.progress_states.read().await.len(),
409            by_type,
410            by_server,
411        }
412    }
413}
414
415impl Default for McpNotificationManager {
416    fn default() -> Self {
417        Self::new(100)
418    }
419}
420
421/// Filter for notification history
422#[derive(Debug, Clone, Default)]
423pub struct NotificationFilter {
424    /// Filter by server name
425    pub server_name: Option<String>,
426    /// Filter by notification type
427    pub notification_type: Option<NotificationType>,
428    /// Filter by timestamp (since)
429    pub since: Option<DateTime<Utc>>,
430    /// Limit number of results
431    pub limit: Option<usize>,
432}
433
434/// Notification statistics
435#[derive(Debug, Clone)]
436pub struct NotificationStats {
437    /// Total notifications in history
438    pub total_notifications: usize,
439    /// Maximum history size
440    pub max_history_size: usize,
441    /// Active progress operations
442    pub active_progress: usize,
443    /// Notifications by type
444    pub by_type: HashMap<NotificationType, usize>,
445    /// Notifications by server
446    pub by_server: HashMap<String, usize>,
447}
448
449/// Create progress notification parameters
450pub fn create_progress_params(
451    token: &str,
452    progress: u64,
453    total: Option<u64>,
454) -> ProgressNotification {
455    ProgressNotification {
456        progress_token: token.to_string(),
457        progress,
458        total,
459    }
460}
461
462#[cfg(test)]
463mod tests {
464    use super::*;
465
466    #[test]
467    fn test_notification_type_display() {
468        assert_eq!(NotificationType::Progress.to_string(), "progress");
469        assert_eq!(
470            NotificationType::ToolsListChanged.to_string(),
471            "tools/list_changed"
472        );
473    }
474
475    #[test]
476    fn test_get_notification_type() {
477        assert_eq!(
478            McpNotificationManager::get_notification_type("notifications/progress"),
479            NotificationType::Progress
480        );
481        assert_eq!(
482            McpNotificationManager::get_notification_type("notifications/tools/list_changed"),
483            NotificationType::ToolsListChanged
484        );
485        assert_eq!(
486            McpNotificationManager::get_notification_type("custom/event"),
487            NotificationType::Custom
488        );
489    }
490
491    #[tokio::test]
492    async fn test_handle_notification() {
493        let manager = McpNotificationManager::new(100);
494
495        manager
496            .handle_notification("test-server", "notifications/tools/list_changed", None)
497            .await;
498
499        let history = manager.get_history(None).await;
500        assert_eq!(history.len(), 1);
501        assert_eq!(history[0].server_name, "test-server");
502        assert_eq!(
503            history[0].notification_type,
504            NotificationType::ToolsListChanged
505        );
506    }
507
508    #[tokio::test]
509    async fn test_handle_progress() {
510        let manager = McpNotificationManager::new(100);
511
512        let params = serde_json::json!({
513            "progressToken": "token-1",
514            "progress": 50,
515            "total": 100
516        });
517
518        manager
519            .handle_notification("test-server", "notifications/progress", Some(params))
520            .await;
521
522        let progress = manager.get_active_progress().await;
523        assert_eq!(progress.len(), 1);
524        assert_eq!(progress[0].progress, 50);
525        assert_eq!(progress[0].total, Some(100));
526    }
527
528    #[tokio::test]
529    async fn test_progress_complete() {
530        let manager = McpNotificationManager::new(100);
531
532        let params = serde_json::json!({
533            "progressToken": "token-1",
534            "progress": 100,
535            "total": 100
536        });
537
538        manager
539            .handle_notification("test-server", "notifications/progress", Some(params))
540            .await;
541
542        // Progress should be removed when complete
543        let progress = manager.get_active_progress().await;
544        assert!(progress.is_empty());
545    }
546
547    #[tokio::test]
548    async fn test_history_filter() {
549        let manager = McpNotificationManager::new(100);
550
551        manager
552            .handle_notification("server-1", "notifications/progress", None)
553            .await;
554        manager
555            .handle_notification("server-2", "notifications/tools/list_changed", None)
556            .await;
557        manager
558            .handle_notification("server-1", "notifications/cancelled", None)
559            .await;
560
561        let filter = NotificationFilter {
562            server_name: Some("server-1".to_string()),
563            ..Default::default()
564        };
565
566        let history = manager.get_history(Some(filter)).await;
567        assert_eq!(history.len(), 2);
568    }
569
570    #[tokio::test]
571    async fn test_clear_history() {
572        let manager = McpNotificationManager::new(100);
573
574        manager
575            .handle_notification("test-server", "notifications/progress", None)
576            .await;
577        manager
578            .handle_notification("test-server", "notifications/cancelled", None)
579            .await;
580
581        manager.clear_history().await;
582
583        let history = manager.get_history(None).await;
584        assert!(history.is_empty());
585    }
586
587    #[tokio::test]
588    async fn test_get_stats() {
589        let manager = McpNotificationManager::new(100);
590
591        manager
592            .handle_notification("server-1", "notifications/progress", None)
593            .await;
594        manager
595            .handle_notification("server-1", "notifications/progress", None)
596            .await;
597        manager
598            .handle_notification("server-2", "notifications/tools/list_changed", None)
599            .await;
600
601        let stats = manager.get_stats().await;
602        assert_eq!(stats.total_notifications, 3);
603        assert_eq!(stats.by_server.get("server-1"), Some(&2));
604        assert_eq!(stats.by_server.get("server-2"), Some(&1));
605    }
606
607    #[test]
608    fn test_create_progress_params() {
609        let params = create_progress_params("token-1", 50, Some(100));
610        assert_eq!(params.progress_token, "token-1");
611        assert_eq!(params.progress, 50);
612        assert_eq!(params.total, Some(100));
613    }
614}