1use std::collections::{HashMap, HashSet};
7use std::sync::{Arc, RwLock};
8
9use serde::{Deserialize, Serialize};
10use tokio::sync::mpsc;
11
12use crate::transport::traits::JsonRpcNotification;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct SubscribeRequest {
17 pub uri: String,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct UnsubscribeRequest {
24 pub uri: String,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ResourceUpdatedNotification {
31 pub uri: String,
33}
34
35#[derive(Clone)]
40pub struct SubscriptionManager {
41 uri_to_sessions: Arc<RwLock<HashMap<String, HashSet<String>>>>,
43
44 session_to_uris: Arc<RwLock<HashMap<String, HashSet<String>>>>,
46
47 notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
49
50 session_notifiers: Arc<RwLock<HashMap<String, mpsc::UnboundedSender<JsonRpcNotification>>>>,
52}
53
54impl SubscriptionManager {
55 pub fn new() -> Self {
57 Self {
58 uri_to_sessions: Arc::new(RwLock::new(HashMap::new())),
59 session_to_uris: Arc::new(RwLock::new(HashMap::new())),
60 notification_tx: None,
61 session_notifiers: Arc::new(RwLock::new(HashMap::new())),
62 }
63 }
64
65 pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
67 Self {
68 uri_to_sessions: Arc::new(RwLock::new(HashMap::new())),
69 session_to_uris: Arc::new(RwLock::new(HashMap::new())),
70 notification_tx: Some(notification_tx),
71 session_notifiers: Arc::new(RwLock::new(HashMap::new())),
72 }
73 }
74
75 pub fn register_session_notifier(
80 &self,
81 session_id: &str,
82 tx: mpsc::UnboundedSender<JsonRpcNotification>,
83 ) {
84 if let Ok(mut notifiers) = self.session_notifiers.write() {
85 notifiers.insert(session_id.to_string(), tx);
86 }
87 }
88
89 pub fn unregister_session_notifier(&self, session_id: &str) {
91 if let Ok(mut notifiers) = self.session_notifiers.write() {
92 notifiers.remove(session_id);
93 }
94 }
95
96 pub fn subscribe(&self, session_id: &str, uri: &str) -> bool {
100 let is_new_uri_subscription;
101
102 {
104 let mut uri_map = self.uri_to_sessions.write().unwrap();
105 let sessions = uri_map.entry(uri.to_string()).or_default();
106 is_new_uri_subscription = sessions.insert(session_id.to_string());
107 }
108
109 {
111 let mut session_map = self.session_to_uris.write().unwrap();
112 let uris = session_map.entry(session_id.to_string()).or_default();
113 uris.insert(uri.to_string());
114 }
115
116 tracing::debug!(
117 session_id = %session_id,
118 uri = %uri,
119 new = is_new_uri_subscription,
120 "Resource subscription"
121 );
122
123 is_new_uri_subscription
124 }
125
126 pub fn unsubscribe(&self, session_id: &str, uri: &str) -> bool {
130 let was_subscribed;
131
132 {
134 let mut uri_map = self.uri_to_sessions.write().unwrap();
135 if let Some(sessions) = uri_map.get_mut(uri) {
136 was_subscribed = sessions.remove(session_id);
137 if sessions.is_empty() {
139 uri_map.remove(uri);
140 }
141 } else {
142 was_subscribed = false;
143 }
144 }
145
146 {
148 let mut session_map = self.session_to_uris.write().unwrap();
149 if let Some(uris) = session_map.get_mut(session_id) {
150 uris.remove(uri);
151 if uris.is_empty() {
153 session_map.remove(session_id);
154 }
155 }
156 }
157
158 tracing::debug!(
159 session_id = %session_id,
160 uri = %uri,
161 was_subscribed = was_subscribed,
162 "Resource unsubscription"
163 );
164
165 was_subscribed
166 }
167
168 pub fn remove_session(&self, session_id: &str) {
172 let uris: Vec<String> = {
174 let mut session_map = self.session_to_uris.write().unwrap();
175 session_map
176 .remove(session_id)
177 .map(|uris| uris.into_iter().collect())
178 .unwrap_or_default()
179 };
180
181 {
183 let mut uri_map = self.uri_to_sessions.write().unwrap();
184 for uri in &uris {
185 if let Some(sessions) = uri_map.get_mut(uri) {
186 sessions.remove(session_id);
187 if sessions.is_empty() {
188 uri_map.remove(uri);
189 }
190 }
191 }
192 }
193
194 self.unregister_session_notifier(session_id);
196
197 if !uris.is_empty() {
198 tracing::debug!(
199 session_id = %session_id,
200 subscription_count = uris.len(),
201 "Removed all session subscriptions"
202 );
203 }
204 }
205
206 pub fn is_subscribed(&self, session_id: &str, uri: &str) -> bool {
208 self.uri_to_sessions
209 .read()
210 .ok()
211 .and_then(|map| map.get(uri).map(|sessions| sessions.contains(session_id)))
212 .unwrap_or(false)
213 }
214
215 pub fn get_subscribers(&self, uri: &str) -> Vec<String> {
217 self.uri_to_sessions
218 .read()
219 .ok()
220 .and_then(|map| map.get(uri).map(|sessions| sessions.iter().cloned().collect()))
221 .unwrap_or_default()
222 }
223
224 pub fn get_session_subscriptions(&self, session_id: &str) -> Vec<String> {
226 self.session_to_uris
227 .read()
228 .ok()
229 .and_then(|map| map.get(session_id).map(|uris| uris.iter().cloned().collect()))
230 .unwrap_or_default()
231 }
232
233 pub fn subscription_count(&self) -> usize {
235 self.uri_to_sessions
236 .read()
237 .ok()
238 .map(|map| map.values().map(|s| s.len()).sum())
239 .unwrap_or(0)
240 }
241
242 pub fn subscribed_uri_count(&self) -> usize {
244 self.uri_to_sessions
245 .read()
246 .ok()
247 .map(|map| map.len())
248 .unwrap_or(0)
249 }
250
251 pub fn notify_resource_updated(&self, uri: &str) {
255 let subscribers = self.get_subscribers(uri);
256
257 if subscribers.is_empty() {
258 return;
259 }
260
261 let notification = JsonRpcNotification::new(
262 "notifications/resources/updated",
263 Some(serde_json::json!({ "uri": uri })),
264 );
265
266 let notifiers = self.session_notifiers.read().ok();
268 let mut notified_count = 0;
269
270 if let Some(notifiers) = notifiers {
271 for session_id in &subscribers {
272 if let Some(tx) = notifiers.get(session_id)
273 && tx.send(notification.clone()).is_ok()
274 {
275 notified_count += 1;
276 }
277 }
278 }
279
280 if notified_count == 0
282 && let Some(ref tx) = self.notification_tx
283 {
284 let _ = tx.send(notification);
287 }
288
289 tracing::debug!(
290 uri = %uri,
291 subscriber_count = subscribers.len(),
292 notified_count = notified_count,
293 "Resource update notification sent"
294 );
295 }
296
297 pub fn notify_resources_updated(&self, uris: &[&str]) {
299 for uri in uris {
300 self.notify_resource_updated(uri);
301 }
302 }
303}
304
305impl Default for SubscriptionManager {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311#[cfg(test)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn test_subscription_manager_creation() {
317 let manager = SubscriptionManager::new();
318 assert_eq!(manager.subscription_count(), 0);
319 assert_eq!(manager.subscribed_uri_count(), 0);
320 }
321
322 #[test]
323 fn test_subscribe() {
324 let manager = SubscriptionManager::new();
325
326 assert!(manager.subscribe("session-1", "file:///test.txt"));
328
329 assert!(!manager.subscribe("session-1", "file:///test.txt"));
331
332 assert!(manager.subscribe("session-2", "file:///test.txt"));
334
335 assert_eq!(manager.subscription_count(), 2);
336 assert_eq!(manager.subscribed_uri_count(), 1);
337 }
338
339 #[test]
340 fn test_unsubscribe() {
341 let manager = SubscriptionManager::new();
342
343 manager.subscribe("session-1", "file:///test.txt");
344 manager.subscribe("session-2", "file:///test.txt");
345
346 assert!(manager.unsubscribe("session-1", "file:///test.txt"));
348
349 assert!(!manager.unsubscribe("session-1", "file:///test.txt"));
351
352 assert!(manager.is_subscribed("session-2", "file:///test.txt"));
354
355 assert_eq!(manager.subscription_count(), 1);
356 }
357
358 #[test]
359 fn test_remove_session() {
360 let manager = SubscriptionManager::new();
361
362 manager.subscribe("session-1", "file:///a.txt");
363 manager.subscribe("session-1", "file:///b.txt");
364 manager.subscribe("session-2", "file:///a.txt");
365
366 manager.remove_session("session-1");
367
368 assert!(!manager.is_subscribed("session-1", "file:///a.txt"));
369 assert!(!manager.is_subscribed("session-1", "file:///b.txt"));
370 assert!(manager.is_subscribed("session-2", "file:///a.txt"));
371
372 assert_eq!(manager.subscribed_uri_count(), 1);
374 }
375
376 #[test]
377 fn test_get_subscribers() {
378 let manager = SubscriptionManager::new();
379
380 manager.subscribe("session-1", "file:///test.txt");
381 manager.subscribe("session-2", "file:///test.txt");
382 manager.subscribe("session-3", "file:///other.txt");
383
384 let subscribers = manager.get_subscribers("file:///test.txt");
385 assert_eq!(subscribers.len(), 2);
386 assert!(subscribers.contains(&"session-1".to_string()));
387 assert!(subscribers.contains(&"session-2".to_string()));
388 }
389
390 #[test]
391 fn test_get_session_subscriptions() {
392 let manager = SubscriptionManager::new();
393
394 manager.subscribe("session-1", "file:///a.txt");
395 manager.subscribe("session-1", "file:///b.txt");
396 manager.subscribe("session-2", "file:///c.txt");
397
398 let subs = manager.get_session_subscriptions("session-1");
399 assert_eq!(subs.len(), 2);
400 assert!(subs.contains(&"file:///a.txt".to_string()));
401 assert!(subs.contains(&"file:///b.txt".to_string()));
402 }
403
404 #[test]
405 fn test_is_subscribed() {
406 let manager = SubscriptionManager::new();
407
408 assert!(!manager.is_subscribed("session-1", "file:///test.txt"));
409
410 manager.subscribe("session-1", "file:///test.txt");
411
412 assert!(manager.is_subscribed("session-1", "file:///test.txt"));
413 assert!(!manager.is_subscribed("session-2", "file:///test.txt"));
414 }
415
416 #[tokio::test]
417 async fn test_notify_resource_updated() {
418 let (tx, mut rx) = mpsc::unbounded_channel();
419 let manager = SubscriptionManager::with_notifications(tx);
420
421 manager.subscribe("session-1", "file:///test.txt");
422
423 manager.notify_resource_updated("file:///test.txt");
424
425 let notification = rx.recv().await.expect("Should receive notification");
427 assert_eq!(notification.method, "notifications/resources/updated");
428
429 let params = notification.params.expect("Should have params");
430 assert_eq!(params["uri"], "file:///test.txt");
431 }
432
433 #[tokio::test]
434 async fn test_per_session_notifier() {
435 let manager = SubscriptionManager::new();
436 let (tx1, mut rx1) = mpsc::unbounded_channel();
437 let (tx2, mut rx2) = mpsc::unbounded_channel();
438
439 manager.register_session_notifier("session-1", tx1);
440 manager.register_session_notifier("session-2", tx2);
441
442 manager.subscribe("session-1", "file:///test.txt");
443 manager.notify_resource_updated("file:///test.txt");
446
447 let notification = rx1.recv().await.expect("Session-1 should receive notification");
449 assert_eq!(notification.method, "notifications/resources/updated");
450
451 assert!(rx2.try_recv().is_err());
454 }
455}