Skip to main content

mcp_kit/server/
subscription.rs

1//! Resource subscription management.
2//!
3//! Tracks which resources clients have subscribed to, enabling the server
4//! to send targeted `notifications/resources/updated` when resources change.
5
6use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use crate::server::session::SessionId;
11
12/// Manages resource subscriptions across all sessions.
13///
14/// Thread-safe and can be shared across handlers and background tasks.
15#[derive(Clone, Default)]
16pub struct SubscriptionManager {
17    inner: Arc<RwLock<SubscriptionState>>,
18}
19
20#[derive(Default)]
21struct SubscriptionState {
22    /// Map from resource URI to set of subscribed session IDs
23    by_resource: HashMap<String, HashSet<SessionId>>,
24    /// Map from session ID to set of subscribed resource URIs
25    by_session: HashMap<SessionId, HashSet<String>>,
26}
27
28impl SubscriptionManager {
29    /// Create a new subscription manager.
30    pub fn new() -> Self {
31        Self::default()
32    }
33
34    /// Subscribe a session to a resource.
35    ///
36    /// Returns `true` if this is a new subscription, `false` if already subscribed.
37    pub async fn subscribe(&self, session_id: &SessionId, uri: &str) -> bool {
38        let mut state = self.inner.write().await;
39
40        let resource_subs = state
41            .by_resource
42            .entry(uri.to_string())
43            .or_insert_with(HashSet::new);
44        let is_new = resource_subs.insert(session_id.clone());
45
46        if is_new {
47            state
48                .by_session
49                .entry(session_id.clone())
50                .or_insert_with(HashSet::new)
51                .insert(uri.to_string());
52        }
53
54        is_new
55    }
56
57    /// Unsubscribe a session from a resource.
58    ///
59    /// Returns `true` if the subscription existed, `false` otherwise.
60    pub async fn unsubscribe(&self, session_id: &SessionId, uri: &str) -> bool {
61        let mut state = self.inner.write().await;
62
63        let removed = if let Some(resource_subs) = state.by_resource.get_mut(uri) {
64            let removed = resource_subs.remove(session_id);
65            if resource_subs.is_empty() {
66                state.by_resource.remove(uri);
67            }
68            removed
69        } else {
70            false
71        };
72
73        if removed {
74            if let Some(session_subs) = state.by_session.get_mut(session_id) {
75                session_subs.remove(uri);
76                if session_subs.is_empty() {
77                    state.by_session.remove(session_id);
78                }
79            }
80        }
81
82        removed
83    }
84
85    /// Unsubscribe a session from all resources.
86    ///
87    /// Call this when a session disconnects.
88    pub async fn unsubscribe_all(&self, session_id: &SessionId) {
89        let mut state = self.inner.write().await;
90
91        if let Some(uris) = state.by_session.remove(session_id) {
92            for uri in uris {
93                if let Some(resource_subs) = state.by_resource.get_mut(&uri) {
94                    resource_subs.remove(session_id);
95                    if resource_subs.is_empty() {
96                        state.by_resource.remove(&uri);
97                    }
98                }
99            }
100        }
101    }
102
103    /// Get all session IDs subscribed to a resource.
104    pub async fn subscribers(&self, uri: &str) -> Vec<SessionId> {
105        let state = self.inner.read().await;
106        state
107            .by_resource
108            .get(uri)
109            .map(|subs| subs.iter().cloned().collect())
110            .unwrap_or_default()
111    }
112
113    /// Get all resources a session is subscribed to.
114    pub async fn subscriptions(&self, session_id: &SessionId) -> Vec<String> {
115        let state = self.inner.read().await;
116        state
117            .by_session
118            .get(session_id)
119            .map(|subs| subs.iter().cloned().collect())
120            .unwrap_or_default()
121    }
122
123    /// Check if a session is subscribed to a resource.
124    pub async fn is_subscribed(&self, session_id: &SessionId, uri: &str) -> bool {
125        let state = self.inner.read().await;
126        state
127            .by_resource
128            .get(uri)
129            .map(|subs| subs.contains(session_id))
130            .unwrap_or(false)
131    }
132
133    /// Get the number of subscribers for a resource.
134    pub async fn subscriber_count(&self, uri: &str) -> usize {
135        let state = self.inner.read().await;
136        state.by_resource.get(uri).map(|s| s.len()).unwrap_or(0)
137    }
138
139    /// Get total number of active subscriptions.
140    pub async fn total_subscriptions(&self) -> usize {
141        let state = self.inner.read().await;
142        state.by_resource.values().map(|s| s.len()).sum()
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149
150    #[tokio::test]
151    async fn test_subscribe_unsubscribe() {
152        let mgr = SubscriptionManager::new();
153        let session = SessionId::new();
154
155        assert!(mgr.subscribe(&session, "file:///test.txt").await);
156        assert!(!mgr.subscribe(&session, "file:///test.txt").await); // duplicate
157
158        assert!(mgr.is_subscribed(&session, "file:///test.txt").await);
159        assert_eq!(mgr.subscriber_count("file:///test.txt").await, 1);
160
161        assert!(mgr.unsubscribe(&session, "file:///test.txt").await);
162        assert!(!mgr.is_subscribed(&session, "file:///test.txt").await);
163    }
164
165    #[tokio::test]
166    async fn test_unsubscribe_all() {
167        let mgr = SubscriptionManager::new();
168        let session = SessionId::new();
169
170        mgr.subscribe(&session, "file:///a.txt").await;
171        mgr.subscribe(&session, "file:///b.txt").await;
172
173        assert_eq!(mgr.subscriptions(&session).await.len(), 2);
174
175        mgr.unsubscribe_all(&session).await;
176
177        assert_eq!(mgr.subscriptions(&session).await.len(), 0);
178        assert_eq!(mgr.total_subscriptions().await, 0);
179    }
180}