mcp_host/managers/
subscription.rs

1//! Resource subscription manager
2//!
3//! Tracks which sessions are subscribed to which resource URIs and provides
4//! notification dispatch when resources change.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, RwLock};
8
9use serde::{Deserialize, Serialize};
10use tokio::sync::mpsc;
11
12use crate::transport::traits::JsonRpcNotification;
13
14/// Resource subscription request parameters
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SubscribeRequest {
17    /// URI of the resource to subscribe to
18    pub uri: String,
19}
20
21/// Resource unsubscribe request parameters
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct UnsubscribeRequest {
24    /// URI of the resource to unsubscribe from
25    pub uri: String,
26}
27
28/// Resource updated notification parameters
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ResourceUpdatedNotification {
31    /// URI of the updated resource
32    pub uri: String,
33}
34
35/// Subscription manager for tracking resource subscriptions
36///
37/// Thread-safe structure that maps resource URIs to subscribed sessions
38/// and sessions to their subscribed URIs.
39#[derive(Clone)]
40pub struct SubscriptionManager {
41    /// Maps resource URI -> set of session IDs
42    uri_to_sessions: Arc<RwLock<HashMap<String, HashSet<String>>>>,
43
44    /// Maps session ID -> set of resource URIs (for efficient cleanup)
45    session_to_uris: Arc<RwLock<HashMap<String, HashSet<String>>>>,
46
47    /// Notification channel for dispatching updates
48    notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
49
50    /// Per-session notification channels (for targeted notifications)
51    session_notifiers: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<JsonRpcNotification>>>>,
52}
53
54impl SubscriptionManager {
55    /// Create a new subscription manager
56    pub fn new() -> Self {
57        Self {
58            uri_to_sessions: Arc::new(RwLock::new(HashMap::new())),
59            session_to_uris: Arc::new(RwLock::new(HashMap::new())),
60            notification_tx: None,
61            session_notifiers: Arc::new(RwLock::new(HashMap::new())),
62        }
63    }
64
65    /// Create subscription manager with a global notification channel
66    pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
67        Self {
68            uri_to_sessions: Arc::new(RwLock::new(HashMap::new())),
69            session_to_uris: Arc::new(RwLock::new(HashMap::new())),
70            notification_tx: Some(notification_tx),
71            session_notifiers: Arc::new(RwLock::new(HashMap::new())),
72        }
73    }
74
75    /// Register a per-session notification channel
76    ///
77    /// This allows targeted notifications to specific sessions rather than
78    /// broadcasting to all.
79    pub fn register_session_notifier(
80        &self,
81        session_id: &str,
82        tx: mpsc::UnboundedSender<JsonRpcNotification>,
83    ) {
84        if let Ok(mut notifiers) = self.session_notifiers.write() {
85            notifiers.insert(session_id.to_string(), tx);
86        }
87    }
88
89    /// Unregister a session's notification channel
90    pub fn unregister_session_notifier(&self, session_id: &str) {
91        if let Ok(mut notifiers) = self.session_notifiers.write() {
92            notifiers.remove(session_id);
93        }
94    }
95
96    /// Subscribe a session to a resource URI
97    ///
98    /// Returns true if this is a new subscription, false if already subscribed.
99    pub fn subscribe(&self, session_id: &str, uri: &str) -> bool {
100        let is_new_uri_subscription;
101
102        // Add to uri_to_sessions
103        {
104            let mut uri_map = self.uri_to_sessions.write().unwrap();
105            let sessions = uri_map.entry(uri.to_string()).or_default();
106            is_new_uri_subscription = sessions.insert(session_id.to_string());
107        }
108
109        // Add to session_to_uris
110        {
111            let mut session_map = self.session_to_uris.write().unwrap();
112            let uris = session_map.entry(session_id.to_string()).or_default();
113            uris.insert(uri.to_string());
114        }
115
116        tracing::debug!(
117            session_id = %session_id,
118            uri = %uri,
119            new = is_new_uri_subscription,
120            "Resource subscription"
121        );
122
123        is_new_uri_subscription
124    }
125
126    /// Unsubscribe a session from a resource URI
127    ///
128    /// Returns true if subscription existed and was removed.
129    pub fn unsubscribe(&self, session_id: &str, uri: &str) -> bool {
130        let was_subscribed;
131
132        // Remove from uri_to_sessions
133        {
134            let mut uri_map = self.uri_to_sessions.write().unwrap();
135            if let Some(sessions) = uri_map.get_mut(uri) {
136                was_subscribed = sessions.remove(session_id);
137                // Clean up empty sets
138                if sessions.is_empty() {
139                    uri_map.remove(uri);
140                }
141            } else {
142                was_subscribed = false;
143            }
144        }
145
146        // Remove from session_to_uris
147        {
148            let mut session_map = self.session_to_uris.write().unwrap();
149            if let Some(uris) = session_map.get_mut(session_id) {
150                uris.remove(uri);
151                // Clean up empty sets
152                if uris.is_empty() {
153                    session_map.remove(session_id);
154                }
155            }
156        }
157
158        tracing::debug!(
159            session_id = %session_id,
160            uri = %uri,
161            was_subscribed = was_subscribed,
162            "Resource unsubscription"
163        );
164
165        was_subscribed
166    }
167
168    /// Remove all subscriptions for a session
169    ///
170    /// Call this when a session closes to clean up all its subscriptions.
171    pub fn remove_session(&self, session_id: &str) {
172        // Get all URIs this session was subscribed to
173        let uris: Vec<String> = {
174            let mut session_map = self.session_to_uris.write().unwrap();
175            session_map
176                .remove(session_id)
177                .map(|uris| uris.into_iter().collect())
178                .unwrap_or_default()
179        };
180
181        // Remove session from each URI's subscriber set
182        {
183            let mut uri_map = self.uri_to_sessions.write().unwrap();
184            for uri in &uris {
185                if let Some(sessions) = uri_map.get_mut(uri) {
186                    sessions.remove(session_id);
187                    if sessions.is_empty() {
188                        uri_map.remove(uri);
189                    }
190                }
191            }
192        }
193
194        // Remove session notifier
195        self.unregister_session_notifier(session_id);
196
197        if !uris.is_empty() {
198            tracing::debug!(
199                session_id = %session_id,
200                subscription_count = uris.len(),
201                "Removed all session subscriptions"
202            );
203        }
204    }
205
206    /// Check if a session is subscribed to a URI
207    pub fn is_subscribed(&self, session_id: &str, uri: &str) -> bool {
208        self.uri_to_sessions
209            .read()
210            .ok()
211            .and_then(|map| map.get(uri).map(|sessions| sessions.contains(session_id)))
212            .unwrap_or(false)
213    }
214
215    /// Get all sessions subscribed to a URI
216    pub fn get_subscribers(&self, uri: &str) -> Vec<String> {
217        self.uri_to_sessions
218            .read()
219            .ok()
220            .and_then(|map| map.get(uri).map(|sessions| sessions.iter().cloned().collect()))
221            .unwrap_or_default()
222    }
223
224    /// Get all URIs a session is subscribed to
225    pub fn get_session_subscriptions(&self, session_id: &str) -> Vec<String> {
226        self.session_to_uris
227            .read()
228            .ok()
229            .and_then(|map| map.get(session_id).map(|uris| uris.iter().cloned().collect()))
230            .unwrap_or_default()
231    }
232
233    /// Get total subscription count
234    pub fn subscription_count(&self) -> usize {
235        self.uri_to_sessions
236            .read()
237            .ok()
238            .map(|map| map.values().map(|s| s.len()).sum())
239            .unwrap_or(0)
240    }
241
242    /// Get count of unique URIs with subscribers
243    pub fn subscribed_uri_count(&self) -> usize {
244        self.uri_to_sessions
245            .read()
246            .ok()
247            .map(|map| map.len())
248            .unwrap_or(0)
249    }
250
251    /// Notify all subscribers that a resource has been updated
252    ///
253    /// Sends `notifications/resources/updated` to all sessions subscribed to the URI.
254    pub fn notify_resource_updated(&self, uri: &str) {
255        let subscribers = self.get_subscribers(uri);
256
257        if subscribers.is_empty() {
258            return;
259        }
260
261        let notification = JsonRpcNotification::new(
262            "notifications/resources/updated",
263            Some(serde_json::json!({ "uri": uri })),
264        );
265
266        // Try per-session notifications first
267        let notifiers = self.session_notifiers.read().ok();
268        let mut notified_count = 0;
269
270        if let Some(notifiers) = notifiers {
271            for session_id in &subscribers {
272                if let Some(tx) = notifiers.get(session_id)
273                    && tx.send(notification.clone()).is_ok()
274                {
275                    notified_count += 1;
276                }
277            }
278        }
279
280        // Fall back to global notification channel if no per-session notifiers
281        if notified_count == 0
282            && let Some(ref tx) = self.notification_tx
283        {
284            // Note: This broadcasts to all connections, not just subscribers
285            // In a multi-session scenario, you'd need session-aware transport
286            let _ = tx.send(notification);
287        }
288
289        tracing::debug!(
290            uri = %uri,
291            subscriber_count = subscribers.len(),
292            notified_count = notified_count,
293            "Resource update notification sent"
294        );
295    }
296
297    /// Notify subscribers for multiple URIs (batch operation)
298    pub fn notify_resources_updated(&self, uris: &[&str]) {
299        for uri in uris {
300            self.notify_resource_updated(uri);
301        }
302    }
303}
304
305impl Default for SubscriptionManager {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    use super::*;
314
315    #[test]
316    fn test_subscription_manager_creation() {
317        let manager = SubscriptionManager::new();
318        assert_eq!(manager.subscription_count(), 0);
319        assert_eq!(manager.subscribed_uri_count(), 0);
320    }
321
322    #[test]
323    fn test_subscribe() {
324        let manager = SubscriptionManager::new();
325
326        // First subscription returns true
327        assert!(manager.subscribe("session-1", "file:///test.txt"));
328
329        // Duplicate subscription returns false
330        assert!(!manager.subscribe("session-1", "file:///test.txt"));
331
332        // Different session, same URI returns true
333        assert!(manager.subscribe("session-2", "file:///test.txt"));
334
335        assert_eq!(manager.subscription_count(), 2);
336        assert_eq!(manager.subscribed_uri_count(), 1);
337    }
338
339    #[test]
340    fn test_unsubscribe() {
341        let manager = SubscriptionManager::new();
342
343        manager.subscribe("session-1", "file:///test.txt");
344        manager.subscribe("session-2", "file:///test.txt");
345
346        // Unsubscribe existing returns true
347        assert!(manager.unsubscribe("session-1", "file:///test.txt"));
348
349        // Unsubscribe again returns false
350        assert!(!manager.unsubscribe("session-1", "file:///test.txt"));
351
352        // Other session still subscribed
353        assert!(manager.is_subscribed("session-2", "file:///test.txt"));
354
355        assert_eq!(manager.subscription_count(), 1);
356    }
357
358    #[test]
359    fn test_remove_session() {
360        let manager = SubscriptionManager::new();
361
362        manager.subscribe("session-1", "file:///a.txt");
363        manager.subscribe("session-1", "file:///b.txt");
364        manager.subscribe("session-2", "file:///a.txt");
365
366        manager.remove_session("session-1");
367
368        assert!(!manager.is_subscribed("session-1", "file:///a.txt"));
369        assert!(!manager.is_subscribed("session-1", "file:///b.txt"));
370        assert!(manager.is_subscribed("session-2", "file:///a.txt"));
371
372        // file:///b.txt should be cleaned up since no subscribers
373        assert_eq!(manager.subscribed_uri_count(), 1);
374    }
375
376    #[test]
377    fn test_get_subscribers() {
378        let manager = SubscriptionManager::new();
379
380        manager.subscribe("session-1", "file:///test.txt");
381        manager.subscribe("session-2", "file:///test.txt");
382        manager.subscribe("session-3", "file:///other.txt");
383
384        let subscribers = manager.get_subscribers("file:///test.txt");
385        assert_eq!(subscribers.len(), 2);
386        assert!(subscribers.contains(&"session-1".to_string()));
387        assert!(subscribers.contains(&"session-2".to_string()));
388    }
389
390    #[test]
391    fn test_get_session_subscriptions() {
392        let manager = SubscriptionManager::new();
393
394        manager.subscribe("session-1", "file:///a.txt");
395        manager.subscribe("session-1", "file:///b.txt");
396        manager.subscribe("session-2", "file:///c.txt");
397
398        let subs = manager.get_session_subscriptions("session-1");
399        assert_eq!(subs.len(), 2);
400        assert!(subs.contains(&"file:///a.txt".to_string()));
401        assert!(subs.contains(&"file:///b.txt".to_string()));
402    }
403
404    #[test]
405    fn test_is_subscribed() {
406        let manager = SubscriptionManager::new();
407
408        assert!(!manager.is_subscribed("session-1", "file:///test.txt"));
409
410        manager.subscribe("session-1", "file:///test.txt");
411
412        assert!(manager.is_subscribed("session-1", "file:///test.txt"));
413        assert!(!manager.is_subscribed("session-2", "file:///test.txt"));
414    }
415
416    #[tokio::test]
417    async fn test_notify_resource_updated() {
418        let (tx, mut rx) = mpsc::unbounded_channel();
419        let manager = SubscriptionManager::with_notifications(tx);
420
421        manager.subscribe("session-1", "file:///test.txt");
422
423        manager.notify_resource_updated("file:///test.txt");
424
425        // Should receive notification
426        let notification = rx.recv().await.expect("Should receive notification");
427        assert_eq!(notification.method, "notifications/resources/updated");
428
429        let params = notification.params.expect("Should have params");
430        assert_eq!(params["uri"], "file:///test.txt");
431    }
432
433    #[tokio::test]
434    async fn test_per_session_notifier() {
435        let manager = SubscriptionManager::new();
436        let (tx1, mut rx1) = mpsc::unbounded_channel();
437        let (tx2, mut rx2) = mpsc::unbounded_channel();
438
439        manager.register_session_notifier("session-1", tx1);
440        manager.register_session_notifier("session-2", tx2);
441
442        manager.subscribe("session-1", "file:///test.txt");
443        // session-2 not subscribed
444
445        manager.notify_resource_updated("file:///test.txt");
446
447        // session-1 should receive notification
448        let notification = rx1.recv().await.expect("Session-1 should receive notification");
449        assert_eq!(notification.method, "notifications/resources/updated");
450
451        // session-2 should NOT receive (not subscribed)
452        // Use try_recv to check without blocking
453        assert!(rx2.try_recv().is_err());
454    }
455}