use crate::error::KowalskiError;
use crate::federation::acl::AclEnvelope;
use async_trait::async_trait;
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
const RECENT_ENVELOPE_IDS_CAP: usize = 2048;
#[async_trait]
pub trait MessageBroker: Send + Sync {
async fn publish(&self, envelope: &AclEnvelope) -> Result<(), KowalskiError>;
}
type SubscriberVec = Vec<tokio::sync::mpsc::Sender<AclEnvelope>>;
#[derive(Clone)]
pub struct MpscBroker {
inner: Arc<Mutex<HashMap<String, SubscriberVec>>>,
recent_envelope_ids: Arc<Mutex<VecDeque<String>>>,
}
impl MpscBroker {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
recent_envelope_ids: Arc::new(Mutex::new(VecDeque::with_capacity(
RECENT_ENVELOPE_IDS_CAP,
))),
}
}
pub fn subscribe(
&self,
topic: &str,
buffer: usize,
) -> tokio::sync::mpsc::Receiver<AclEnvelope> {
let (tx, rx) = tokio::sync::mpsc::channel(buffer);
self.inner
.lock()
.expect("mpsc broker lock")
.entry(topic.to_string())
.or_default()
.push(tx);
rx
}
pub async fn publish_to_topic(&self, envelope: &AclEnvelope) -> Result<(), KowalskiError> {
{
let mut recent = self
.recent_envelope_ids
.lock()
.expect("mpsc broker dedupe lock");
if recent.contains(&envelope.id) {
return Ok(());
}
recent.push_back(envelope.id.clone());
while recent.len() > RECENT_ENVELOPE_IDS_CAP {
recent.pop_front();
}
}
let topic = envelope.topic.clone();
let senders: Vec<_> = {
let g = self.inner.lock().expect("mpsc broker lock");
g.get(&topic).cloned().unwrap_or_default()
};
let mut any_sent = false;
let mut dropped = 0usize;
for s in senders {
match s.send(envelope.clone()).await {
Ok(_) => any_sent = true,
Err(_) => dropped += 1,
}
}
if dropped > 0 {
let mut g = self.inner.lock().expect("mpsc broker lock");
if let Some(existing) = g.get_mut(&topic) {
existing.retain(|s| !s.is_closed());
}
}
if !any_sent {
return Ok(());
}
Ok(())
}
}
#[async_trait]
impl MessageBroker for MpscBroker {
async fn publish(&self, envelope: &AclEnvelope) -> Result<(), KowalskiError> {
self.publish_to_topic(envelope).await
}
}
impl Default for MpscBroker {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::federation::acl::AclMessage;
#[tokio::test]
async fn two_subscribers_receive_delegate() {
let broker = MpscBroker::new();
let mut a = broker.subscribe("tasks", 8);
let mut b = broker.subscribe("tasks", 8);
let env = AclEnvelope::new(
"tasks",
"orch",
AclMessage::TaskDelegate {
task_id: "1".into(),
from_agent: "orch".into(),
to_agent: "agent-b".into(),
instruction: "go".into(),
delegation_depth: 0,
max_delegation_depth: None,
},
);
broker.publish_to_topic(&env).await.unwrap();
let ra = a.recv().await.unwrap();
let rb = b.recv().await.unwrap();
assert_eq!(ra.payload, env.payload);
assert_eq!(rb.payload, env.payload);
}
#[tokio::test]
async fn duplicate_envelope_id_not_delivered_twice() {
let broker = MpscBroker::new();
let mut sub = broker.subscribe("tasks", 8);
let env = AclEnvelope::new(
"tasks",
"orch",
AclMessage::Ping {
text: "once".into(),
},
);
broker.publish_to_topic(&env).await.unwrap();
broker.publish_to_topic(&env).await.unwrap();
let _ = sub.recv().await.unwrap();
assert!(sub.try_recv().is_err());
}
#[tokio::test]
async fn topic_isolation() {
let broker = MpscBroker::new();
let mut t1 = broker.subscribe("t1", 4);
broker
.publish_to_topic(&AclEnvelope::new(
"t2",
"x",
AclMessage::Ping { text: "hi".into() },
))
.await
.unwrap();
assert!(t1.try_recv().is_err());
}
}