use std::collections::HashMap;
use std::sync::Arc;
use anyhow::{Result, bail};
use async_trait::async_trait;
use tokio::sync::{Mutex, broadcast};
use super::traits::{Transport, TransportAddress};
use crate::network::{MessageEnvelope, TransportType};
pub struct PubSubTransport {
topics: Arc<Mutex<HashMap<String, broadcast::Sender<MessageEnvelope>>>>,
incoming_tx: broadcast::Sender<MessageEnvelope>,
incoming_rx: Mutex<Option<broadcast::Receiver<MessageEnvelope>>>,
connected: bool,
buffer_size: usize,
}
impl PubSubTransport {
pub fn new() -> Self {
Self::with_buffer_size(256)
}
pub fn with_buffer_size(buffer_size: usize) -> Self {
let (incoming_tx, incoming_rx) = broadcast::channel(buffer_size);
Self {
topics: Arc::new(Mutex::new(HashMap::new())),
incoming_tx,
incoming_rx: Mutex::new(Some(incoming_rx)),
connected: false,
buffer_size,
}
}
pub async fn subscribe_topic(&self, topic: &str) -> broadcast::Receiver<MessageEnvelope> {
let mut topics = self.topics.lock().await;
let sender = topics
.entry(topic.to_string())
.or_insert_with(|| broadcast::channel(self.buffer_size).0);
sender.subscribe()
}
async fn get_topic_sender(&self, topic: &str) -> broadcast::Sender<MessageEnvelope> {
let mut topics = self.topics.lock().await;
topics
.entry(topic.to_string())
.or_insert_with(|| broadcast::channel(self.buffer_size).0)
.clone()
}
pub fn incoming_sender(&self) -> broadcast::Sender<MessageEnvelope> {
self.incoming_tx.clone()
}
}
impl Default for PubSubTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Transport for PubSubTransport {
async fn connect(&mut self, target: &TransportAddress) -> Result<()> {
match target {
TransportAddress::Channel(_) => {
self.connected = true;
Ok(())
}
_ => bail!("PubSubTransport only supports Channel addresses"),
}
}
async fn disconnect(&mut self) -> Result<()> {
self.connected = false;
Ok(())
}
async fn send(&self, envelope: &MessageEnvelope) -> Result<()> {
if !self.connected {
bail!("PubSubTransport not connected");
}
match &envelope.recipient {
crate::network::MessageTarget::Topic(topic) => {
let sender = self.get_topic_sender(topic).await;
let _ = sender.send(envelope.clone());
Ok(())
}
crate::network::MessageTarget::Broadcast => {
let topics = self.topics.lock().await;
for sender in topics.values() {
let _ = sender.send(envelope.clone());
}
Ok(())
}
crate::network::MessageTarget::Direct(_) => {
bail!("PubSubTransport does not support direct messages; use topic addressing");
}
}
}
async fn receive(&self) -> Result<Option<MessageEnvelope>> {
if !self.connected {
bail!("PubSubTransport not connected");
}
let mut rx_guard = self.incoming_rx.lock().await;
if let Some(rx) = rx_guard.as_mut() {
match rx.recv().await {
Ok(envelope) => Ok(Some(envelope)),
Err(broadcast::error::RecvError::Closed) => Ok(None),
Err(broadcast::error::RecvError::Lagged(n)) => {
tracing::warn!("PubSubTransport receiver lagged by {n} messages");
match rx.recv().await {
Ok(envelope) => Ok(Some(envelope)),
_ => Ok(None),
}
}
}
} else {
Ok(None)
}
}
fn transport_type(&self) -> TransportType {
TransportType::PubSub
}
fn is_connected(&self) -> bool {
self.connected
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::network::Payload;
use uuid::Uuid;
#[tokio::test]
async fn pubsub_topic_delivery() {
let transport = PubSubTransport::new();
let mut rx = transport.subscribe_topic("events").await;
let env = MessageEnvelope::topic(Uuid::new_v4(), "events", Payload::Text("update".into()));
let sender = transport.get_topic_sender("events").await;
sender.send(env.clone()).unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.id, env.id);
}
#[tokio::test]
async fn pubsub_no_subscribers_ok() {
let mut transport = PubSubTransport::new();
transport
.connect(&TransportAddress::Channel("test".into()))
.await
.unwrap();
let env = MessageEnvelope::topic(
Uuid::new_v4(),
"nobody-listening",
Payload::Text("hello".into()),
);
transport.send(&env).await.unwrap();
}
#[tokio::test]
async fn pubsub_rejects_direct() {
let mut transport = PubSubTransport::new();
transport
.connect(&TransportAddress::Channel("test".into()))
.await
.unwrap();
let env = MessageEnvelope::direct(
Uuid::new_v4(),
Uuid::new_v4(),
Payload::Text("hello".into()),
);
assert!(transport.send(&env).await.is_err());
}
#[test]
fn pubsub_transport_type() {
let t = PubSubTransport::new();
assert_eq!(t.transport_type(), TransportType::PubSub);
assert!(!t.is_connected());
}
}