1use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::Instant;
15use tokio::sync::{broadcast, RwLock};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
19#[serde(rename_all = "snake_case")]
20pub enum NotificationType {
21 Progress,
23 Cancelled,
25 ResourcesListChanged,
27 ResourcesUpdated,
29 ToolsListChanged,
31 PromptsListChanged,
33 RootsListChanged,
35 Custom,
37}
38
39impl std::fmt::Display for NotificationType {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 Self::Progress => write!(f, "progress"),
43 Self::Cancelled => write!(f, "cancelled"),
44 Self::ResourcesListChanged => write!(f, "resources/list_changed"),
45 Self::ResourcesUpdated => write!(f, "resources/updated"),
46 Self::ToolsListChanged => write!(f, "tools/list_changed"),
47 Self::PromptsListChanged => write!(f, "prompts/list_changed"),
48 Self::RootsListChanged => write!(f, "roots/list_changed"),
49 Self::Custom => write!(f, "custom"),
50 }
51 }
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct Notification {
57 pub notification_type: NotificationType,
59 pub server_name: String,
61 pub timestamp: DateTime<Utc>,
63 pub method: String,
65 #[serde(skip_serializing_if = "Option::is_none")]
67 pub params: Option<serde_json::Value>,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct ProgressNotification {
73 pub progress_token: String,
75 pub progress: u64,
77 #[serde(skip_serializing_if = "Option::is_none")]
79 pub total: Option<u64>,
80}
81
82#[derive(Debug, Clone)]
84pub struct ProgressState {
85 pub server_name: String,
87 pub token: String,
89 pub progress: u64,
91 pub total: Option<u64>,
93 pub start_time: Instant,
95 pub last_update: Instant,
97}
98
99#[derive(Debug, Clone)]
101pub enum NotificationEvent {
102 Notification(Notification),
104 Progress {
106 server_name: String,
107 token: String,
108 progress: u64,
109 total: Option<u64>,
110 },
111 ProgressComplete { server_name: String, token: String },
113 Cancelled {
115 server_name: String,
116 request_id: String,
117 reason: Option<String>,
118 },
119 ListChanged {
121 server_name: String,
122 list_type: NotificationType,
123 },
124 ResourceUpdated { server_name: String, uri: String },
126 HistoryCleared { count: usize },
128}
129
130pub struct McpNotificationManager {
132 history: Arc<RwLock<Vec<Notification>>>,
133 progress_states: Arc<RwLock<HashMap<String, ProgressState>>>,
134 max_history_size: usize,
135 event_sender: broadcast::Sender<NotificationEvent>,
136}
137
138impl McpNotificationManager {
139 pub fn new(max_history_size: usize) -> Self {
141 let (event_sender, _) = broadcast::channel(256);
142 Self {
143 history: Arc::new(RwLock::new(Vec::new())),
144 progress_states: Arc::new(RwLock::new(HashMap::new())),
145 max_history_size,
146 event_sender,
147 }
148 }
149
150 pub fn subscribe(&self) -> broadcast::Receiver<NotificationEvent> {
152 self.event_sender.subscribe()
153 }
154
155 pub async fn handle_notification(
157 &self,
158 server_name: &str,
159 method: &str,
160 params: Option<serde_json::Value>,
161 ) {
162 let notification_type = Self::get_notification_type(method);
163
164 let notification = Notification {
165 notification_type,
166 server_name: server_name.to_string(),
167 timestamp: Utc::now(),
168 method: method.to_string(),
169 params: params.clone(),
170 };
171
172 self.add_to_history(notification.clone()).await;
174
175 let _ = self
177 .event_sender
178 .send(NotificationEvent::Notification(notification.clone()));
179
180 self.handle_specific_type(server_name, notification_type, params)
182 .await;
183 }
184
185 fn get_notification_type(method: &str) -> NotificationType {
187 match method {
188 "notifications/progress" => NotificationType::Progress,
189 "notifications/cancelled" => NotificationType::Cancelled,
190 "notifications/resources/list_changed" => NotificationType::ResourcesListChanged,
191 "notifications/resources/updated" => NotificationType::ResourcesUpdated,
192 "notifications/tools/list_changed" => NotificationType::ToolsListChanged,
193 "notifications/prompts/list_changed" => NotificationType::PromptsListChanged,
194 m if m.contains("roots/list_changed") => NotificationType::RootsListChanged,
195 _ => NotificationType::Custom,
196 }
197 }
198
199 async fn handle_specific_type(
201 &self,
202 server_name: &str,
203 notification_type: NotificationType,
204 params: Option<serde_json::Value>,
205 ) {
206 match notification_type {
207 NotificationType::Progress => {
208 if let Some(params) = params {
209 self.handle_progress(server_name, params).await;
210 }
211 }
212 NotificationType::Cancelled => {
213 if let Some(params) = params {
214 self.handle_cancelled(server_name, params).await;
215 }
216 }
217 NotificationType::ResourcesListChanged
218 | NotificationType::ToolsListChanged
219 | NotificationType::PromptsListChanged
220 | NotificationType::RootsListChanged => {
221 let _ = self.event_sender.send(NotificationEvent::ListChanged {
222 server_name: server_name.to_string(),
223 list_type: notification_type,
224 });
225 }
226 NotificationType::ResourcesUpdated => {
227 if let Some(params) = params {
228 if let Some(uri) = params.get("uri").and_then(|v| v.as_str()) {
229 let _ = self.event_sender.send(NotificationEvent::ResourceUpdated {
230 server_name: server_name.to_string(),
231 uri: uri.to_string(),
232 });
233 }
234 }
235 }
236 NotificationType::Custom => {}
237 }
238 }
239
240 async fn handle_progress(&self, server_name: &str, params: serde_json::Value) {
242 let progress_token = params
243 .get("progressToken")
244 .and_then(|v| v.as_str())
245 .unwrap_or("unknown")
246 .to_string();
247 let progress = params.get("progress").and_then(|v| v.as_u64()).unwrap_or(0);
248 let total = params.get("total").and_then(|v| v.as_u64());
249
250 let key = format!("{}:{}", server_name, progress_token);
251 let now = Instant::now();
252
253 let mut states = self.progress_states.write().await;
254 let start_time = states.get(&key).map(|e| e.start_time).unwrap_or(now);
255
256 states.insert(
257 key.clone(),
258 ProgressState {
259 server_name: server_name.to_string(),
260 token: progress_token.clone(),
261 progress,
262 total,
263 start_time,
264 last_update: now,
265 },
266 );
267
268 let _ = self.event_sender.send(NotificationEvent::Progress {
269 server_name: server_name.to_string(),
270 token: progress_token.clone(),
271 progress,
272 total,
273 });
274
275 let is_complete = total.map(|t| progress >= t).unwrap_or(false) || progress == 100;
277 if is_complete {
278 states.remove(&key);
279 let _ = self.event_sender.send(NotificationEvent::ProgressComplete {
280 server_name: server_name.to_string(),
281 token: progress_token,
282 });
283 }
284 }
285
286 async fn handle_cancelled(&self, server_name: &str, params: serde_json::Value) {
288 let request_id = params
289 .get("requestId")
290 .and_then(|v| v.as_str())
291 .unwrap_or("unknown")
292 .to_string();
293 let reason = params
294 .get("reason")
295 .and_then(|v| v.as_str())
296 .map(String::from);
297
298 let _ = self.event_sender.send(NotificationEvent::Cancelled {
299 server_name: server_name.to_string(),
300 request_id,
301 reason,
302 });
303 }
304
305 async fn add_to_history(&self, notification: Notification) {
307 let mut history = self.history.write().await;
308 history.push(notification);
309
310 if history.len() > self.max_history_size {
311 history.remove(0);
312 }
313 }
314
315 pub async fn get_history(&self, filter: Option<NotificationFilter>) -> Vec<Notification> {
317 let history = self.history.read().await;
318 let mut filtered: Vec<_> = history.iter().cloned().collect();
319
320 if let Some(f) = filter {
321 if let Some(server_name) = f.server_name {
322 filtered.retain(|n| n.server_name == server_name);
323 }
324 if let Some(notification_type) = f.notification_type {
325 filtered.retain(|n| n.notification_type == notification_type);
326 }
327 if let Some(since) = f.since {
328 filtered.retain(|n| n.timestamp >= since);
329 }
330 if let Some(limit) = f.limit {
331 let len = filtered.len();
332 if len > limit {
333 filtered = filtered.into_iter().skip(len - limit).collect();
334 }
335 }
336 }
337
338 filtered
339 }
340
341 pub async fn clear_history(&self) {
343 let mut history = self.history.write().await;
344 let count = history.len();
345 history.clear();
346 let _ = self
347 .event_sender
348 .send(NotificationEvent::HistoryCleared { count });
349 }
350
351 pub async fn clear_server_history(&self, server_name: &str) -> usize {
353 let mut history = self.history.write().await;
354 let before = history.len();
355 history.retain(|n| n.server_name != server_name);
356 before - history.len()
357 }
358
359 pub async fn get_active_progress(&self) -> Vec<ProgressState> {
361 self.progress_states
362 .read()
363 .await
364 .values()
365 .cloned()
366 .collect()
367 }
368
369 pub async fn get_server_progress(&self, server_name: &str) -> Vec<ProgressState> {
371 self.progress_states
372 .read()
373 .await
374 .values()
375 .filter(|p| p.server_name == server_name)
376 .cloned()
377 .collect()
378 }
379
380 pub async fn cancel_progress(&self, server_name: &str, token: &str) -> bool {
382 let key = format!("{}:{}", server_name, token);
383 self.progress_states.write().await.remove(&key).is_some()
384 }
385
386 pub async fn clear_progress(&self) {
388 self.progress_states.write().await.clear();
389 }
390
391 pub async fn get_stats(&self) -> NotificationStats {
393 let history = self.history.read().await;
394
395 let mut by_type: HashMap<NotificationType, usize> = HashMap::new();
396 let mut by_server: HashMap<String, usize> = HashMap::new();
397
398 for notification in history.iter() {
399 *by_type.entry(notification.notification_type).or_insert(0) += 1;
400 *by_server
401 .entry(notification.server_name.clone())
402 .or_insert(0) += 1;
403 }
404
405 NotificationStats {
406 total_notifications: history.len(),
407 max_history_size: self.max_history_size,
408 active_progress: self.progress_states.read().await.len(),
409 by_type,
410 by_server,
411 }
412 }
413}
414
415impl Default for McpNotificationManager {
416 fn default() -> Self {
417 Self::new(100)
418 }
419}
420
421#[derive(Debug, Clone, Default)]
423pub struct NotificationFilter {
424 pub server_name: Option<String>,
426 pub notification_type: Option<NotificationType>,
428 pub since: Option<DateTime<Utc>>,
430 pub limit: Option<usize>,
432}
433
434#[derive(Debug, Clone)]
436pub struct NotificationStats {
437 pub total_notifications: usize,
439 pub max_history_size: usize,
441 pub active_progress: usize,
443 pub by_type: HashMap<NotificationType, usize>,
445 pub by_server: HashMap<String, usize>,
447}
448
449pub fn create_progress_params(
451 token: &str,
452 progress: u64,
453 total: Option<u64>,
454) -> ProgressNotification {
455 ProgressNotification {
456 progress_token: token.to_string(),
457 progress,
458 total,
459 }
460}
461
462#[cfg(test)]
463mod tests {
464 use super::*;
465
466 #[test]
467 fn test_notification_type_display() {
468 assert_eq!(NotificationType::Progress.to_string(), "progress");
469 assert_eq!(
470 NotificationType::ToolsListChanged.to_string(),
471 "tools/list_changed"
472 );
473 }
474
475 #[test]
476 fn test_get_notification_type() {
477 assert_eq!(
478 McpNotificationManager::get_notification_type("notifications/progress"),
479 NotificationType::Progress
480 );
481 assert_eq!(
482 McpNotificationManager::get_notification_type("notifications/tools/list_changed"),
483 NotificationType::ToolsListChanged
484 );
485 assert_eq!(
486 McpNotificationManager::get_notification_type("custom/event"),
487 NotificationType::Custom
488 );
489 }
490
491 #[tokio::test]
492 async fn test_handle_notification() {
493 let manager = McpNotificationManager::new(100);
494
495 manager
496 .handle_notification("test-server", "notifications/tools/list_changed", None)
497 .await;
498
499 let history = manager.get_history(None).await;
500 assert_eq!(history.len(), 1);
501 assert_eq!(history[0].server_name, "test-server");
502 assert_eq!(
503 history[0].notification_type,
504 NotificationType::ToolsListChanged
505 );
506 }
507
508 #[tokio::test]
509 async fn test_handle_progress() {
510 let manager = McpNotificationManager::new(100);
511
512 let params = serde_json::json!({
513 "progressToken": "token-1",
514 "progress": 50,
515 "total": 100
516 });
517
518 manager
519 .handle_notification("test-server", "notifications/progress", Some(params))
520 .await;
521
522 let progress = manager.get_active_progress().await;
523 assert_eq!(progress.len(), 1);
524 assert_eq!(progress[0].progress, 50);
525 assert_eq!(progress[0].total, Some(100));
526 }
527
528 #[tokio::test]
529 async fn test_progress_complete() {
530 let manager = McpNotificationManager::new(100);
531
532 let params = serde_json::json!({
533 "progressToken": "token-1",
534 "progress": 100,
535 "total": 100
536 });
537
538 manager
539 .handle_notification("test-server", "notifications/progress", Some(params))
540 .await;
541
542 let progress = manager.get_active_progress().await;
544 assert!(progress.is_empty());
545 }
546
547 #[tokio::test]
548 async fn test_history_filter() {
549 let manager = McpNotificationManager::new(100);
550
551 manager
552 .handle_notification("server-1", "notifications/progress", None)
553 .await;
554 manager
555 .handle_notification("server-2", "notifications/tools/list_changed", None)
556 .await;
557 manager
558 .handle_notification("server-1", "notifications/cancelled", None)
559 .await;
560
561 let filter = NotificationFilter {
562 server_name: Some("server-1".to_string()),
563 ..Default::default()
564 };
565
566 let history = manager.get_history(Some(filter)).await;
567 assert_eq!(history.len(), 2);
568 }
569
570 #[tokio::test]
571 async fn test_clear_history() {
572 let manager = McpNotificationManager::new(100);
573
574 manager
575 .handle_notification("test-server", "notifications/progress", None)
576 .await;
577 manager
578 .handle_notification("test-server", "notifications/cancelled", None)
579 .await;
580
581 manager.clear_history().await;
582
583 let history = manager.get_history(None).await;
584 assert!(history.is_empty());
585 }
586
587 #[tokio::test]
588 async fn test_get_stats() {
589 let manager = McpNotificationManager::new(100);
590
591 manager
592 .handle_notification("server-1", "notifications/progress", None)
593 .await;
594 manager
595 .handle_notification("server-1", "notifications/progress", None)
596 .await;
597 manager
598 .handle_notification("server-2", "notifications/tools/list_changed", None)
599 .await;
600
601 let stats = manager.get_stats().await;
602 assert_eq!(stats.total_notifications, 3);
603 assert_eq!(stats.by_server.get("server-1"), Some(&2));
604 assert_eq!(stats.by_server.get("server-2"), Some(&1));
605 }
606
607 #[test]
608 fn test_create_progress_params() {
609 let params = create_progress_params("token-1", 50, Some(100));
610 assert_eq!(params.progress_token, "token-1");
611 assert_eq!(params.progress, 50);
612 assert_eq!(params.total, Some(100));
613 }
614}