use crate::{Error, Message, Result};
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, RwLock};
#[async_trait]
pub trait Transport: Send + Sync + 'static {
async fn send(&self, msg: Message) -> Result<()>;
async fn recv(&self) -> Result<Message>;
async fn close(&self) -> Result<()>;
}
static MESSAGE_BUS: once_cell::sync::Lazy<Arc<MessageBus>> =
once_cell::sync::Lazy::new(|| Arc::new(MessageBus::new()));
struct MessageBus {
sender: broadcast::Sender<Message>,
processes: Arc<RwLock<HashMap<String, ProcessInfo>>>,
}
struct ProcessInfo {
#[allow(dead_code)]
name: String,
#[allow(dead_code)]
registered_at: std::time::Instant,
}
impl MessageBus {
fn new() -> Self {
let (sender, _) = broadcast::channel(1024);
Self {
sender,
processes: Arc::new(RwLock::new(HashMap::new())),
}
}
async fn register_process(&self, name: &str) -> broadcast::Receiver<Message> {
let mut processes = self.processes.write().await;
processes.insert(
name.to_string(),
ProcessInfo {
name: name.to_string(),
registered_at: std::time::Instant::now(),
},
);
self.sender.subscribe()
}
async fn send_message(&self, msg: Message) -> Result<()> {
self.sender
.send(msg)
.map_err(|_| Error::connection_msg("broadcast channel send failed"))?;
Ok(())
}
#[allow(dead_code)]
async fn unregister_process(&self, name: &str) {
let mut processes = self.processes.write().await;
processes.remove(name);
}
}
pub struct IpmbTransport {
process_name: String,
#[allow(dead_code)]
receiver: Arc<RwLock<broadcast::Receiver<Message>>>,
local_receiver: Arc<RwLock<mpsc::Receiver<Message>>>,
_receiver_task: Arc<tokio::task::JoinHandle<()>>,
}
impl IpmbTransport {
pub async fn new(process_name: &str) -> Result<Self> {
let mut bus_receiver = MESSAGE_BUS.register_process(process_name).await;
let (local_tx, local_rx) = mpsc::channel(1024);
let process_name_clone = process_name.to_string();
let receiver_task = tokio::spawn(async move {
while let Ok(msg) = bus_receiver.recv().await {
let should_receive = match &msg.target {
Some(target) => target == &process_name_clone,
None => {
true
}
};
if should_receive && local_tx.send(msg).await.is_err() {
break; }
}
});
Ok(Self {
process_name: process_name.to_string(),
receiver: Arc::new(RwLock::new(MESSAGE_BUS.sender.subscribe())),
local_receiver: Arc::new(RwLock::new(local_rx)),
_receiver_task: Arc::new(receiver_task),
})
}
}
#[async_trait]
impl Transport for IpmbTransport {
async fn send(&self, msg: Message) -> Result<()> {
MESSAGE_BUS.send_message(msg).await
}
async fn recv(&self) -> Result<Message> {
let mut receiver = self.local_receiver.write().await;
receiver
.recv()
.await
.ok_or(Error::connection_msg("message channel closed"))
}
async fn close(&self) -> Result<()> {
MESSAGE_BUS.unregister_process(&self.process_name).await;
Ok(())
}
}
#[cfg(test)]
pub struct MockTransport {
tx: mpsc::Sender<Message>,
rx: Arc<RwLock<mpsc::Receiver<Message>>>,
}
#[cfg(test)]
impl MockTransport {
pub fn new() -> (Self, mpsc::Receiver<Message>) {
let (tx1, rx1) = mpsc::channel(100);
let (_tx2, rx2) = mpsc::channel(100);
(
Self {
tx: tx1,
rx: Arc::new(RwLock::new(rx2)),
},
rx1,
)
}
}
#[cfg(test)]
#[async_trait]
impl Transport for MockTransport {
async fn send(&self, msg: Message) -> Result<()> {
self.tx
.send(msg)
.await
.map_err(|_| Error::connection_msg("mock transport send failed"))
}
async fn recv(&self) -> Result<Message> {
let mut rx = self.rx.write().await;
rx.recv()
.await
.ok_or(Error::connection_msg("mock transport recv failed"))
}
async fn close(&self) -> Result<()> {
Ok(())
}
}