pub mod message;
pub use message::{InboundMessage, MediaAttachment, MediaType, OutboundMessage};
use crate::error::{Result, ZeptoError};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex;
const DEFAULT_BUFFER_SIZE: usize = 100;
pub struct MessageBus {
inbound_tx: mpsc::Sender<InboundMessage>,
inbound_rx: Arc<Mutex<mpsc::Receiver<InboundMessage>>>,
outbound_tx: mpsc::Sender<OutboundMessage>,
outbound_rx: Arc<Mutex<mpsc::Receiver<OutboundMessage>>>,
}
impl MessageBus {
pub fn new() -> Self {
Self::with_buffer_size(DEFAULT_BUFFER_SIZE)
}
pub fn with_buffer_size(buffer_size: usize) -> Self {
let (inbound_tx, inbound_rx) = mpsc::channel(buffer_size);
let (outbound_tx, outbound_rx) = mpsc::channel(buffer_size);
Self {
inbound_tx,
inbound_rx: Arc::new(Mutex::new(inbound_rx)),
outbound_tx,
outbound_rx: Arc::new(Mutex::new(outbound_rx)),
}
}
pub async fn publish_inbound(&self, msg: InboundMessage) -> Result<()> {
self.inbound_tx
.send(msg)
.await
.map_err(|_| ZeptoError::BusClosed)
}
pub async fn consume_inbound(&self) -> Option<InboundMessage> {
self.inbound_rx.lock().await.recv().await
}
pub async fn publish_outbound(&self, msg: OutboundMessage) -> Result<()> {
self.outbound_tx
.send(msg)
.await
.map_err(|_| ZeptoError::BusClosed)
}
pub async fn consume_outbound(&self) -> Option<OutboundMessage> {
self.outbound_rx.lock().await.recv().await
}
pub fn inbound_sender(&self) -> mpsc::Sender<InboundMessage> {
self.inbound_tx.clone()
}
pub fn outbound_sender(&self) -> mpsc::Sender<OutboundMessage> {
self.outbound_tx.clone()
}
pub fn try_publish_inbound(&self, msg: InboundMessage) -> Result<()> {
self.inbound_tx.try_send(msg).map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => {
ZeptoError::Channel("inbound buffer full".to_string())
}
mpsc::error::TrySendError::Closed(_) => ZeptoError::BusClosed,
})
}
pub fn try_publish_outbound(&self, msg: OutboundMessage) -> Result<()> {
self.outbound_tx.try_send(msg).map_err(|e| match e {
mpsc::error::TrySendError::Full(_) => {
ZeptoError::Channel("outbound buffer full".to_string())
}
mpsc::error::TrySendError::Closed(_) => ZeptoError::BusClosed,
})
}
}
impl Default for MessageBus {
fn default() -> Self {
Self::new()
}
}
impl Clone for MessageBus {
fn clone(&self) -> Self {
Self {
inbound_tx: self.inbound_tx.clone(),
inbound_rx: Arc::clone(&self.inbound_rx),
outbound_tx: self.outbound_tx.clone(),
outbound_rx: Arc::clone(&self.outbound_rx),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_inbound_message_creation() {
let msg = InboundMessage::new("telegram", "user123", "chat456", "Hello");
assert_eq!(msg.channel, "telegram");
assert_eq!(msg.content, "Hello");
assert_eq!(msg.session_key, "telegram:chat456");
}
#[test]
fn test_message_bus_creation() {
let bus = MessageBus::new();
drop(bus);
}
#[test]
fn test_message_bus_with_custom_buffer() {
let bus = MessageBus::with_buffer_size(50);
drop(bus);
}
#[test]
fn test_message_bus_default() {
let bus = MessageBus::default();
drop(bus);
}
#[test]
fn test_message_bus_clone() {
let bus1 = MessageBus::new();
let bus2 = bus1.clone();
drop(bus1);
drop(bus2);
}
#[tokio::test]
async fn test_bus_inbound_flow() {
let bus = MessageBus::new();
let msg = InboundMessage::new("telegram", "user123", "chat456", "Hello");
bus.publish_inbound(msg.clone()).await.unwrap();
let received = bus.consume_inbound().await.unwrap();
assert_eq!(received.content, "Hello");
assert_eq!(received.channel, "telegram");
assert_eq!(received.sender_id, "user123");
assert_eq!(received.chat_id, "chat456");
}
#[tokio::test]
async fn test_bus_outbound_flow() {
let bus = MessageBus::new();
let msg = OutboundMessage::new("telegram", "chat456", "Response");
bus.publish_outbound(msg).await.unwrap();
let received = bus.consume_outbound().await.unwrap();
assert_eq!(received.content, "Response");
assert_eq!(received.channel, "telegram");
assert_eq!(received.chat_id, "chat456");
}
#[tokio::test]
async fn test_bus_multiple_messages() {
let bus = MessageBus::new();
for i in 0..5 {
let msg = InboundMessage::new("telegram", "user", "chat", &format!("Message {}", i));
bus.publish_inbound(msg).await.unwrap();
}
for i in 0..5 {
let received = bus.consume_inbound().await.unwrap();
assert_eq!(received.content, format!("Message {}", i));
}
}
#[tokio::test]
async fn test_bus_sender_clones() {
let bus = MessageBus::new();
let sender1 = bus.inbound_sender();
let sender2 = bus.inbound_sender();
let msg1 = InboundMessage::new("telegram", "user1", "chat1", "From sender 1");
let msg2 = InboundMessage::new("discord", "user2", "chat2", "From sender 2");
sender1.send(msg1).await.unwrap();
sender2.send(msg2).await.unwrap();
let received1 = bus.consume_inbound().await.unwrap();
let received2 = bus.consume_inbound().await.unwrap();
assert_eq!(received1.content, "From sender 1");
assert_eq!(received2.content, "From sender 2");
}
#[tokio::test]
async fn test_bus_concurrent_access() {
let bus = Arc::new(MessageBus::new());
let bus_clone = Arc::clone(&bus);
let producer = tokio::spawn(async move {
for i in 0..10 {
let msg = InboundMessage::new("test", "user", "chat", &format!("Msg {}", i));
bus_clone.publish_inbound(msg).await.unwrap();
}
});
let bus_clone2 = Arc::clone(&bus);
let consumer = tokio::spawn(async move {
let mut count = 0;
while count < 10 {
if let Some(_msg) = bus_clone2.consume_inbound().await {
count += 1;
}
}
count
});
producer.await.unwrap();
let consumed = consumer.await.unwrap();
assert_eq!(consumed, 10);
}
#[tokio::test]
async fn test_try_publish_inbound() {
let bus = MessageBus::with_buffer_size(2);
let msg1 = InboundMessage::new("test", "user", "chat", "Msg 1");
let msg2 = InboundMessage::new("test", "user", "chat", "Msg 2");
bus.try_publish_inbound(msg1).unwrap();
bus.try_publish_inbound(msg2).unwrap();
let msg3 = InboundMessage::new("test", "user", "chat", "Msg 3");
let result = bus.try_publish_inbound(msg3);
assert!(matches!(result, Err(ZeptoError::Channel(_))));
}
#[tokio::test]
async fn test_try_publish_outbound() {
let bus = MessageBus::with_buffer_size(2);
let msg1 = OutboundMessage::new("test", "chat", "Msg 1");
let msg2 = OutboundMessage::new("test", "chat", "Msg 2");
bus.try_publish_outbound(msg1).unwrap();
bus.try_publish_outbound(msg2).unwrap();
let msg3 = OutboundMessage::new("test", "chat", "Msg 3");
let result = bus.try_publish_outbound(msg3);
assert!(matches!(result, Err(ZeptoError::Channel(_))));
}
#[tokio::test]
async fn test_outbound_with_reply() {
let bus = MessageBus::new();
let msg = OutboundMessage::new("telegram", "chat456", "This is a reply")
.with_reply("original_msg_123");
bus.publish_outbound(msg).await.unwrap();
let received = bus.consume_outbound().await.unwrap();
assert_eq!(received.reply_to, Some("original_msg_123".to_string()));
}
#[tokio::test]
async fn test_inbound_with_media() {
let bus = MessageBus::new();
let media = MediaAttachment::new(MediaType::Image)
.with_url("https://example.com/image.png")
.with_filename("photo.png");
let msg = InboundMessage::new("telegram", "user123", "chat456", "Check this out!")
.with_media(media);
bus.publish_inbound(msg).await.unwrap();
let received = bus.consume_inbound().await.unwrap();
assert!(received.has_media());
let attachment = received.media.unwrap();
assert_eq!(attachment.media_type, MediaType::Image);
assert!(attachment.has_url());
}
#[tokio::test]
async fn test_bus_reply_to_inbound() {
let bus = MessageBus::new();
let inbound = InboundMessage::new("telegram", "user123", "chat456", "Hello bot!");
bus.publish_inbound(inbound).await.unwrap();
let received = bus.consume_inbound().await.unwrap();
let response = OutboundMessage::reply_to(&received, "Hello human!");
bus.publish_outbound(response).await.unwrap();
let outgoing = bus.consume_outbound().await.unwrap();
assert_eq!(outgoing.channel, "telegram");
assert_eq!(outgoing.chat_id, "chat456");
assert_eq!(outgoing.content, "Hello human!");
}
}