nt_memory/coordination/
pubsub.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{mpsc, RwLock};
6
7pub type Message = Vec<u8>;
9
10pub struct Subscription {
12 pub receiver: mpsc::Receiver<Message>,
14}
15
16pub struct PubSubBroker {
18 subscribers: Arc<RwLock<HashMap<String, Vec<mpsc::Sender<Message>>>>>,
20
21 buffer_size: usize,
23}
24
25impl PubSubBroker {
26 pub fn new() -> Self {
28 Self {
29 subscribers: Arc::new(RwLock::new(HashMap::new())),
30 buffer_size: 1000,
31 }
32 }
33
34 pub fn with_buffer_size(mut self, size: usize) -> Self {
36 self.buffer_size = size;
37 self
38 }
39
40 pub async fn subscribe(&self, topic: &str) -> anyhow::Result<mpsc::Receiver<Message>> {
42 let (tx, rx) = mpsc::channel(self.buffer_size);
43
44 let mut subscribers = self.subscribers.write().await;
45 subscribers
46 .entry(topic.to_string())
47 .or_insert_with(Vec::new)
48 .push(tx);
49
50 tracing::debug!("Subscribed to topic: {}", topic);
51
52 Ok(rx)
53 }
54
55 pub async fn publish(&self, topic: &str, message: Message) -> anyhow::Result<()> {
57 let subscribers = self.subscribers.read().await;
58
59 if let Some(subs) = subscribers.get(topic) {
60 let mut sent = 0;
61 let mut failed = 0;
62
63 for sender in subs {
64 match sender.try_send(message.clone()) {
65 Ok(()) => sent += 1,
66 Err(e) => {
67 tracing::warn!("Failed to send message: {:?}", e);
68 failed += 1;
69 }
70 }
71 }
72
73 tracing::debug!(
74 "Published to {}: {} sent, {} failed",
75 topic,
76 sent,
77 failed
78 );
79 } else {
80 tracing::debug!("No subscribers for topic: {}", topic);
81 }
82
83 Ok(())
84 }
85
86 pub async fn unsubscribe_all(&self, topic: &str) {
88 let mut subscribers = self.subscribers.write().await;
89 subscribers.remove(topic);
90
91 tracing::debug!("Unsubscribed all from topic: {}", topic);
92 }
93
94 pub async fn subscriber_count(&self, topic: &str) -> usize {
96 let subscribers = self.subscribers.read().await;
97 subscribers.get(topic).map(|s| s.len()).unwrap_or(0)
98 }
99
100 pub async fn list_topics(&self) -> Vec<String> {
102 let subscribers = self.subscribers.read().await;
103 subscribers.keys().cloned().collect()
104 }
105
106 pub async fn clear(&self) {
108 let mut subscribers = self.subscribers.write().await;
109 subscribers.clear();
110
111 tracing::debug!("Cleared all subscriptions");
112 }
113}
114
115impl Default for PubSubBroker {
116 fn default() -> Self {
117 Self::new()
118 }
119}
120
121#[cfg(test)]
122mod tests {
123 use super::*;
124
125 #[tokio::test]
126 async fn test_pubsub_basic() {
127 let broker = PubSubBroker::new();
128
129 let mut rx = broker.subscribe("test_topic").await.unwrap();
131
132 let message = b"test message".to_vec();
134 broker.publish("test_topic", message.clone()).await.unwrap();
135
136 let received = rx.recv().await.unwrap();
138 assert_eq!(received, message);
139 }
140
141 #[tokio::test]
142 async fn test_multiple_subscribers() {
143 let broker = PubSubBroker::new();
144
145 let mut rx1 = broker.subscribe("topic").await.unwrap();
147 let mut rx2 = broker.subscribe("topic").await.unwrap();
148 let mut rx3 = broker.subscribe("topic").await.unwrap();
149
150 assert_eq!(broker.subscriber_count("topic").await, 3);
151
152 let message = b"broadcast".to_vec();
154 broker.publish("topic", message.clone()).await.unwrap();
155
156 assert_eq!(rx1.recv().await.unwrap(), message);
158 assert_eq!(rx2.recv().await.unwrap(), message);
159 assert_eq!(rx3.recv().await.unwrap(), message);
160 }
161
162 #[tokio::test]
163 async fn test_topic_isolation() {
164 let broker = PubSubBroker::new();
165
166 let mut rx1 = broker.subscribe("topic1").await.unwrap();
167 let mut rx2 = broker.subscribe("topic2").await.unwrap();
168
169 broker.publish("topic1", b"message1".to_vec()).await.unwrap();
171
172 assert_eq!(rx1.recv().await.unwrap(), b"message1");
174
175 tokio::select! {
177 _ = rx2.recv() => panic!("Should not receive"),
178 _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => (),
179 }
180 }
181
182 #[tokio::test]
183 async fn test_unsubscribe() {
184 let broker = PubSubBroker::new();
185
186 let _rx = broker.subscribe("topic").await.unwrap();
187 assert_eq!(broker.subscriber_count("topic").await, 1);
188
189 broker.unsubscribe_all("topic").await;
190 assert_eq!(broker.subscriber_count("topic").await, 0);
191 }
192}