1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use async_recursion::async_recursion;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tokio::sync::RwLock;
/// A node in the subscription trie.
#[derive(Debug, Default)]
struct TrieNode {
children: HashMap<u8, Arc<RwLock<TrieNode>>>,
/// Count of subscriptions ending exactly at this node.
count: AtomicUsize,
}
/// Manages topic subscriptions using a prefix trie for efficient matching.
#[derive(Debug, Default)]
pub(crate) struct SubscriptionTrie {
root: Arc<RwLock<TrieNode>>,
}
impl SubscriptionTrie {
pub fn new() -> Self {
Self::default()
}
/// Retrieves all currently active subscription topics.
/// Returns a Vec where each inner Vec<u8> is a subscribed topic prefix.
pub async fn get_all_topics(&self) -> Vec<Vec<u8>> {
let mut topics = Vec::new();
// Start recursive traversal from the root.
self
.collect_topics_recursive(self.root.clone(), Vec::new(), &mut topics)
.await;
topics
}
/// Adds a subscription topic (prefix).
/// Increments the count if the topic already exists.
pub async fn subscribe(&self, topic: &[u8]) {
let mut current_node_arc = self.root.clone();
for &byte in topic {
let mut current_node_w = current_node_arc.write().await; // Lock current node
let next_node_arc = current_node_w
.children
.entry(byte)
.or_insert_with(|| Arc::new(RwLock::new(TrieNode::default())))
.clone();
drop(current_node_w); // Release lock before moving to next node
current_node_arc = next_node_arc;
}
// Reached the end node for the topic, increment its count
let final_node_r = current_node_arc.read().await; // Read lock sufficient for fetch_add
final_node_r.count.fetch_add(1, Ordering::Relaxed);
tracing::debug!(topic = ?String::from_utf8_lossy(topic), "Subscribed");
}
/// Removes a subscription topic (prefix).
/// Decrements the count. Returns true if the topic existed and count reached zero.
/// Note: This basic version doesn't prune empty nodes.
pub async fn unsubscribe(&self, topic: &[u8]) -> bool {
let mut current_node_arc = self.root.clone();
let mut path = Vec::new(); // Store path for potential cleanup later if needed
for &byte in topic {
let next_node_option = {
// Start new scope
let current_node_r = current_node_arc.read().await; // Lock inside scope
current_node_r.children.get(&byte).cloned() // Clone the Arc<RwLock<TrieNode>> if found
}; // current_node_r guard is dropped here
match next_node_option {
Some(next_node_arc) => {
path.push((current_node_arc.clone(), byte));
current_node_arc = next_node_arc; // Reassignment happens *after* guard is dropped
}
None => {
tracing::debug!(topic = ?String::from_utf8_lossy(topic), "Unsubscribe failed: Topic prefix not found");
return false;
}
}
}
// Reached the end node
let final_node_r = current_node_arc.read().await;
let old_count = final_node_r.count.fetch_sub(1, Ordering::Relaxed); // Decrement count
if old_count > 0 {
tracing::debug!(topic = ?String::from_utf8_lossy(topic), new_count = old_count - 1, "Unsubscribed");
// TODO: Implement pruning of the trie if old_count was 1 and node has no children?
// Pruning requires write locks back up the path and careful handling of races.
// Delay pruning implementation for simplicity for now.
old_count == 1 // Return true if this was the last subscription for this exact topic
} else {
// Count was already zero somehow, restore it? Or log error?
final_node_r.count.fetch_add(1, Ordering::Relaxed); // Restore count
tracing::warn!(topic = ?String::from_utf8_lossy(topic), "Unsubscribe attempt on topic with zero count");
false
}
}
/// Checks if a given message topic matches *any* active subscription prefix.
pub async fn matches(&self, message_topic: &[u8]) -> bool {
let mut current_node_arc = self.root.clone();
// Check for exact match at root (empty subscription "")
{
let root_r = current_node_arc.read().await;
if root_r.count.load(Ordering::Relaxed) > 0 {
return true; // Matches empty subscription
}
}
// Traverse the trie based on the message topic bytes
for &byte in message_topic {
let (matched_prefix, next_node_option) = {
// Start new scope
let current_node_r = current_node_arc.read().await; // Lock inside scope
// Check if current node itself matches a prefix
let prefix_match = current_node_r.count.load(Ordering::Relaxed) > 0;
// Get the next node Arc if it exists
let next_node = current_node_r.children.get(&byte).cloned();
(prefix_match, next_node) // Return results
}; // current_node_r guard is dropped here
if matched_prefix {
return true; // Matched a subscription prefix during traversal
}
match next_node_option {
Some(next_node_arc) => {
current_node_arc = next_node_arc; // Reassignment happens *after* guard is dropped
}
None => {
return false; // No matching path further down
}
}
}
// Reached the end of the message topic, check if the final node is a subscription end
let final_node_r = current_node_arc.read().await;
final_node_r.count.load(Ordering::Relaxed) > 0
}
/// Recursive helper function to collect topics.
#[async_recursion]
async fn collect_topics_recursive(
&self,
node_arc: Arc<RwLock<TrieNode>>,
current_prefix: Vec<u8>,
all_topics: &mut Vec<Vec<u8>>,
) {
let node_read_guard = node_arc.read().await; // Lock current node for reading
// If this node marks the end of a subscription (count > 0), add its prefix.
if node_read_guard.count.load(Ordering::Relaxed) > 0 {
all_topics.push(current_prefix.clone());
}
// Create a list of children to visit *after* releasing the read lock.
// Cloning the Arc<RwLock<TrieNode>> is cheap.
let children_to_visit: Vec<(u8, Arc<RwLock<TrieNode>>)> = node_read_guard
.children
.iter()
.map(|(byte, child_arc)| (*byte, child_arc.clone()))
.collect();
// Drop the read lock before recursing to avoid potential deadlocks if the
// recursion somehow tried to acquire a write lock later (unlikely here, but good practice).
drop(node_read_guard);
// Recurse into children
for (byte, child_node_arc) in children_to_visit {
let mut next_prefix = current_prefix.clone();
next_prefix.push(byte);
self
.collect_topics_recursive(child_node_arc, next_prefix, all_topics)
.await;
}
}
}