1use devsper_core::{Bus, BusMessage};
2use anyhow::Result;
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::pin::Pin;
6use std::future::Future;
7use std::sync::Arc;
8use tokio::sync::{broadcast, RwLock};
9use tracing::debug;
10
11pub struct InMemoryBus {
14 channels: Arc<RwLock<HashMap<String, broadcast::Sender<BusMessage>>>>,
16}
17
18impl InMemoryBus {
19 pub fn new() -> Self {
20 Self {
21 channels: Arc::new(RwLock::new(HashMap::new())),
22 }
23 }
24
25 async fn get_or_create_sender(&self, topic: &str) -> broadcast::Sender<BusMessage> {
26 {
27 let channels = self.channels.read().await;
28 if let Some(tx) = channels.get(topic) {
29 return tx.clone();
30 }
31 }
32 let mut channels = self.channels.write().await;
33 if let Some(tx) = channels.get(topic) {
35 return tx.clone();
36 }
37 let (tx, _) = broadcast::channel(1024);
38 channels.insert(topic.to_string(), tx.clone());
39 tx
40 }
41}
42
43impl Default for InMemoryBus {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49#[async_trait]
50impl Bus for InMemoryBus {
51 async fn publish(&self, msg: BusMessage) -> Result<()> {
52 let tx = self.get_or_create_sender(&msg.topic).await;
53 debug!(topic = %msg.topic, "Bus publish");
54 let _ = tx.send(msg);
56 Ok(())
57 }
58
59 async fn subscribe(
60 &self,
61 topic: &str,
62 handler: Box<
63 dyn Fn(BusMessage) -> Pin<Box<dyn Future<Output = ()> + Send>>
64 + Send + Sync,
65 >,
66 ) -> Result<()> {
67 let tx = self.get_or_create_sender(topic).await;
68 let mut rx = tx.subscribe();
69 let handler = Arc::new(handler);
70
71 tokio::spawn(async move {
72 while let Ok(msg) = rx.recv().await {
73 handler(msg).await;
74 }
75 });
76
77 Ok(())
78 }
79
80 async fn start(&self) -> Result<()> {
81 Ok(()) }
83
84 async fn stop(&self) -> Result<()> {
85 Ok(()) }
87}
88
89#[cfg(test)]
90mod tests {
91 use super::*;
92 use devsper_core::RunId;
93 use std::sync::atomic::{AtomicUsize, Ordering};
94
95 #[tokio::test]
96 async fn publish_subscribe_roundtrip() {
97 let bus = InMemoryBus::new();
98 let counter = Arc::new(AtomicUsize::new(0));
99 let c2 = counter.clone();
100
101 bus.subscribe(
102 "test.topic",
103 Box::new(move |_msg: BusMessage| {
104 let c = c2.clone();
105 Box::pin(async move {
106 c.fetch_add(1, Ordering::SeqCst);
107 })
108 }),
109 )
110 .await
111 .unwrap();
112
113 let msg = BusMessage::new(RunId::new(), "test.topic", serde_json::json!({"x": 1}));
114 bus.publish(msg).await.unwrap();
115
116 tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
118 assert_eq!(counter.load(Ordering::SeqCst), 1);
119 }
120
121 #[tokio::test]
122 async fn multiple_subscribers_all_receive() {
123 let bus = Arc::new(InMemoryBus::new());
124 let c1 = Arc::new(AtomicUsize::new(0));
125 let c2 = Arc::new(AtomicUsize::new(0));
126
127 let c1c = c1.clone();
128 bus.subscribe(
129 "shared",
130 Box::new(move |_| {
131 let c = c1c.clone();
132 Box::pin(async move {
133 c.fetch_add(1, Ordering::SeqCst);
134 })
135 }),
136 )
137 .await
138 .unwrap();
139
140 let c2c = c2.clone();
141 bus.subscribe(
142 "shared",
143 Box::new(move |_| {
144 let c = c2c.clone();
145 Box::pin(async move {
146 c.fetch_add(1, Ordering::SeqCst);
147 })
148 }),
149 )
150 .await
151 .unwrap();
152
153 bus.publish(BusMessage::new(RunId::new(), "shared", serde_json::json!(null)))
154 .await
155 .unwrap();
156
157 tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
158 assert_eq!(c1.load(Ordering::SeqCst), 1);
159 assert_eq!(c2.load(Ordering::SeqCst), 1);
160 }
161}