nt_memory/coordination/
pubsub.rs

1//! Pub/sub messaging system for agent coordination
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::{mpsc, RwLock};
6
7/// Message payload
8pub type Message = Vec<u8>;
9
10/// Subscription handle
11pub struct Subscription {
12    /// Receiver channel
13    pub receiver: mpsc::Receiver<Message>,
14}
15
16/// Pub/sub broker
17pub struct PubSubBroker {
18    /// Topic subscribers
19    subscribers: Arc<RwLock<HashMap<String, Vec<mpsc::Sender<Message>>>>>,
20
21    /// Channel buffer size
22    buffer_size: usize,
23}
24
25impl PubSubBroker {
26    /// Create new pub/sub broker
27    pub fn new() -> Self {
28        Self {
29            subscribers: Arc::new(RwLock::new(HashMap::new())),
30            buffer_size: 1000,
31        }
32    }
33
34    /// Configure buffer size
35    pub fn with_buffer_size(mut self, size: usize) -> Self {
36        self.buffer_size = size;
37        self
38    }
39
40    /// Subscribe to topic
41    pub async fn subscribe(&self, topic: &str) -> anyhow::Result<mpsc::Receiver<Message>> {
42        let (tx, rx) = mpsc::channel(self.buffer_size);
43
44        let mut subscribers = self.subscribers.write().await;
45        subscribers
46            .entry(topic.to_string())
47            .or_insert_with(Vec::new)
48            .push(tx);
49
50        tracing::debug!("Subscribed to topic: {}", topic);
51
52        Ok(rx)
53    }
54
55    /// Publish message to topic
56    pub async fn publish(&self, topic: &str, message: Message) -> anyhow::Result<()> {
57        let subscribers = self.subscribers.read().await;
58
59        if let Some(subs) = subscribers.get(topic) {
60            let mut sent = 0;
61            let mut failed = 0;
62
63            for sender in subs {
64                match sender.try_send(message.clone()) {
65                    Ok(()) => sent += 1,
66                    Err(e) => {
67                        tracing::warn!("Failed to send message: {:?}", e);
68                        failed += 1;
69                    }
70                }
71            }
72
73            tracing::debug!(
74                "Published to {}: {} sent, {} failed",
75                topic,
76                sent,
77                failed
78            );
79        } else {
80            tracing::debug!("No subscribers for topic: {}", topic);
81        }
82
83        Ok(())
84    }
85
86    /// Unsubscribe all from topic
87    pub async fn unsubscribe_all(&self, topic: &str) {
88        let mut subscribers = self.subscribers.write().await;
89        subscribers.remove(topic);
90
91        tracing::debug!("Unsubscribed all from topic: {}", topic);
92    }
93
94    /// Get topic subscriber count
95    pub async fn subscriber_count(&self, topic: &str) -> usize {
96        let subscribers = self.subscribers.read().await;
97        subscribers.get(topic).map(|s| s.len()).unwrap_or(0)
98    }
99
100    /// List all topics
101    pub async fn list_topics(&self) -> Vec<String> {
102        let subscribers = self.subscribers.read().await;
103        subscribers.keys().cloned().collect()
104    }
105
106    /// Clear all subscriptions
107    pub async fn clear(&self) {
108        let mut subscribers = self.subscribers.write().await;
109        subscribers.clear();
110
111        tracing::debug!("Cleared all subscriptions");
112    }
113}
114
115impl Default for PubSubBroker {
116    fn default() -> Self {
117        Self::new()
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124
125    #[tokio::test]
126    async fn test_pubsub_basic() {
127        let broker = PubSubBroker::new();
128
129        // Subscribe
130        let mut rx = broker.subscribe("test_topic").await.unwrap();
131
132        // Publish
133        let message = b"test message".to_vec();
134        broker.publish("test_topic", message.clone()).await.unwrap();
135
136        // Receive
137        let received = rx.recv().await.unwrap();
138        assert_eq!(received, message);
139    }
140
141    #[tokio::test]
142    async fn test_multiple_subscribers() {
143        let broker = PubSubBroker::new();
144
145        // Multiple subscribers
146        let mut rx1 = broker.subscribe("topic").await.unwrap();
147        let mut rx2 = broker.subscribe("topic").await.unwrap();
148        let mut rx3 = broker.subscribe("topic").await.unwrap();
149
150        assert_eq!(broker.subscriber_count("topic").await, 3);
151
152        // Publish
153        let message = b"broadcast".to_vec();
154        broker.publish("topic", message.clone()).await.unwrap();
155
156        // All should receive
157        assert_eq!(rx1.recv().await.unwrap(), message);
158        assert_eq!(rx2.recv().await.unwrap(), message);
159        assert_eq!(rx3.recv().await.unwrap(), message);
160    }
161
162    #[tokio::test]
163    async fn test_topic_isolation() {
164        let broker = PubSubBroker::new();
165
166        let mut rx1 = broker.subscribe("topic1").await.unwrap();
167        let mut rx2 = broker.subscribe("topic2").await.unwrap();
168
169        // Publish to topic1
170        broker.publish("topic1", b"message1".to_vec()).await.unwrap();
171
172        // Only rx1 should receive
173        assert_eq!(rx1.recv().await.unwrap(), b"message1");
174
175        // rx2 should timeout
176        tokio::select! {
177            _ = rx2.recv() => panic!("Should not receive"),
178            _ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => (),
179        }
180    }
181
182    #[tokio::test]
183    async fn test_unsubscribe() {
184        let broker = PubSubBroker::new();
185
186        let _rx = broker.subscribe("topic").await.unwrap();
187        assert_eq!(broker.subscriber_count("topic").await, 1);
188
189        broker.unsubscribe_all("topic").await;
190        assert_eq!(broker.subscriber_count("topic").await, 0);
191    }
192}