use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use async_trait::async_trait;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::sync::{mpsc, Mutex, OnceCell};
use tokio::task::JoinHandle;
use crate::clock::BaseClock;
use crate::error::{PipecatError, Result};
use crate::frames::{Frame, FrameDirection};
use crate::observer::BaseObserver;
use crate::pipeline::PipelineTask;
use super::bus::{AgentBus, BusMessage, BusPayload, BusSubscriber, TaskStatus};
use super::edges::BusOutputEdge;
use super::registry::AgentRegistry;
use super::task::{TaskContext, TaskUpdate};
const CHILD_END_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(5);
#[derive(Clone)]
pub struct TaskRequestCtx {
pub task_id: String,
pub task_name: Option<String>,
pub payload: Option<Value>,
pub source: String,
pub agent_name: String,
pub task_ctx: Arc<TaskContext>,
}
impl TaskRequestCtx {
pub async fn complete(&self, status: TaskStatus, response: Option<Value>) {
self.task_ctx
.complete_task(
&self.agent_name,
&self.source,
self.task_id.clone(),
status,
response,
)
.await;
}
pub async fn stream_start(&self, data: Option<Value>) {
self.task_ctx
.stream_start(&self.agent_name, &self.source, self.task_id.clone(), data)
.await;
}
pub async fn stream_data(&self, data: Option<Value>) {
self.task_ctx
.stream_data(&self.agent_name, &self.source, self.task_id.clone(), data)
.await;
}
pub async fn stream_end(&self, data: Option<Value>) {
self.task_ctx
.stream_end(&self.agent_name, &self.source, self.task_id.clone(), data)
.await;
}
}
pub type TaskHandler = Arc<dyn Fn(TaskRequestCtx) -> BoxFuture<'static, ()> + Send + Sync>;
struct ActiveJob {
source: String,
join: JoinHandle<()>,
}
#[async_trait]
pub trait Agent: BusSubscriber {
fn parent(&self) -> Option<&str>;
async fn setup(&self, bus: Arc<dyn AgentBus>, registry: Arc<AgentRegistry>) -> Result<()>;
async fn run(
&self,
clock: Arc<dyn BaseClock>,
observer: Option<Arc<dyn BaseObserver>>,
) -> Result<()>;
async fn end(&self, reason: Option<String>) -> Result<()>;
async fn cancel(&self, reason: Option<String>) -> Result<()>;
fn active(&self) -> bool;
fn bridged(&self) -> bool;
fn ready(&self) -> bool;
}
pub struct BaseAgent {
name: String,
parent: Option<String>,
pipeline_task: PipelineTask,
push_tx: mpsc::Sender<(Frame, FrameDirection)>,
active: Arc<AtomicBool>,
bridged: Option<Vec<String>>,
ready: AtomicBool,
ending: AtomicBool,
bus: OnceCell<Arc<dyn AgentBus>>,
registry: OnceCell<Arc<AgentRegistry>>,
task_ctx: OnceCell<Arc<TaskContext>>,
handlers: HashMap<String, TaskHandler>,
default_handler: Option<TaskHandler>,
active_jobs: Arc<Mutex<HashMap<String, ActiveJob>>>,
output_edge: Option<BusOutputEdge>,
}
impl BaseAgent {
pub fn new(
name: impl Into<String>,
pipeline_task: PipelineTask,
bridged: Option<Vec<String>>,
active_on_start: bool,
) -> Self {
let push_tx = pipeline_task.push_sender();
Self {
name: name.into(),
parent: None,
pipeline_task,
push_tx,
active: Arc::new(AtomicBool::new(active_on_start)),
bridged,
ready: AtomicBool::new(false),
ending: AtomicBool::new(false),
bus: OnceCell::new(),
registry: OnceCell::new(),
task_ctx: OnceCell::new(),
handlers: HashMap::new(),
default_handler: None,
active_jobs: Arc::new(Mutex::new(HashMap::new())),
output_edge: None,
}
}
pub fn bridged_pipeline(
name: impl Into<String>,
processors: Vec<crate::frames::FrameProcessor>,
params: crate::pipeline::PipelineParams,
peers: Vec<String>,
active_on_start: bool,
) -> Self {
let name = name.into();
let edge = BusOutputEdge::new(name.clone(), peers.clone());
let mut all = processors;
all.push(edge.to_processor());
let pipeline_task = PipelineTask::new(all, params);
let agent = Self::new(name, pipeline_task, Some(peers), active_on_start);
edge.bind_activation(agent.active.clone());
agent.with_output_edge(edge)
}
pub fn with_output_edge(mut self, edge: BusOutputEdge) -> Self {
edge.bind_activation(self.active.clone());
self.output_edge = Some(edge);
self
}
pub fn with_parent(mut self, parent: impl Into<String>) -> Self {
self.parent = Some(parent.into());
self
}
pub fn on_task(mut self, name: impl Into<String>, handler: TaskHandler) -> Self {
self.handlers.insert(name.into(), handler);
self
}
pub fn on_task_default(mut self, handler: TaskHandler) -> Self {
self.default_handler = Some(handler);
self
}
pub fn task_ctx(&self) -> Option<Arc<TaskContext>> {
self.task_ctx.get().cloned()
}
pub fn pipeline(&self) -> &PipelineTask {
&self.pipeline_task
}
pub fn active_flag(&self) -> Arc<AtomicBool> {
self.active.clone()
}
async fn announce_ready(&self) {
let registry = match self.registry.get() {
Some(r) => r,
None => return,
};
let info = super::registry::AgentInfo {
name: self.name.clone(),
runner: registry.runner_name().to_string(),
parent: self.parent.clone(),
active: self.active.load(Ordering::Relaxed),
bridged: self.bridged.is_some(),
started_at: Some(crate::clock::system_clock().get_time()),
};
registry.register(info.clone()).await;
if let Some(bus) = self.bus.get() {
let msg = BusMessage::new(
self.name.clone(),
None,
BusPayload::AgentReady {
runner: info.runner,
parent: info.parent,
active: info.active,
bridged: info.bridged,
started_at: info.started_at,
},
);
bus.send(msg).await;
}
}
fn accepts_bridged_from(&self, source: &str) -> bool {
match &self.bridged {
None => false,
Some(names) => names.is_empty() || names.iter().any(|n| n == source),
}
}
async fn handle_task_request(
&self,
task_id: &str,
task_name: &Option<String>,
payload: &Option<Value>,
source: &str,
) {
let task_ctx = match self.task_ctx.get() {
Some(ctx) => ctx.clone(),
None => {
log::error!(
"Agent '{}': TaskRequest before setup, dropping task {}",
self.name,
task_id
);
return;
}
};
let handler = task_name
.as_deref()
.and_then(|n| self.handlers.get(n))
.or(self.default_handler.as_ref())
.cloned();
let handler = match handler {
Some(h) => h,
None => {
log::warn!(
"Agent '{}': no handler for task '{}', failing task {}",
self.name,
task_name.as_deref().unwrap_or("<default>"),
task_id
);
task_ctx
.complete_task(
&self.name,
source,
task_id.to_string(),
TaskStatus::Failed,
Some(serde_json::json!({
"error": "no handler",
"task_name": task_name,
})),
)
.await;
return;
}
};
let ctx = TaskRequestCtx {
task_id: task_id.to_string(),
task_name: task_name.clone(),
payload: payload.clone(),
source: source.to_string(),
agent_name: self.name.clone(),
task_ctx,
};
let mut jobs = self.active_jobs.lock().await;
let jobs_for_task = self.active_jobs.clone();
let tid = task_id.to_string();
let fut = handler(ctx);
let join = tokio::spawn(async move {
fut.await;
jobs_for_task.lock().await.remove(&tid);
});
jobs.insert(
task_id.to_string(),
ActiveJob {
source: source.to_string(),
join,
},
);
}
async fn handle_task_cancel(&self, task_id: &str, reason: &Option<String>) {
let job = self.active_jobs.lock().await.remove(task_id);
if let Some(job) = job {
job.join.abort();
if let Some(ctx) = self.task_ctx.get() {
ctx.complete_task(
&self.name,
&job.source,
task_id.to_string(),
TaskStatus::Cancelled,
reason.as_ref().map(|r| serde_json::json!({ "reason": r })),
)
.await;
}
}
}
async fn cascade_to_children(&self, reason: &Option<String>, is_end: bool) {
let (bus, registry) = match (self.bus.get(), self.registry.get()) {
(Some(b), Some(r)) => (b, r),
_ => return,
};
let children = registry.children_of(&self.name).await;
if children.is_empty() {
return;
}
for child in &children {
let payload = if is_end {
BusPayload::End {
reason: reason.clone(),
}
} else {
BusPayload::Cancel {
reason: reason.clone(),
}
};
bus.send(BusMessage::new(
self.name.clone(),
Some(child.clone()),
payload,
))
.await;
}
if is_end {
for child in &children {
if tokio::time::timeout(CHILD_END_TIMEOUT, registry.wait_finished(child))
.await
.is_err()
{
log::warn!(
"Agent '{}': child '{}' did not finish within {:?}, continuing shutdown",
self.name,
child,
CHILD_END_TIMEOUT
);
}
}
}
}
async fn cleanup_jobs(&self, reason: &str) {
let jobs: Vec<(String, ActiveJob)> = self.active_jobs.lock().await.drain().collect();
if let Some(ctx) = self.task_ctx.get() {
for (task_id, job) in jobs {
job.join.abort();
ctx.complete_task(
&self.name,
&job.source,
task_id,
TaskStatus::Cancelled,
Some(serde_json::json!({ "reason": reason })),
)
.await;
}
ctx.fail_all_pending(reason).await;
} else {
for (_, job) in jobs {
job.join.abort();
}
}
}
}
#[async_trait]
impl BusSubscriber for BaseAgent {
fn name(&self) -> &str {
&self.name
}
async fn on_bus_message(&self, message: Arc<BusMessage>) {
if message.source == self.name {
return;
}
if let Some(target) = &message.target {
if target != &self.name {
return;
}
}
match &message.payload {
BusPayload::Frame { frame, direction }
if self.active.load(Ordering::Relaxed)
&& self.accepts_bridged_from(&message.source) =>
{
let _ = self.push_tx.send((frame.clone(), *direction)).await;
}
BusPayload::Activate { .. } => {
self.active.store(true, Ordering::Relaxed);
}
BusPayload::Deactivate => {
self.active.store(false, Ordering::Relaxed);
}
BusPayload::End { reason } => {
let _ = self.end(reason.clone()).await;
}
BusPayload::Cancel { reason } => {
let _ = self.cancel(reason.clone()).await;
}
BusPayload::TaskRequest {
task_id,
task_name,
payload,
} => {
self.handle_task_request(task_id, task_name, payload, &message.source)
.await;
}
BusPayload::TaskCancel { task_id, reason } => {
self.handle_task_cancel(task_id, reason).await;
}
BusPayload::TaskResponse {
task_id,
status,
response,
}
| BusPayload::TaskResponseUrgent {
task_id,
status,
response,
} => {
if let Some(ctx) = self.task_ctx.get() {
ctx.route_update(
task_id,
TaskUpdate::Response {
status: *status,
response: response.clone(),
},
)
.await;
}
}
BusPayload::TaskUpdate { task_id, update }
| BusPayload::TaskUpdateUrgent { task_id, update } => {
if let Some(ctx) = self.task_ctx.get() {
ctx.route_update(
task_id,
TaskUpdate::Update {
update: update.clone(),
},
)
.await;
}
}
BusPayload::TaskStreamStart { task_id, data } => {
if let Some(ctx) = self.task_ctx.get() {
ctx.route_update(task_id, TaskUpdate::StreamStart { data: data.clone() })
.await;
}
}
BusPayload::TaskStreamData { task_id, data } => {
if let Some(ctx) = self.task_ctx.get() {
ctx.route_update(task_id, TaskUpdate::StreamData { data: data.clone() })
.await;
}
}
BusPayload::TaskStreamEnd { task_id, data } => {
if let Some(ctx) = self.task_ctx.get() {
ctx.route_update(task_id, TaskUpdate::StreamEnd { data: data.clone() })
.await;
}
}
_ => {}
}
}
}
#[async_trait]
impl Agent for BaseAgent {
fn parent(&self) -> Option<&str> {
self.parent.as_deref()
}
async fn setup(&self, bus: Arc<dyn AgentBus>, registry: Arc<AgentRegistry>) -> Result<()> {
self.bus
.set(bus.clone())
.map_err(|_| PipecatError::pipeline("BaseAgent::setup called more than once"))?;
self.registry
.set(registry.clone())
.map_err(|_| PipecatError::pipeline("BaseAgent::setup called more than once"))?;
self.task_ctx
.set(Arc::new(TaskContext::new(bus.clone(), registry)))
.map_err(|_| PipecatError::pipeline("BaseAgent::setup called more than once"))?;
if let Some(edge) = &self.output_edge {
edge.set_bus(bus);
}
Ok(())
}
async fn run(
&self,
clock: Arc<dyn BaseClock>,
observer: Option<Arc<dyn BaseObserver>>,
) -> Result<()> {
self.ready.store(true, Ordering::Relaxed);
self.announce_ready().await;
let result = self.pipeline_task.run(clock, observer).await;
self.cleanup_jobs("agent ended").await;
if let Some(registry) = self.registry.get() {
registry.mark_finished(&self.name).await;
}
result
}
async fn end(&self, reason: Option<String>) -> Result<()> {
if self.ending.swap(true, Ordering::SeqCst) {
return Ok(());
}
self.cascade_to_children(&reason, true).await;
self.cleanup_jobs(reason.as_deref().unwrap_or("agent ended"))
.await;
let frame = match reason {
Some(r) => Frame::end_with(r),
None => Frame::end(),
};
let _ = self.push_tx.send((frame, FrameDirection::Downstream)).await;
Ok(())
}
async fn cancel(&self, reason: Option<String>) -> Result<()> {
if self.ending.swap(true, Ordering::SeqCst) {
return Ok(());
}
self.cascade_to_children(&reason, false).await;
self.cleanup_jobs(reason.as_deref().unwrap_or("agent cancelled"))
.await;
let frame = match reason {
Some(r) => Frame::cancel_with(r),
None => Frame::cancel(),
};
let _ = self.push_tx.send((frame, FrameDirection::Downstream)).await;
Ok(())
}
fn active(&self) -> bool {
self.active.load(Ordering::Relaxed)
}
fn bridged(&self) -> bool {
self.bridged.is_some()
}
fn ready(&self) -> bool {
self.ready.load(Ordering::Relaxed)
}
}