use async_trait::async_trait;
use tokio::sync::broadcast;
use crate::{CollaborationEvent, ExecutionError};
#[async_trait]
pub trait CollaborationTransport: Send + Sync {
async fn publish(&self, event: CollaborationEvent) -> Result<(), ExecutionError>;
fn subscribe(&self) -> Box<dyn CollaborationReceiver>;
}
#[async_trait]
pub trait CollaborationReceiver: Send {
async fn recv(&mut self) -> Option<CollaborationEvent>;
}
#[derive(Debug)]
pub struct LocalTransport {
tx: broadcast::Sender<CollaborationEvent>,
}
impl LocalTransport {
pub fn new(capacity: usize) -> Self {
let (tx, _rx) = broadcast::channel(capacity);
Self { tx }
}
}
impl Default for LocalTransport {
fn default() -> Self {
Self::new(256)
}
}
#[async_trait]
impl CollaborationTransport for LocalTransport {
async fn publish(&self, event: CollaborationEvent) -> Result<(), ExecutionError> {
let _ = self.tx.send(event);
Ok(())
}
fn subscribe(&self) -> Box<dyn CollaborationReceiver> {
Box::new(LocalReceiver { rx: self.tx.subscribe() })
}
}
struct LocalReceiver {
rx: broadcast::Receiver<CollaborationEvent>,
}
#[async_trait]
impl CollaborationReceiver for LocalReceiver {
async fn recv(&mut self) -> Option<CollaborationEvent> {
loop {
match self.rx.recv().await {
Ok(event) => return Some(event),
Err(broadcast::error::RecvError::Lagged(skipped)) => {
tracing::warn!(
skipped,
"local transport receiver lagged, {skipped} events dropped"
);
continue;
}
Err(broadcast::error::RecvError::Closed) => return None,
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CollaborationEventKind;
#[tokio::test]
async fn local_transport_publish_and_receive() {
let transport = LocalTransport::new(16);
let mut rx = transport.subscribe();
let event = CollaborationEvent::new(
"c1",
"api-routes",
"backend",
CollaborationEventKind::WorkPublished,
);
transport.publish(event).await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.correlation_id, "c1");
assert_eq!(received.kind, CollaborationEventKind::WorkPublished);
}
#[tokio::test]
async fn local_transport_multiple_subscribers() {
let transport = LocalTransport::new(16);
let mut rx1 = transport.subscribe();
let mut rx2 = transport.subscribe();
transport
.publish(CollaborationEvent::new(
"c1",
"topic",
"agent",
CollaborationEventKind::NeedWork,
))
.await
.unwrap();
let e1 = rx1.recv().await.unwrap();
let e2 = rx2.recv().await.unwrap();
assert_eq!(e1.correlation_id, "c1");
assert_eq!(e2.correlation_id, "c1");
}
#[tokio::test]
async fn local_transport_publish_with_no_subscribers_succeeds() {
let transport = LocalTransport::new(16);
let result = transport
.publish(CollaborationEvent::new(
"c1",
"topic",
"agent",
CollaborationEventKind::Completed,
))
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn local_transport_default_capacity() {
let transport = LocalTransport::default();
let mut rx = transport.subscribe();
transport
.publish(CollaborationEvent::new(
"c1",
"topic",
"agent",
CollaborationEventKind::WorkClaimed,
))
.await
.unwrap();
let event = rx.recv().await.unwrap();
assert_eq!(event.kind, CollaborationEventKind::WorkClaimed);
}
#[tokio::test]
async fn local_transport_preserves_event_fields() {
let transport = LocalTransport::new(16);
let mut rx = transport.subscribe();
let original = CollaborationEvent::new(
"corr-42",
"database-schema",
"db_engineer",
CollaborationEventKind::FeedbackRequested,
)
.consumer("reviewer")
.payload(serde_json::json!({ "tables": ["users", "orders"] }))
.timestamp(1719000000000);
transport.publish(original).await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.correlation_id, "corr-42");
assert_eq!(received.topic, "database-schema");
assert_eq!(received.producer, "db_engineer");
assert_eq!(received.consumer.as_deref(), Some("reviewer"));
assert_eq!(received.kind, CollaborationEventKind::FeedbackRequested);
assert_eq!(received.payload, serde_json::json!({ "tables": ["users", "orders"] }));
assert_eq!(received.timestamp, 1719000000000);
}
#[tokio::test]
async fn local_transport_event_ordering() {
let transport = LocalTransport::new(16);
let mut rx = transport.subscribe();
for i in 0..5 {
transport
.publish(CollaborationEvent::new(
format!("c{i}"),
"topic",
"agent",
CollaborationEventKind::NeedWork,
))
.await
.unwrap();
}
for i in 0..5 {
let event = rx.recv().await.unwrap();
assert_eq!(event.correlation_id, format!("c{i}"));
}
}
}