1use async_trait::async_trait;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::Arc;
5use thiserror::Error;
6use tokio::sync::{broadcast, RwLock};
7use tracing::{debug, info};
8
9pub type Result<T> = std::result::Result<T, AdapterError>;
10
11#[derive(Error, Debug)]
12pub enum AdapterError {
13 #[error("Channel error: {0}")]
14 ChannelError(String),
15
16 #[error("Topic not found: {0}")]
17 TopicNotFound(String),
18
19 #[error("Serialization error: {0}")]
20 Serialization(#[from] serde_json::Error),
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Message {
26 pub topic: String,
27 pub payload: serde_json::Value,
28 pub timestamp: String,
29 pub metadata: HashMap<String, String>,
30}
31
32impl Message {
33 pub fn new(topic: impl Into<String>, payload: serde_json::Value) -> Self {
34 use std::time::SystemTime;
35 Self {
36 topic: topic.into(),
37 payload,
38 timestamp: SystemTime::now()
39 .duration_since(SystemTime::UNIX_EPOCH)
40 .unwrap()
41 .as_secs()
42 .to_string(),
43 metadata: HashMap::new(),
44 }
45 }
46
47 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
48 self.metadata.insert(key.into(), value.into());
49 self
50 }
51}
52
53#[async_trait]
55pub trait MessageHandler: Send + Sync {
56 async fn handle(&self, message: Message) -> Result<()>;
57}
58
59pub struct MemoryAdapter {
61 channels: Arc<RwLock<HashMap<String, broadcast::Sender<Message>>>>,
62 buffer_size: usize,
63}
64
65impl MemoryAdapter {
66 pub fn new(buffer_size: usize) -> Self {
67 Self {
68 channels: Arc::new(RwLock::new(HashMap::new())),
69 buffer_size,
70 }
71 }
72
73 async fn get_or_create_channel(&self, topic: &str) -> broadcast::Sender<Message> {
75 let mut channels = self.channels.write().await;
76
77 if let Some(sender) = channels.get(topic) {
78 sender.clone()
79 } else {
80 let (sender, _) = broadcast::channel(self.buffer_size);
81 channels.insert(topic.to_string(), sender.clone());
82 info!("Created channel for topic: {}", topic);
83 sender
84 }
85 }
86
87 pub async fn publish(
89 &self,
90 topic: impl Into<String>,
91 payload: serde_json::Value,
92 ) -> Result<()> {
93 let topic = topic.into();
94 let message = Message::new(topic.clone(), payload);
95
96 let sender = self.get_or_create_channel(&topic).await;
97
98 sender
99 .send(message)
100 .map_err(|e| AdapterError::ChannelError(format!("Failed to publish: {}", e)))?;
101
102 debug!("Published message to topic: {}", topic);
103 Ok(())
104 }
105
106 pub async fn subscribe<H>(&self, topic: impl Into<String>, handler: Arc<H>) -> Result<()>
108 where
109 H: MessageHandler + 'static,
110 {
111 let topic = topic.into();
112 let sender = self.get_or_create_channel(&topic).await;
113 let mut receiver = sender.subscribe();
114
115 info!("Subscribed to topic: {}", topic);
116
117 tokio::spawn(async move {
118 while let Ok(message) = receiver.recv().await {
119 if let Err(e) = handler.handle(message).await {
120 tracing::error!("Handler error: {}", e);
121 }
122 }
123 });
124
125 Ok(())
126 }
127
128 pub async fn subscribe_fn<F, Fut>(&self, topic: impl Into<String>, handler: F) -> Result<()>
130 where
131 F: Fn(Message) -> Fut + Send + Sync + 'static,
132 Fut: std::future::Future<Output = Result<()>> + Send + 'static,
133 {
134 struct ClosureHandler<F, Fut>
135 where
136 F: Fn(Message) -> Fut + Send + Sync,
137 Fut: std::future::Future<Output = Result<()>> + Send,
138 {
139 func: F,
140 }
141
142 #[async_trait]
143 impl<F, Fut> MessageHandler for ClosureHandler<F, Fut>
144 where
145 F: Fn(Message) -> Fut + Send + Sync,
146 Fut: std::future::Future<Output = Result<()>> + Send,
147 {
148 async fn handle(&self, message: Message) -> Result<()> {
149 (self.func)(message).await
150 }
151 }
152
153 let handler = Arc::new(ClosureHandler { func: handler });
154 self.subscribe(topic, handler).await
155 }
156
157 pub async fn list_topics(&self) -> Vec<String> {
159 let channels = self.channels.read().await;
160 channels.keys().cloned().collect()
161 }
162
163 pub async fn subscriber_count(&self, topic: &str) -> usize {
165 let channels = self.channels.read().await;
166 channels
167 .get(topic)
168 .map(|sender| sender.receiver_count())
169 .unwrap_or(0)
170 }
171}
172
173impl Default for MemoryAdapter {
174 fn default() -> Self {
175 Self::new(1000)
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use tokio::time::{sleep, Duration};
183
184 #[tokio::test]
185 async fn test_publish_subscribe() {
186 let adapter = MemoryAdapter::new(10);
187
188 let received = Arc::new(RwLock::new(Vec::new()));
189 let received_clone = received.clone();
190
191 adapter
192 .subscribe_fn("test_topic", move |msg| {
193 let received = received_clone.clone();
194 async move {
195 received.write().await.push(msg.payload.clone());
196 Ok(())
197 }
198 })
199 .await
200 .unwrap();
201
202 sleep(Duration::from_millis(10)).await;
203
204 adapter
205 .publish("test_topic", serde_json::json!({"value": 42}))
206 .await
207 .unwrap();
208
209 sleep(Duration::from_millis(10)).await;
210
211 let messages = received.read().await;
212 assert_eq!(messages.len(), 1);
213 assert_eq!(messages[0]["value"], 42);
214 }
215}