mcp_kit/server/
subscription.rs1use std::collections::{HashMap, HashSet};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10use crate::server::session::SessionId;
11
12#[derive(Clone, Default)]
16pub struct SubscriptionManager {
17 inner: Arc<RwLock<SubscriptionState>>,
18}
19
20#[derive(Default)]
21struct SubscriptionState {
22 by_resource: HashMap<String, HashSet<SessionId>>,
24 by_session: HashMap<SessionId, HashSet<String>>,
26}
27
28impl SubscriptionManager {
29 pub fn new() -> Self {
31 Self::default()
32 }
33
34 pub async fn subscribe(&self, session_id: &SessionId, uri: &str) -> bool {
38 let mut state = self.inner.write().await;
39
40 let resource_subs = state
41 .by_resource
42 .entry(uri.to_string())
43 .or_insert_with(HashSet::new);
44 let is_new = resource_subs.insert(session_id.clone());
45
46 if is_new {
47 state
48 .by_session
49 .entry(session_id.clone())
50 .or_insert_with(HashSet::new)
51 .insert(uri.to_string());
52 }
53
54 is_new
55 }
56
57 pub async fn unsubscribe(&self, session_id: &SessionId, uri: &str) -> bool {
61 let mut state = self.inner.write().await;
62
63 let removed = if let Some(resource_subs) = state.by_resource.get_mut(uri) {
64 let removed = resource_subs.remove(session_id);
65 if resource_subs.is_empty() {
66 state.by_resource.remove(uri);
67 }
68 removed
69 } else {
70 false
71 };
72
73 if removed {
74 if let Some(session_subs) = state.by_session.get_mut(session_id) {
75 session_subs.remove(uri);
76 if session_subs.is_empty() {
77 state.by_session.remove(session_id);
78 }
79 }
80 }
81
82 removed
83 }
84
85 pub async fn unsubscribe_all(&self, session_id: &SessionId) {
89 let mut state = self.inner.write().await;
90
91 if let Some(uris) = state.by_session.remove(session_id) {
92 for uri in uris {
93 if let Some(resource_subs) = state.by_resource.get_mut(&uri) {
94 resource_subs.remove(session_id);
95 if resource_subs.is_empty() {
96 state.by_resource.remove(&uri);
97 }
98 }
99 }
100 }
101 }
102
103 pub async fn subscribers(&self, uri: &str) -> Vec<SessionId> {
105 let state = self.inner.read().await;
106 state
107 .by_resource
108 .get(uri)
109 .map(|subs| subs.iter().cloned().collect())
110 .unwrap_or_default()
111 }
112
113 pub async fn subscriptions(&self, session_id: &SessionId) -> Vec<String> {
115 let state = self.inner.read().await;
116 state
117 .by_session
118 .get(session_id)
119 .map(|subs| subs.iter().cloned().collect())
120 .unwrap_or_default()
121 }
122
123 pub async fn is_subscribed(&self, session_id: &SessionId, uri: &str) -> bool {
125 let state = self.inner.read().await;
126 state
127 .by_resource
128 .get(uri)
129 .map(|subs| subs.contains(session_id))
130 .unwrap_or(false)
131 }
132
133 pub async fn subscriber_count(&self, uri: &str) -> usize {
135 let state = self.inner.read().await;
136 state.by_resource.get(uri).map(|s| s.len()).unwrap_or(0)
137 }
138
139 pub async fn total_subscriptions(&self) -> usize {
141 let state = self.inner.read().await;
142 state.by_resource.values().map(|s| s.len()).sum()
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149
150 #[tokio::test]
151 async fn test_subscribe_unsubscribe() {
152 let mgr = SubscriptionManager::new();
153 let session = SessionId::new();
154
155 assert!(mgr.subscribe(&session, "file:///test.txt").await);
156 assert!(!mgr.subscribe(&session, "file:///test.txt").await); assert!(mgr.is_subscribed(&session, "file:///test.txt").await);
159 assert_eq!(mgr.subscriber_count("file:///test.txt").await, 1);
160
161 assert!(mgr.unsubscribe(&session, "file:///test.txt").await);
162 assert!(!mgr.is_subscribed(&session, "file:///test.txt").await);
163 }
164
165 #[tokio::test]
166 async fn test_unsubscribe_all() {
167 let mgr = SubscriptionManager::new();
168 let session = SessionId::new();
169
170 mgr.subscribe(&session, "file:///a.txt").await;
171 mgr.subscribe(&session, "file:///b.txt").await;
172
173 assert_eq!(mgr.subscriptions(&session).await.len(), 2);
174
175 mgr.unsubscribe_all(&session).await;
176
177 assert_eq!(mgr.subscriptions(&session).await.len(), 0);
178 assert_eq!(mgr.total_subscriptions().await, 0);
179 }
180}