use super::{SignalBus, SignalReceiver};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, RwLock};
pub struct InMemorySignalBus {
channels: Arc<RwLock<HashMap<String, broadcast::Sender<Vec<u8>>>>>,
capacity: usize,
}
impl InMemorySignalBus {
pub fn new(capacity: usize) -> Self {
Self {
channels: Arc::new(RwLock::new(HashMap::new())),
capacity,
}
}
#[allow(clippy::should_implement_trait)]
pub fn default() -> Self {
Self::new(1024)
}
}
#[async_trait]
impl SignalBus for InMemorySignalBus {
async fn emit(&self, channel: &str, signal: &[u8]) -> anyhow::Result<()> {
let channels = self.channels.read().await;
if let Some(sender) = channels.get(channel) {
let _ = sender.send(signal.to_vec());
}
Ok(())
}
async fn subscribe(&self, channel: &str) -> anyhow::Result<SignalReceiver<Vec<u8>>> {
let mut channels = self.channels.write().await;
let sender = channels
.entry(channel.to_string())
.or_insert_with(|| broadcast::channel(self.capacity).0);
Ok(sender.subscribe())
}
async fn unsubscribe(&self, _channel: &str) -> anyhow::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_inmemory_signal_bus_new() {
let bus = InMemorySignalBus::new(100);
assert_eq!(bus.capacity, 100);
}
#[tokio::test]
async fn test_inmemory_signal_bus_default() {
let bus = InMemorySignalBus::default();
assert_eq!(bus.capacity, 1024);
}
#[tokio::test]
async fn test_subscribe_and_receive() {
let bus = InMemorySignalBus::default();
let mut rx = bus.subscribe("test-channel").await.unwrap();
bus.emit("test-channel", b"hello world").await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received, b"hello world".to_vec());
}
#[tokio::test]
async fn test_multiple_subscribers() {
let bus = InMemorySignalBus::default();
let mut rx1 = bus.subscribe("multi-channel").await.unwrap();
let mut rx2 = bus.subscribe("multi-channel").await.unwrap();
bus.emit("multi-channel", b"broadcast").await.unwrap();
let received1 = rx1.recv().await.unwrap();
let received2 = rx2.recv().await.unwrap();
assert_eq!(received1, b"broadcast".to_vec());
assert_eq!(received2, b"broadcast".to_vec());
}
#[tokio::test]
async fn test_emit_without_subscribers() {
let bus = InMemorySignalBus::default();
let result = bus.emit("no-subscribers", b"data").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_emit_to_different_channels() {
let bus = InMemorySignalBus::default();
let mut rx1 = bus.subscribe("channel-a").await.unwrap();
let mut rx2 = bus.subscribe("channel-b").await.unwrap();
bus.emit("channel-a", b"msg-a").await.unwrap();
bus.emit("channel-b", b"msg-b").await.unwrap();
let received1 = rx1.recv().await.unwrap();
let received2 = rx2.recv().await.unwrap();
assert_eq!(received1, b"msg-a".to_vec());
assert_eq!(received2, b"msg-b".to_vec());
}
#[tokio::test]
async fn test_unsubscribe() {
let bus = InMemorySignalBus::default();
let _rx = bus.subscribe("unsub-channel").await.unwrap();
let result = bus.unsubscribe("unsub-channel").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_multiple_messages() {
let bus = InMemorySignalBus::default();
let mut rx = bus.subscribe("multi-msg").await.unwrap();
bus.emit("multi-msg", b"first").await.unwrap();
bus.emit("multi-msg", b"second").await.unwrap();
bus.emit("multi-msg", b"third").await.unwrap();
assert_eq!(rx.recv().await.unwrap(), b"first".to_vec());
assert_eq!(rx.recv().await.unwrap(), b"second".to_vec());
assert_eq!(rx.recv().await.unwrap(), b"third".to_vec());
}
#[tokio::test]
async fn test_late_subscriber_misses_messages() {
let bus = InMemorySignalBus::default();
let mut rx1 = bus.subscribe("late-sub").await.unwrap();
bus.emit("late-sub", b"early").await.unwrap();
let mut rx2 = bus.subscribe("late-sub").await.unwrap();
bus.emit("late-sub", b"late").await.unwrap();
assert_eq!(rx1.recv().await.unwrap(), b"early".to_vec());
assert_eq!(rx1.recv().await.unwrap(), b"late".to_vec());
assert_eq!(rx2.recv().await.unwrap(), b"late".to_vec());
}
#[tokio::test]
async fn test_concurrent_emit() {
let bus = Arc::new(InMemorySignalBus::default());
let mut rx = bus.subscribe("concurrent").await.unwrap();
let bus1 = bus.clone();
let bus2 = bus.clone();
let h1 = tokio::spawn(async move {
for i in 0..5 {
bus1.emit("concurrent", format!("msg-a-{}", i).as_bytes())
.await
.unwrap();
}
});
let h2 = tokio::spawn(async move {
for i in 0..5 {
bus2.emit("concurrent", format!("msg-b-{}", i).as_bytes())
.await
.unwrap();
}
});
h1.await.unwrap();
h2.await.unwrap();
let mut received = Vec::new();
while let Ok(msg) = rx.try_recv() {
received.push(msg);
}
assert_eq!(received.len(), 10);
}
}