adapter_memory/
lib.rs

1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use thiserror::Error;
6use tokio::sync::{broadcast, RwLock};
7use tracing::{debug, info};
8
9pub type Result<T> = std::result::Result<T, AdapterError>;
10
11#[derive(Error, Debug)]
12pub enum AdapterError {
13    #[error("Channel error: {0}")]
14    ChannelError(String),
15
16    #[error("Topic not found: {0}")]
17    TopicNotFound(String),
18
19    #[error("Serialization error: {0}")]
20    Serialization(#[from] serde_json::Error),
21}
22
23/// Message envelope
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Message {
26    pub topic: String,
27    pub payload: serde_json::Value,
28    pub timestamp: String,
29    pub metadata: HashMap<String, String>,
30}
31
32impl Message {
33    pub fn new(topic: impl Into<String>, payload: serde_json::Value) -> Self {
34        use std::time::SystemTime;
35        Self {
36            topic: topic.into(),
37            payload,
38            timestamp: SystemTime::now()
39                .duration_since(SystemTime::UNIX_EPOCH)
40                .unwrap()
41                .as_secs()
42                .to_string(),
43            metadata: HashMap::new(),
44        }
45    }
46
47    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
48        self.metadata.insert(key.into(), value.into());
49        self
50    }
51}
52
53/// Message handler trait
54#[async_trait]
55pub trait MessageHandler: Send + Sync {
56    async fn handle(&self, message: Message) -> Result<()>;
57}
58
59/// Memory-based message broker
60pub struct MemoryAdapter {
61    channels: Arc<RwLock<HashMap<String, broadcast::Sender<Message>>>>,
62    buffer_size: usize,
63}
64
65impl MemoryAdapter {
66    pub fn new(buffer_size: usize) -> Self {
67        Self {
68            channels: Arc::new(RwLock::new(HashMap::new())),
69            buffer_size,
70        }
71    }
72
73    /// Create or get a channel for a topic
74    async fn get_or_create_channel(&self, topic: &str) -> broadcast::Sender<Message> {
75        let mut channels = self.channels.write().await;
76
77        if let Some(sender) = channels.get(topic) {
78            sender.clone()
79        } else {
80            let (sender, _) = broadcast::channel(self.buffer_size);
81            channels.insert(topic.to_string(), sender.clone());
82            info!("Created channel for topic: {}", topic);
83            sender
84        }
85    }
86
87    /// Publish a message to a topic
88    pub async fn publish(
89        &self,
90        topic: impl Into<String>,
91        payload: serde_json::Value,
92    ) -> Result<()> {
93        let topic = topic.into();
94        let message = Message::new(topic.clone(), payload);
95
96        let sender = self.get_or_create_channel(&topic).await;
97
98        sender
99            .send(message)
100            .map_err(|e| AdapterError::ChannelError(format!("Failed to publish: {}", e)))?;
101
102        debug!("Published message to topic: {}", topic);
103        Ok(())
104    }
105
106    /// Subscribe to a topic with a handler
107    pub async fn subscribe<H>(&self, topic: impl Into<String>, handler: Arc<H>) -> Result<()>
108    where
109        H: MessageHandler + 'static,
110    {
111        let topic = topic.into();
112        let sender = self.get_or_create_channel(&topic).await;
113        let mut receiver = sender.subscribe();
114
115        info!("Subscribed to topic: {}", topic);
116
117        tokio::spawn(async move {
118            while let Ok(message) = receiver.recv().await {
119                if let Err(e) = handler.handle(message).await {
120                    tracing::error!("Handler error: {}", e);
121                }
122            }
123        });
124
125        Ok(())
126    }
127
128    /// Subscribe with a closure
129    pub async fn subscribe_fn<F, Fut>(&self, topic: impl Into<String>, handler: F) -> Result<()>
130    where
131        F: Fn(Message) -> Fut + Send + Sync + 'static,
132        Fut: std::future::Future<Output = Result<()>> + Send + 'static,
133    {
134        struct ClosureHandler<F, Fut>
135        where
136            F: Fn(Message) -> Fut + Send + Sync,
137            Fut: std::future::Future<Output = Result<()>> + Send,
138        {
139            func: F,
140        }
141
142        #[async_trait]
143        impl<F, Fut> MessageHandler for ClosureHandler<F, Fut>
144        where
145            F: Fn(Message) -> Fut + Send + Sync,
146            Fut: std::future::Future<Output = Result<()>> + Send,
147        {
148            async fn handle(&self, message: Message) -> Result<()> {
149                (self.func)(message).await
150            }
151        }
152
153        let handler = Arc::new(ClosureHandler { func: handler });
154        self.subscribe(topic, handler).await
155    }
156
157    /// Get list of all topics
158    pub async fn list_topics(&self) -> Vec<String> {
159        let channels = self.channels.read().await;
160        channels.keys().cloned().collect()
161    }
162
163    /// Get subscriber count for a topic
164    pub async fn subscriber_count(&self, topic: &str) -> usize {
165        let channels = self.channels.read().await;
166        channels
167            .get(topic)
168            .map(|sender| sender.receiver_count())
169            .unwrap_or(0)
170    }
171}
172
173impl Default for MemoryAdapter {
174    fn default() -> Self {
175        Self::new(1000)
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182    use tokio::time::{sleep, Duration};
183
184    #[tokio::test]
185    async fn test_publish_subscribe() {
186        let adapter = MemoryAdapter::new(10);
187
188        let received = Arc::new(RwLock::new(Vec::new()));
189        let received_clone = received.clone();
190
191        adapter
192            .subscribe_fn("test_topic", move |msg| {
193                let received = received_clone.clone();
194                async move {
195                    received.write().await.push(msg.payload.clone());
196                    Ok(())
197                }
198            })
199            .await
200            .unwrap();
201
202        sleep(Duration::from_millis(10)).await;
203
204        adapter
205            .publish("test_topic", serde_json::json!({"value": 42}))
206            .await
207            .unwrap();
208
209        sleep(Duration::from_millis(10)).await;
210
211        let messages = received.read().await;
212        assert_eq!(messages.len(), 1);
213        assert_eq!(messages[0]["value"], 42);
214    }
215}