enact_core/signal/
inmemory.rs1use super::{SignalBus, SignalReceiver};
16use async_trait::async_trait;
17use std::collections::HashMap;
18use std::sync::Arc;
19use tokio::sync::{broadcast, RwLock};
20
21pub struct InMemorySignalBus {
26 channels: Arc<RwLock<HashMap<String, broadcast::Sender<Vec<u8>>>>>,
27 capacity: usize,
28}
29
30impl InMemorySignalBus {
31 pub fn new(capacity: usize) -> Self {
33 Self {
34 channels: Arc::new(RwLock::new(HashMap::new())),
35 capacity,
36 }
37 }
38
39 #[allow(clippy::should_implement_trait)]
41 pub fn default() -> Self {
42 Self::new(1024)
43 }
44}
45
46#[async_trait]
47impl SignalBus for InMemorySignalBus {
48 async fn emit(&self, channel: &str, signal: &[u8]) -> anyhow::Result<()> {
49 let channels = self.channels.read().await;
50 if let Some(sender) = channels.get(channel) {
51 let _ = sender.send(signal.to_vec());
53 }
54 Ok(())
55 }
56
57 async fn subscribe(&self, channel: &str) -> anyhow::Result<SignalReceiver<Vec<u8>>> {
58 let mut channels = self.channels.write().await;
59 let sender = channels
60 .entry(channel.to_string())
61 .or_insert_with(|| broadcast::channel(self.capacity).0);
62 Ok(sender.subscribe())
63 }
64
65 async fn unsubscribe(&self, _channel: &str) -> anyhow::Result<()> {
66 Ok(())
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74
75 #[tokio::test]
76 async fn test_inmemory_signal_bus_new() {
77 let bus = InMemorySignalBus::new(100);
78 assert_eq!(bus.capacity, 100);
79 }
80
81 #[tokio::test]
82 async fn test_inmemory_signal_bus_default() {
83 let bus = InMemorySignalBus::default();
84 assert_eq!(bus.capacity, 1024);
85 }
86
87 #[tokio::test]
88 async fn test_subscribe_and_receive() {
89 let bus = InMemorySignalBus::default();
90
91 let mut rx = bus.subscribe("test-channel").await.unwrap();
93
94 bus.emit("test-channel", b"hello world").await.unwrap();
96
97 let received = rx.recv().await.unwrap();
99 assert_eq!(received, b"hello world".to_vec());
100 }
101
102 #[tokio::test]
103 async fn test_multiple_subscribers() {
104 let bus = InMemorySignalBus::default();
105
106 let mut rx1 = bus.subscribe("multi-channel").await.unwrap();
108 let mut rx2 = bus.subscribe("multi-channel").await.unwrap();
109
110 bus.emit("multi-channel", b"broadcast").await.unwrap();
112
113 let received1 = rx1.recv().await.unwrap();
115 let received2 = rx2.recv().await.unwrap();
116
117 assert_eq!(received1, b"broadcast".to_vec());
118 assert_eq!(received2, b"broadcast".to_vec());
119 }
120
121 #[tokio::test]
122 async fn test_emit_without_subscribers() {
123 let bus = InMemorySignalBus::default();
124
125 let result = bus.emit("no-subscribers", b"data").await;
127 assert!(result.is_ok());
128 }
129
130 #[tokio::test]
131 async fn test_emit_to_different_channels() {
132 let bus = InMemorySignalBus::default();
133
134 let mut rx1 = bus.subscribe("channel-a").await.unwrap();
135 let mut rx2 = bus.subscribe("channel-b").await.unwrap();
136
137 bus.emit("channel-a", b"msg-a").await.unwrap();
138 bus.emit("channel-b", b"msg-b").await.unwrap();
139
140 let received1 = rx1.recv().await.unwrap();
142 let received2 = rx2.recv().await.unwrap();
143
144 assert_eq!(received1, b"msg-a".to_vec());
145 assert_eq!(received2, b"msg-b".to_vec());
146 }
147
148 #[tokio::test]
149 async fn test_unsubscribe() {
150 let bus = InMemorySignalBus::default();
151
152 let _rx = bus.subscribe("unsub-channel").await.unwrap();
153
154 let result = bus.unsubscribe("unsub-channel").await;
156 assert!(result.is_ok());
157 }
158
159 #[tokio::test]
160 async fn test_multiple_messages() {
161 let bus = InMemorySignalBus::default();
162
163 let mut rx = bus.subscribe("multi-msg").await.unwrap();
164
165 bus.emit("multi-msg", b"first").await.unwrap();
166 bus.emit("multi-msg", b"second").await.unwrap();
167 bus.emit("multi-msg", b"third").await.unwrap();
168
169 assert_eq!(rx.recv().await.unwrap(), b"first".to_vec());
170 assert_eq!(rx.recv().await.unwrap(), b"second".to_vec());
171 assert_eq!(rx.recv().await.unwrap(), b"third".to_vec());
172 }
173
174 #[tokio::test]
175 async fn test_late_subscriber_misses_messages() {
176 let bus = InMemorySignalBus::default();
177
178 let mut rx1 = bus.subscribe("late-sub").await.unwrap();
180
181 bus.emit("late-sub", b"early").await.unwrap();
183
184 let mut rx2 = bus.subscribe("late-sub").await.unwrap();
186
187 bus.emit("late-sub", b"late").await.unwrap();
189
190 assert_eq!(rx1.recv().await.unwrap(), b"early".to_vec());
192 assert_eq!(rx1.recv().await.unwrap(), b"late".to_vec());
193
194 assert_eq!(rx2.recv().await.unwrap(), b"late".to_vec());
196 }
197
198 #[tokio::test]
199 async fn test_concurrent_emit() {
200 let bus = Arc::new(InMemorySignalBus::default());
201
202 let mut rx = bus.subscribe("concurrent").await.unwrap();
203
204 let bus1 = bus.clone();
205 let bus2 = bus.clone();
206
207 let h1 = tokio::spawn(async move {
209 for i in 0..5 {
210 bus1.emit("concurrent", format!("msg-a-{}", i).as_bytes())
211 .await
212 .unwrap();
213 }
214 });
215
216 let h2 = tokio::spawn(async move {
217 for i in 0..5 {
218 bus2.emit("concurrent", format!("msg-b-{}", i).as_bytes())
219 .await
220 .unwrap();
221 }
222 });
223
224 h1.await.unwrap();
225 h2.await.unwrap();
226
227 let mut received = Vec::new();
229 while let Ok(msg) = rx.try_recv() {
230 received.push(msg);
231 }
232
233 assert_eq!(received.len(), 10);
235 }
236}