use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use serde_json::Value;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::error::{PipecatError, Result};
use crate::frames::{Frame, FrameDirection};
const DISPATCH_JOIN_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
pub const DEFAULT_DATA_CAPACITY: usize = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed,
Cancelled,
}
#[derive(Debug, Clone)]
pub enum BusPayload {
Frame {
frame: Frame,
direction: FrameDirection,
},
Activate {
args: Option<Value>,
},
Deactivate,
End {
reason: Option<String>,
},
Cancel {
reason: Option<String>,
},
AgentReady {
runner: String,
parent: Option<String>,
active: bool,
bridged: bool,
started_at: Option<f64>,
},
AgentRegistry {
runner: String,
agents: Vec<AgentRegistryEntry>,
},
AgentError {
error: String,
},
TaskRequest {
task_id: String,
task_name: Option<String>,
payload: Option<Value>,
},
TaskResponse {
task_id: String,
status: TaskStatus,
response: Option<Value>,
},
TaskResponseUrgent {
task_id: String,
status: TaskStatus,
response: Option<Value>,
},
TaskUpdate {
task_id: String,
update: Option<Value>,
},
TaskUpdateUrgent {
task_id: String,
update: Option<Value>,
},
TaskUpdateRequest {
task_id: String,
},
TaskCancel {
task_id: String,
reason: Option<String>,
},
TaskStreamStart {
task_id: String,
data: Option<Value>,
},
TaskStreamData {
task_id: String,
data: Option<Value>,
},
TaskStreamEnd {
task_id: String,
data: Option<Value>,
},
}
#[derive(Debug, Clone)]
pub struct BusMessage {
pub source: String,
pub target: Option<String>,
pub payload: BusPayload,
pub seq: u64,
}
impl BusMessage {
pub fn new(source: impl Into<String>, target: Option<String>, payload: BusPayload) -> Self {
Self {
source: source.into(),
target,
payload,
seq: 0,
}
}
pub fn is_system(&self) -> bool {
matches!(
self.payload,
BusPayload::End { .. }
| BusPayload::Cancel { .. }
| BusPayload::Activate { .. }
| BusPayload::Deactivate
| BusPayload::AgentReady { .. }
| BusPayload::AgentRegistry { .. }
| BusPayload::AgentError { .. }
| BusPayload::TaskResponseUrgent { .. }
| BusPayload::TaskUpdateUrgent { .. }
| BusPayload::TaskCancel { .. }
)
}
}
#[derive(Debug, Clone)]
pub struct AgentRegistryEntry {
pub name: String,
pub parent: Option<String>,
pub active: bool,
pub bridged: bool,
pub started_at: Option<f64>,
}
#[async_trait]
pub trait BusSubscriber: Send + Sync {
fn name(&self) -> &str;
async fn on_bus_message(&self, message: Arc<BusMessage>);
}
#[async_trait]
pub trait AgentBus: Send + Sync {
async fn subscribe(&self, subscriber: Arc<dyn BusSubscriber>) -> Result<()>;
async fn unsubscribe(&self, name: &str);
async fn send(&self, message: BusMessage);
async fn start(&self);
async fn stop(&self);
}
struct SubscriberHandle {
name: Arc<str>,
sys_tx: mpsc::UnboundedSender<Arc<BusMessage>>,
data_tx: mpsc::Sender<Arc<BusMessage>>,
dropped: Arc<AtomicU64>,
cancel: CancellationToken,
join: std::sync::Mutex<Option<JoinHandle<()>>>,
}
impl SubscriberHandle {
async fn shutdown(&self) {
self.cancel.cancel();
let handle = self.join.lock().unwrap().take();
if let Some(h) = handle {
let abort = h.abort_handle();
if tokio::time::timeout(DISPATCH_JOIN_TIMEOUT, h)
.await
.is_err()
{
log::warn!(
"bus: dispatch task for '{}' did not stop within {:?}, aborting",
self.name,
DISPATCH_JOIN_TIMEOUT
);
abort.abort();
}
}
}
}
async fn dispatch_loop(
subscriber: Arc<dyn BusSubscriber>,
mut sys_rx: mpsc::UnboundedReceiver<Arc<BusMessage>>,
mut data_rx: mpsc::Receiver<Arc<BusMessage>>,
cancel: CancellationToken,
) {
loop {
tokio::select! {
biased;
_ = cancel.cancelled() => break,
Some(msg) = sys_rx.recv() => subscriber.on_bus_message(msg).await,
Some(msg) = data_rx.recv() => subscriber.on_bus_message(msg).await,
else => break,
}
}
while let Ok(msg) = sys_rx.try_recv() {
subscriber.on_bus_message(msg).await;
}
log::debug!("bus: dispatch loop for '{}' exited", subscriber.name());
}
pub struct LocalAgentBus {
subscribers: ArcSwap<Vec<Arc<SubscriberHandle>>>,
write_lock: std::sync::Mutex<()>,
seq: AtomicU64,
data_capacity: usize,
running: AtomicBool,
}
impl LocalAgentBus {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_DATA_CAPACITY)
}
pub fn with_capacity(data_capacity: usize) -> Self {
Self {
subscribers: ArcSwap::from_pointee(Vec::new()),
write_lock: std::sync::Mutex::new(()),
seq: AtomicU64::new(0),
data_capacity: data_capacity.max(1),
running: AtomicBool::new(true),
}
}
pub fn dropped_count(&self, name: &str) -> Option<u64> {
self.subscribers
.load()
.iter()
.find(|s| &*s.name == name)
.map(|s| s.dropped.load(Ordering::Relaxed))
}
}
impl Default for LocalAgentBus {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AgentBus for LocalAgentBus {
async fn subscribe(&self, subscriber: Arc<dyn BusSubscriber>) -> Result<()> {
let name: Arc<str> = Arc::from(subscriber.name());
let (sys_tx, sys_rx) = mpsc::unbounded_channel();
let (data_tx, data_rx) = mpsc::channel(self.data_capacity);
let cancel = CancellationToken::new();
let join = tokio::spawn(dispatch_loop(subscriber, sys_rx, data_rx, cancel.clone()));
let handle = Arc::new(SubscriberHandle {
name: name.clone(),
sys_tx,
data_tx,
dropped: Arc::new(AtomicU64::new(0)),
cancel,
join: std::sync::Mutex::new(Some(join)),
});
let inserted = {
let _guard = self.write_lock.lock().unwrap();
let current = self.subscribers.load_full();
if current.iter().any(|s| s.name == name) {
false
} else {
let mut next = Vec::with_capacity(current.len() + 1);
next.extend(current.iter().cloned());
next.push(handle.clone());
self.subscribers.store(Arc::new(next));
true
}
};
if !inserted {
handle.shutdown().await;
return Err(PipecatError::pipeline(format!(
"Bus subscriber '{}' already exists",
name
)));
}
log::debug!("bus: subscribed '{}'", name);
Ok(())
}
async fn unsubscribe(&self, name: &str) {
let removed = {
let _guard = self.write_lock.lock().unwrap();
let current = self.subscribers.load_full();
let removed = current.iter().find(|s| &*s.name == name).cloned();
if removed.is_some() {
let next: Vec<Arc<SubscriberHandle>> = current
.iter()
.filter(|s| &*s.name != name)
.cloned()
.collect();
self.subscribers.store(Arc::new(next));
}
removed
};
if let Some(handle) = removed {
handle.shutdown().await;
log::debug!("bus: unsubscribed '{}'", name);
}
}
async fn send(&self, message: BusMessage) {
if !self.running.load(Ordering::Acquire) {
log::debug!(
"bus: dropping send from '{}' — bus is stopped",
message.source
);
return;
}
let mut message = message;
message.seq = self.seq.fetch_add(1, Ordering::Relaxed);
let is_system = message.is_system();
let msg = Arc::new(message);
let subs = self.subscribers.load();
for sub in subs.iter() {
if msg.source == *sub.name {
continue;
}
if let Some(target) = &msg.target {
if target.as_str() != &*sub.name {
continue;
}
}
if is_system {
if sub.sys_tx.send(msg.clone()).is_err() {
log::debug!(
"bus: system message to '{}' not delivered (receiver gone)",
sub.name
);
}
} else {
match sub.data_tx.try_send(msg.clone()) {
Ok(()) => {}
Err(mpsc::error::TrySendError::Full(_)) => {
let n = sub.dropped.fetch_add(1, Ordering::Relaxed) + 1;
if n == 1 || n % 100 == 0 {
log::warn!(
"bus: data channel full for '{}' — {} message(s) dropped so far",
sub.name,
n
);
}
}
Err(mpsc::error::TrySendError::Closed(_)) => {
log::debug!(
"bus: data message to '{}' not delivered (receiver gone)",
sub.name
);
}
}
}
}
}
async fn start(&self) {
self.running.store(true, Ordering::Release);
}
async fn stop(&self) {
self.running.store(false, Ordering::Release);
let old = {
let _guard = self.write_lock.lock().unwrap();
self.subscribers.swap(Arc::new(Vec::new()))
};
for handle in old.iter() {
handle.shutdown().await;
}
log::debug!("bus: stopped ({} subscriber(s) shut down)", old.len());
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use std::time::Duration;
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::Notify;
struct Recorder {
name: String,
received: TokioMutex<Vec<Arc<BusMessage>>>,
received_count: AtomicUsize,
gate: Option<Arc<Notify>>,
delay: Option<Duration>,
in_flight: AtomicBool,
notify_on_receive: Arc<Notify>,
}
impl Recorder {
fn new(name: &str) -> Arc<Self> {
Arc::new(Self {
name: name.to_string(),
received: TokioMutex::new(Vec::new()),
received_count: AtomicUsize::new(0),
gate: None,
delay: None,
in_flight: AtomicBool::new(false),
notify_on_receive: Arc::new(Notify::new()),
})
}
fn with_gate(name: &str, gate: Arc<Notify>) -> Arc<Self> {
Arc::new(Self {
name: name.to_string(),
received: TokioMutex::new(Vec::new()),
received_count: AtomicUsize::new(0),
gate: Some(gate),
delay: None,
in_flight: AtomicBool::new(false),
notify_on_receive: Arc::new(Notify::new()),
})
}
fn with_delay(name: &str, delay: Duration) -> Arc<Self> {
Arc::new(Self {
name: name.to_string(),
received: TokioMutex::new(Vec::new()),
received_count: AtomicUsize::new(0),
gate: None,
delay: Some(delay),
in_flight: AtomicBool::new(false),
notify_on_receive: Arc::new(Notify::new()),
})
}
async fn payload_names(&self) -> Vec<&'static str> {
self.received
.lock()
.await
.iter()
.map(|m| match &m.payload {
BusPayload::Frame { .. } => "frame",
BusPayload::Cancel { .. } => "cancel",
BusPayload::End { .. } => "end",
BusPayload::TaskUpdate { .. } => "update",
_ => "other",
})
.collect()
}
}
#[async_trait]
impl BusSubscriber for Recorder {
fn name(&self) -> &str {
&self.name
}
async fn on_bus_message(&self, message: Arc<BusMessage>) {
self.in_flight.store(true, Ordering::SeqCst);
if let Some(gate) = &self.gate {
gate.notified().await;
}
if let Some(d) = self.delay {
tokio::time::sleep(d).await;
}
self.received.lock().await.push(message);
self.received_count.fetch_add(1, Ordering::SeqCst);
self.in_flight.store(false, Ordering::SeqCst);
self.notify_on_receive.notify_waiters();
}
}
fn data_msg(source: &str, target: Option<&str>) -> BusMessage {
BusMessage::new(
source,
target.map(String::from),
BusPayload::TaskUpdate {
task_id: "t".into(),
update: None,
},
)
}
fn cancel_msg(source: &str, target: Option<&str>) -> BusMessage {
BusMessage::new(
source,
target.map(String::from),
BusPayload::Cancel { reason: None },
)
}
async fn wait_for<F: Fn() -> bool>(cond: F, timeout: Duration) -> bool {
let deadline = tokio::time::Instant::now() + timeout;
while tokio::time::Instant::now() < deadline {
if cond() {
return true;
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
cond()
}
#[tokio::test]
async fn system_priority_over_data() {
let bus = LocalAgentBus::new();
let sub = Recorder::with_delay("slow", Duration::from_millis(10));
bus.subscribe(sub.clone()).await.unwrap();
for _ in 0..50 {
bus.send(data_msg("src", Some("slow"))).await;
}
bus.send(cancel_msg("src", Some("slow"))).await;
assert!(
wait_for(
|| sub.received_count.load(Ordering::SeqCst) >= 51,
Duration::from_secs(10)
)
.await,
"not all messages delivered"
);
let names = sub.payload_names().await;
let cancel_pos = names.iter().position(|n| *n == "cancel").unwrap();
assert!(
cancel_pos < 5,
"Cancel was handled at position {cancel_pos}, expected near the front"
);
}
#[tokio::test]
async fn drop_accounting() {
let bus = LocalAgentBus::with_capacity(4);
let gate = Arc::new(Notify::new());
let sub = Recorder::with_gate("blocked", gate.clone());
bus.subscribe(sub.clone()).await.unwrap();
bus.send(data_msg("src", Some("blocked"))).await;
assert!(
wait_for(
|| sub.in_flight.load(Ordering::SeqCst),
Duration::from_secs(2)
)
.await,
"handler never started"
);
for _ in 0..10 {
bus.send(data_msg("src", Some("blocked"))).await;
}
assert_eq!(bus.dropped_count("blocked"), Some(6));
bus.send(cancel_msg("src", Some("blocked"))).await;
for _ in 0..10 {
gate.notify_waiters();
tokio::time::sleep(Duration::from_millis(10)).await;
}
assert!(
wait_for(
|| sub.received_count.load(Ordering::SeqCst) >= 6,
Duration::from_secs(5)
)
.await,
"queued messages not delivered, got {}",
sub.received_count.load(Ordering::SeqCst)
);
assert_eq!(sub.received_count.load(Ordering::SeqCst), 6);
let names = sub.payload_names().await;
assert!(names.contains(&"cancel"), "system message lost: {names:?}");
}
#[tokio::test]
async fn targeting_rules() {
let bus = LocalAgentBus::new();
let a = Recorder::new("a");
let b = Recorder::new("b");
let c = Recorder::new("c");
bus.subscribe(a.clone()).await.unwrap();
bus.subscribe(b.clone()).await.unwrap();
bus.subscribe(c.clone()).await.unwrap();
bus.send(data_msg("a", None)).await;
bus.send(data_msg("a", Some("b"))).await;
assert!(
wait_for(
|| b.received_count.load(Ordering::SeqCst) == 2
&& c.received_count.load(Ordering::SeqCst) == 1,
Duration::from_secs(2)
)
.await
);
assert_eq!(a.received_count.load(Ordering::SeqCst), 0);
assert_eq!(b.received_count.load(Ordering::SeqCst), 2);
assert_eq!(c.received_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn seq_monotonicity_and_order() {
let bus = Arc::new(LocalAgentBus::new());
let sub = Recorder::new("sink");
bus.subscribe(sub.clone()).await.unwrap();
let mut handles = Vec::new();
for i in 0..4 {
let bus = bus.clone();
handles.push(tokio::spawn(async move {
for j in 0..25 {
let mut m = data_msg(&format!("sender{i}"), Some("sink"));
if let BusPayload::TaskUpdate { task_id, .. } = &mut m.payload {
*task_id = format!("{i}-{j}");
}
bus.send(m).await;
}
}));
}
for h in handles {
h.await.unwrap();
}
assert!(
wait_for(
|| sub.received_count.load(Ordering::SeqCst) == 100,
Duration::from_secs(5)
)
.await
);
let received = sub.received.lock().await;
let mut seqs: Vec<u64> = received.iter().map(|m| m.seq).collect();
seqs.sort_unstable();
seqs.dedup();
assert_eq!(seqs.len(), 100, "duplicate seq values");
for i in 0..4 {
let js: Vec<usize> = received
.iter()
.filter_map(|m| match &m.payload {
BusPayload::TaskUpdate { task_id, .. } => {
let (s, j) = task_id.split_once('-')?;
(s == i.to_string()).then(|| j.parse::<usize>().unwrap())
}
_ => None,
})
.collect();
assert!(
js.windows(2).all(|w| w[0] < w[1]),
"sender {i} delivery out of order: {js:?}"
);
}
}
#[tokio::test]
async fn unsubscribe_while_handler_in_flight() {
let bus = LocalAgentBus::new();
let gate = Arc::new(Notify::new());
let sub = Recorder::with_gate("busy", gate.clone());
bus.subscribe(sub.clone()).await.unwrap();
bus.send(data_msg("src", Some("busy"))).await;
assert!(
wait_for(
|| sub.in_flight.load(Ordering::SeqCst),
Duration::from_secs(2)
)
.await
);
let g = gate.clone();
let release = tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(100)).await;
g.notify_waiters();
});
bus.unsubscribe("busy").await;
release.await.unwrap();
assert_eq!(sub.received_count.load(Ordering::SeqCst), 1);
assert!(!sub.in_flight.load(Ordering::SeqCst));
}
#[tokio::test]
async fn send_after_stop_is_noop() {
let bus = LocalAgentBus::new();
let sub = Recorder::new("a");
bus.subscribe(sub.clone()).await.unwrap();
bus.stop().await;
bus.send(data_msg("src", Some("a"))).await;
bus.send(cancel_msg("src", Some("a"))).await;
tokio::time::sleep(Duration::from_millis(50)).await;
assert_eq!(sub.received_count.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn duplicate_subscribe_rejected() {
let bus = LocalAgentBus::new();
bus.subscribe(Recorder::new("dup")).await.unwrap();
assert!(bus.subscribe(Recorder::new("dup")).await.is_err());
}
}