1use clasp_core::{address::Pattern, SignalType, SubscribeOptions};
9use parking_lot::RwLock;
10use std::collections::{HashMap, HashSet};
11
12use crate::SessionId;
13
14#[derive(Debug, Clone)]
16pub struct Subscription {
17 pub id: u32,
19 pub session_id: SessionId,
21 pub pattern: Pattern,
23 pub types: HashSet<SignalType>,
25 pub options: SubscribeOptions,
27}
28
29impl Subscription {
30 pub fn new(
31 id: u32,
32 session_id: SessionId,
33 pattern: &str,
34 types: Vec<SignalType>,
35 options: SubscribeOptions,
36 ) -> Result<Self, clasp_core::Error> {
37 let pattern = Pattern::compile(pattern)?;
38
39 Ok(Self {
40 id,
41 session_id,
42 pattern,
43 types: types.into_iter().collect(),
44 options,
45 })
46 }
47
48 pub fn matches(&self, address: &str, signal_type: Option<SignalType>) -> bool {
50 if !self.pattern.matches(address) {
52 return false;
53 }
54
55 if !self.types.is_empty() {
57 if let Some(st) = signal_type {
58 if !self.types.contains(&st) {
59 return false;
60 }
61 }
62 }
63
64 true
65 }
66}
67
68#[derive(Debug, Clone)]
74struct SubscriberEntry {
75 session_id: SessionId,
76 sub_id: u32,
77 types: HashSet<SignalType>,
78 verify_pattern: Option<String>,
82}
83
84#[derive(Debug, Default)]
86struct TrieNode {
87 children: HashMap<String, TrieNode>,
89 wildcard: Option<Box<TrieNode>>,
91 globstar: Option<Box<TrieNode>>,
93 subscribers: Vec<SubscriberEntry>,
95}
96
97impl TrieNode {
98 fn is_empty(&self) -> bool {
99 self.subscribers.is_empty()
100 && self.children.is_empty()
101 && self.wildcard.is_none()
102 && self.globstar.is_none()
103 }
104
105 fn insert(&mut self, segments: &[&str], entry: SubscriberEntry) {
107 if segments.is_empty() {
108 self.subscribers.push(entry);
109 return;
110 }
111
112 let seg = segments[0];
113 let rest = &segments[1..];
114
115 if seg == "**" {
116 self.globstar
117 .get_or_insert_with(|| Box::new(TrieNode::default()))
118 .insert(rest, entry);
119 } else if seg == "*" || seg.contains('*') {
120 self.wildcard
124 .get_or_insert_with(|| Box::new(TrieNode::default()))
125 .insert(rest, entry);
126 } else {
127 self.children
128 .entry(seg.to_string())
129 .or_default()
130 .insert(rest, entry);
131 }
132 }
133
134 fn remove(&mut self, segments: &[&str], session_id: &str, sub_id: u32) -> bool {
136 if segments.is_empty() {
137 let before = self.subscribers.len();
138 self.subscribers
139 .retain(|e| !(e.session_id == session_id && e.sub_id == sub_id));
140 return self.subscribers.len() < before;
141 }
142
143 let seg = segments[0];
144 let rest = &segments[1..];
145
146 if seg == "**" {
147 if let Some(ref mut gs) = self.globstar {
148 let removed = gs.remove(rest, session_id, sub_id);
149 if gs.is_empty() {
150 self.globstar = None;
151 }
152 return removed;
153 }
154 false
155 } else if seg == "*" || seg.contains('*') {
156 if let Some(ref mut wc) = self.wildcard {
157 let removed = wc.remove(rest, session_id, sub_id);
158 if wc.is_empty() {
159 self.wildcard = None;
160 }
161 return removed;
162 }
163 false
164 } else {
165 let key = seg.to_string();
166 if let Some(child) = self.children.get_mut(&key) {
167 let removed = child.remove(rest, session_id, sub_id);
168 if child.is_empty() {
169 self.children.remove(&key);
170 }
171 removed
172 } else {
173 false
174 }
175 }
176 }
177
178 fn remove_session(&mut self, session_id: &str) {
180 self.subscribers.retain(|e| e.session_id != session_id);
181
182 for child in self.children.values_mut() {
183 child.remove_session(session_id);
184 }
185 self.children.retain(|_, c| !c.is_empty());
186
187 if let Some(ref mut wc) = self.wildcard {
188 wc.remove_session(session_id);
189 if wc.is_empty() {
190 self.wildcard = None;
191 }
192 }
193
194 if let Some(ref mut gs) = self.globstar {
195 gs.remove_session(session_id);
196 if gs.is_empty() {
197 self.globstar = None;
198 }
199 }
200 }
201
202 fn find_matches(
204 &self,
205 segments: &[&str],
206 idx: usize,
207 signal_type: Option<SignalType>,
208 address: &str,
209 results: &mut HashSet<SessionId>,
210 ) {
211 if let Some(ref gs) = self.globstar {
213 for i in idx..=segments.len() {
214 if i == segments.len() {
215 collect_filtered(&gs.subscribers, signal_type, address, results);
217 collect_zero_remaining(gs, signal_type, address, results);
219 } else {
220 if let Some(child) = gs.children.get(segments[i]) {
222 child.find_matches(segments, i + 1, signal_type, address, results);
223 }
224 if let Some(ref wc) = gs.wildcard {
226 wc.find_matches(segments, i + 1, signal_type, address, results);
227 }
228 if let Some(ref nested_gs) = gs.globstar {
230 nested_gs.find_matches(segments, i, signal_type, address, results);
233 }
234 }
235 }
236 }
237
238 if idx >= segments.len() {
240 collect_filtered(&self.subscribers, signal_type, address, results);
241 return;
242 }
243
244 let seg = segments[idx];
245
246 if let Some(child) = self.children.get(seg) {
248 child.find_matches(segments, idx + 1, signal_type, address, results);
249 }
250
251 if let Some(ref wc) = self.wildcard {
253 wc.find_matches(segments, idx + 1, signal_type, address, results);
254 }
255 }
256}
257
258fn collect_zero_remaining(
261 node: &TrieNode,
262 signal_type: Option<SignalType>,
263 address: &str,
264 results: &mut HashSet<SessionId>,
265) {
266 if let Some(ref gs) = node.globstar {
267 collect_filtered(&gs.subscribers, signal_type, address, results);
268 collect_zero_remaining(gs, signal_type, address, results);
269 }
270}
271
272fn collect_filtered(
275 subscribers: &[SubscriberEntry],
276 signal_type: Option<SignalType>,
277 address: &str,
278 results: &mut HashSet<SessionId>,
279) {
280 for entry in subscribers {
281 if let Some(ref pat) = entry.verify_pattern {
283 if !clasp_core::address::glob_match(pat, address) {
284 continue;
285 }
286 }
287
288 if entry.types.is_empty() || signal_type.is_none_or(|st| entry.types.contains(&st)) {
290 results.insert(entry.session_id.clone());
291 }
292 }
293}
294
295struct TrieInner {
301 root: TrieNode,
302 subscriptions: HashMap<(SessionId, u32), Subscription>,
304}
305
306pub struct SubscriptionManager {
308 inner: RwLock<TrieInner>,
309}
310
311impl SubscriptionManager {
312 pub fn new() -> Self {
313 Self {
314 inner: RwLock::new(TrieInner {
315 root: TrieNode::default(),
316 subscriptions: HashMap::new(),
317 }),
318 }
319 }
320
321 pub fn add(&self, sub: Subscription) {
323 let pattern_segments: Vec<String> = sub.pattern.address().segments().to_vec();
324 let segments: Vec<&str> = pattern_segments.iter().map(|s| s.as_str()).collect();
325
326 let has_partial_wildcard = pattern_segments
329 .iter()
330 .any(|s| s.contains('*') && s != "*" && s != "**");
331
332 let entry = SubscriberEntry {
333 session_id: sub.session_id.clone(),
334 sub_id: sub.id,
335 types: sub.types.clone(),
336 verify_pattern: if has_partial_wildcard {
337 Some(sub.pattern.address().as_str().to_string())
338 } else {
339 None
340 },
341 };
342
343 let key = (sub.session_id.clone(), sub.id);
344 let mut inner = self.inner.write();
345 inner.root.insert(&segments, entry);
346 inner.subscriptions.insert(key, sub);
347 }
348
349 pub fn remove(&self, session_id: &SessionId, id: u32) -> Option<Subscription> {
351 let mut inner = self.inner.write();
352 let key = (session_id.clone(), id);
353 if let Some(sub) = inner.subscriptions.remove(&key) {
354 let pattern_segments: Vec<String> = sub.pattern.address().segments().to_vec();
355 let segments: Vec<&str> = pattern_segments.iter().map(|s| s.as_str()).collect();
356 inner.root.remove(&segments, session_id, id);
357 Some(sub)
358 } else {
359 None
360 }
361 }
362
363 pub fn remove_session(&self, session_id: &SessionId) {
365 let mut inner = self.inner.write();
366 inner.subscriptions.retain(|k, _| k.0 != *session_id);
367 inner.root.remove_session(session_id);
368 }
369
370 pub fn find_subscribers(
372 &self,
373 address: &str,
374 signal_type: Option<SignalType>,
375 ) -> Vec<SessionId> {
376 let segments: Vec<&str> = if address.len() > 1 {
378 address[1..].split('/').collect()
379 } else {
380 vec![""]
382 };
383
384 let mut results = HashSet::new();
385 let inner = self.inner.read();
386 inner
387 .root
388 .find_matches(&segments, 0, signal_type, address, &mut results);
389
390 results.into_iter().collect()
391 }
392
393 pub fn len(&self) -> usize {
395 self.inner.read().subscriptions.len()
396 }
397
398 pub fn is_empty(&self) -> bool {
400 self.inner.read().subscriptions.is_empty()
401 }
402}
403
404impl Default for SubscriptionManager {
405 fn default() -> Self {
406 Self::new()
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 #[test]
415 fn test_subscription_matching() {
416 let sub = Subscription::new(
417 1,
418 "session1".to_string(),
419 "/lumen/scene/*/layer/*/opacity",
420 vec![],
421 SubscribeOptions::default(),
422 )
423 .unwrap();
424
425 assert!(sub.matches("/lumen/scene/0/layer/3/opacity", None));
426 assert!(!sub.matches("/lumen/scene/0/opacity", None));
427 }
428
429 #[test]
430 fn test_manager() {
431 let manager = SubscriptionManager::new();
432
433 let sub = Subscription::new(
434 1,
435 "session1".to_string(),
436 "/test/**",
437 vec![],
438 SubscribeOptions::default(),
439 )
440 .unwrap();
441
442 manager.add(sub);
443
444 let subscribers = manager.find_subscribers("/test/foo/bar", None);
445 assert!(subscribers.contains(&"session1".to_string()));
446 }
447
448 #[test]
449 fn test_root_globstar_subscription() {
450 let manager = SubscriptionManager::new();
452
453 let sub = Subscription::new(
454 1,
455 "session1".to_string(),
456 "/**",
457 vec![],
458 SubscribeOptions::default(),
459 )
460 .unwrap();
461
462 manager.add(sub);
463
464 let subscribers = manager.find_subscribers("/a/b/c", None);
466 assert!(
467 subscribers.contains(&"session1".to_string()),
468 "/** should match /a/b/c"
469 );
470
471 let subscribers = manager.find_subscribers("/foo", None);
472 assert!(
473 subscribers.contains(&"session1".to_string()),
474 "/** should match /foo"
475 );
476
477 let subscribers = manager.find_subscribers("/deeply/nested/path/here", None);
478 assert!(
479 subscribers.contains(&"session1".to_string()),
480 "/** should match deeply nested paths"
481 );
482 }
483
484 #[test]
485 fn test_multiple_globstar_patterns() {
486 let manager = SubscriptionManager::new();
488
489 manager.add(
491 Subscription::new(
492 1,
493 "global".to_string(),
494 "/**",
495 vec![],
496 SubscribeOptions::default(),
497 )
498 .unwrap(),
499 );
500
501 manager.add(
503 Subscription::new(
504 2,
505 "lumen".to_string(),
506 "/lumen/**",
507 vec![],
508 SubscribeOptions::default(),
509 )
510 .unwrap(),
511 );
512
513 manager.add(
515 Subscription::new(
516 3,
517 "other".to_string(),
518 "/other/**",
519 vec![],
520 SubscribeOptions::default(),
521 )
522 .unwrap(),
523 );
524
525 let subscribers = manager.find_subscribers("/lumen/scene/0", None);
527 assert!(subscribers.contains(&"global".to_string()));
528 assert!(subscribers.contains(&"lumen".to_string()));
529 assert!(!subscribers.contains(&"other".to_string()));
530
531 let subscribers = manager.find_subscribers("/other/data", None);
533 assert!(subscribers.contains(&"global".to_string()));
534 assert!(subscribers.contains(&"other".to_string()));
535 assert!(!subscribers.contains(&"lumen".to_string()));
536 }
537
538 #[test]
539 fn test_remove_cleans_up_by_prefix() {
540 let manager = SubscriptionManager::new();
541
542 let sub = Subscription::new(
544 1,
545 "session1".to_string(),
546 "/test/**",
547 vec![],
548 SubscribeOptions::default(),
549 )
550 .unwrap();
551
552 manager.add(sub);
553 assert_eq!(manager.len(), 1);
554
555 let removed = manager.remove(&"session1".to_string(), 1);
557 assert!(removed.is_some());
558 assert_eq!(manager.len(), 0);
559
560 let sub2 = Subscription::new(
564 2,
565 "session2".to_string(),
566 "/test/**",
567 vec![],
568 SubscribeOptions::default(),
569 )
570 .unwrap();
571
572 manager.add(sub2);
573 let subscribers = manager.find_subscribers("/test/foo", None);
574 assert_eq!(subscribers.len(), 1);
575 assert!(subscribers.contains(&"session2".to_string()));
576 }
577
578 #[test]
579 fn test_remove_session_cleans_up_by_prefix() {
580 let manager = SubscriptionManager::new();
581
582 manager.add(
584 Subscription::new(
585 1,
586 "session1".to_string(),
587 "/test/**",
588 vec![],
589 SubscribeOptions::default(),
590 )
591 .unwrap(),
592 );
593 manager.add(
594 Subscription::new(
595 2,
596 "session1".to_string(),
597 "/other/**",
598 vec![],
599 SubscribeOptions::default(),
600 )
601 .unwrap(),
602 );
603
604 manager.add(
606 Subscription::new(
607 1,
608 "session2".to_string(),
609 "/test/**",
610 vec![],
611 SubscribeOptions::default(),
612 )
613 .unwrap(),
614 );
615
616 assert_eq!(manager.len(), 3);
617
618 manager.remove_session(&"session1".to_string());
620 assert_eq!(manager.len(), 1);
621
622 let subscribers = manager.find_subscribers("/test/foo", None);
624 assert_eq!(subscribers.len(), 1);
625 assert!(subscribers.contains(&"session2".to_string()));
626
627 let subscribers = manager.find_subscribers("/other/foo", None);
629 assert_eq!(subscribers.len(), 0);
630 }
631
632 #[test]
635 fn test_exact_address_match() {
636 let manager = SubscriptionManager::new();
637 manager.add(
638 Subscription::new(
639 1,
640 "s1".to_string(),
641 "/chat/room/abc/messages",
642 vec![],
643 SubscribeOptions::default(),
644 )
645 .unwrap(),
646 );
647
648 assert_eq!(
649 manager
650 .find_subscribers("/chat/room/abc/messages", None)
651 .len(),
652 1
653 );
654 assert_eq!(
655 manager
656 .find_subscribers("/chat/room/xyz/messages", None)
657 .len(),
658 0
659 );
660 assert_eq!(manager.find_subscribers("/chat/room/abc", None).len(), 0);
661 }
662
663 #[test]
664 fn test_single_wildcard() {
665 let manager = SubscriptionManager::new();
666 manager.add(
667 Subscription::new(
668 1,
669 "s1".to_string(),
670 "/chat/room/*/messages",
671 vec![],
672 SubscribeOptions::default(),
673 )
674 .unwrap(),
675 );
676
677 assert_eq!(
678 manager
679 .find_subscribers("/chat/room/abc/messages", None)
680 .len(),
681 1
682 );
683 assert_eq!(
684 manager
685 .find_subscribers("/chat/room/xyz/messages", None)
686 .len(),
687 1
688 );
689 assert_eq!(
691 manager
692 .find_subscribers("/chat/room/a/b/messages", None)
693 .len(),
694 0
695 );
696 }
697
698 #[test]
699 fn test_globstar_matches_zero_segments() {
700 let manager = SubscriptionManager::new();
701 manager.add(
702 Subscription::new(
703 1,
704 "s1".to_string(),
705 "/chat/**",
706 vec![],
707 SubscribeOptions::default(),
708 )
709 .unwrap(),
710 );
711
712 assert_eq!(manager.find_subscribers("/chat", None).len(), 1);
714 assert_eq!(manager.find_subscribers("/chat/room", None).len(), 1);
716 assert_eq!(
718 manager
719 .find_subscribers("/chat/room/abc/messages", None)
720 .len(),
721 1
722 );
723 }
724
725 #[test]
726 fn test_signal_type_filtering() {
727 let manager = SubscriptionManager::new();
728 manager.add(
729 Subscription::new(
730 1,
731 "s1".to_string(),
732 "/data/**",
733 vec![SignalType::Param],
734 SubscribeOptions::default(),
735 )
736 .unwrap(),
737 );
738 manager.add(
739 Subscription::new(
740 1,
741 "s2".to_string(),
742 "/data/**",
743 vec![SignalType::Event],
744 SubscribeOptions::default(),
745 )
746 .unwrap(),
747 );
748 manager.add(
749 Subscription::new(
750 1,
751 "s3".to_string(),
752 "/data/**",
753 vec![], SubscribeOptions::default(),
755 )
756 .unwrap(),
757 );
758
759 let param_subs = manager.find_subscribers("/data/x", Some(SignalType::Param));
760 assert!(param_subs.contains(&"s1".to_string()));
761 assert!(!param_subs.contains(&"s2".to_string()));
762 assert!(param_subs.contains(&"s3".to_string()));
763
764 let event_subs = manager.find_subscribers("/data/x", Some(SignalType::Event));
765 assert!(!event_subs.contains(&"s1".to_string()));
766 assert!(event_subs.contains(&"s2".to_string()));
767 assert!(event_subs.contains(&"s3".to_string()));
768
769 let all_subs = manager.find_subscribers("/data/x", None);
771 assert_eq!(all_subs.len(), 3);
772 }
773
774 #[test]
775 fn test_multiple_wildcards_in_pattern() {
776 let manager = SubscriptionManager::new();
777 manager.add(
778 Subscription::new(
779 1,
780 "s1".to_string(),
781 "/scene/*/layer/*/opacity",
782 vec![],
783 SubscribeOptions::default(),
784 )
785 .unwrap(),
786 );
787
788 assert_eq!(
789 manager
790 .find_subscribers("/scene/0/layer/3/opacity", None)
791 .len(),
792 1
793 );
794 assert_eq!(
795 manager
796 .find_subscribers("/scene/main/layer/bg/opacity", None)
797 .len(),
798 1
799 );
800 assert_eq!(
801 manager
802 .find_subscribers("/scene/0/layer/3/color", None)
803 .len(),
804 0
805 );
806 }
807
808 #[test]
809 fn test_overlapping_patterns() {
810 let manager = SubscriptionManager::new();
811
812 manager.add(
814 Subscription::new(
815 1,
816 "exact".to_string(),
817 "/chat/room/abc/messages",
818 vec![],
819 SubscribeOptions::default(),
820 )
821 .unwrap(),
822 );
823
824 manager.add(
826 Subscription::new(
827 1,
828 "wild".to_string(),
829 "/chat/room/*/messages",
830 vec![],
831 SubscribeOptions::default(),
832 )
833 .unwrap(),
834 );
835
836 manager.add(
838 Subscription::new(
839 1,
840 "glob".to_string(),
841 "/chat/**",
842 vec![],
843 SubscribeOptions::default(),
844 )
845 .unwrap(),
846 );
847
848 manager.add(
850 Subscription::new(
851 1,
852 "root".to_string(),
853 "/**",
854 vec![],
855 SubscribeOptions::default(),
856 )
857 .unwrap(),
858 );
859
860 let subs = manager.find_subscribers("/chat/room/abc/messages", None);
861 assert_eq!(subs.len(), 4, "All four patterns should match");
862 assert!(subs.contains(&"exact".to_string()));
863 assert!(subs.contains(&"wild".to_string()));
864 assert!(subs.contains(&"glob".to_string()));
865 assert!(subs.contains(&"root".to_string()));
866 }
867
868 #[test]
869 fn test_trie_prunes_empty_nodes() {
870 let manager = SubscriptionManager::new();
871
872 manager.add(
873 Subscription::new(
874 1,
875 "s1".to_string(),
876 "/a/b/c",
877 vec![],
878 SubscribeOptions::default(),
879 )
880 .unwrap(),
881 );
882 manager.add(
883 Subscription::new(
884 2,
885 "s1".to_string(),
886 "/a/b/d",
887 vec![],
888 SubscribeOptions::default(),
889 )
890 .unwrap(),
891 );
892
893 manager.remove(&"s1".to_string(), 1);
894 assert_eq!(manager.len(), 1);
895
896 assert_eq!(manager.find_subscribers("/a/b/d", None).len(), 1);
898 assert_eq!(manager.find_subscribers("/a/b/c", None).len(), 0);
900
901 manager.remove(&"s1".to_string(), 2);
902 assert_eq!(manager.len(), 0);
903 assert!(manager.is_empty());
904 }
905}