use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::Duration;
use serde_json::Value;
use tokio::sync::{mpsc, oneshot, Mutex};
use tokio::time::timeout;
use crate::error::{PipecatError, Result};
use super::bus::{AgentBus, BusMessage, BusPayload, TaskStatus};
use super::registry::AgentRegistry;
pub const DEFAULT_READY_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, Clone)]
pub enum TaskUpdate {
Response {
status: TaskStatus,
response: Option<Value>,
},
StreamStart {
data: Option<Value>,
},
StreamData {
data: Option<Value>,
},
StreamEnd {
data: Option<Value>,
},
Update {
update: Option<Value>,
},
}
#[derive(Debug, Clone)]
pub struct TaskResult {
pub task_id: String,
pub status: TaskStatus,
pub response: Option<Value>,
}
pub struct TaskHandle {
pub task_id: String,
pub target_agent: String,
rx: mpsc::UnboundedReceiver<TaskUpdate>,
}
impl TaskHandle {
pub async fn await_completion(
mut self,
timeout_duration: Option<Duration>,
) -> Result<TaskResult> {
let task_id = self.task_id.clone();
loop {
let update = match timeout_duration {
Some(d) => timeout(d, self.rx.recv())
.await
.map_err(|_| PipecatError::pipeline("Task timeout"))?,
None => self.rx.recv().await,
};
match update {
Some(TaskUpdate::Response { status, response }) => {
return Ok(TaskResult {
task_id,
status,
response,
})
}
Some(_) => continue,
None => return Err(PipecatError::pipeline("Task did not return a response")),
}
}
}
pub async fn stream_updates(
self,
timeout_duration: Option<Duration>,
) -> Result<(Vec<TaskUpdate>, TaskResult)> {
let mut rx = self.rx;
let task_id = self.task_id.clone();
let mut updates = Vec::new();
loop {
let update = match timeout_duration {
Some(d) => timeout(d, rx.recv())
.await
.map_err(|_| PipecatError::pipeline("Task stream timeout"))?,
None => rx.recv().await,
};
match update {
Some(TaskUpdate::Response { status, response }) => {
return Ok((
updates,
TaskResult {
task_id,
status,
response,
},
));
}
Some(u) => updates.push(u),
None => {
return Err(PipecatError::pipeline(
"Task stream closed without response",
))
}
}
}
}
}
pub type UpdateHandler =
Arc<dyn Fn(TaskUpdate) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
pub struct TaskContext {
bus: Arc<dyn AgentBus>,
registry: Arc<AgentRegistry>,
pending: Mutex<HashMap<String, mpsc::UnboundedSender<TaskUpdate>>>,
update_handlers: Mutex<HashMap<String, Vec<UpdateHandler>>>,
}
impl TaskContext {
pub fn new(bus: Arc<dyn AgentBus>, registry: Arc<AgentRegistry>) -> Self {
Self {
bus,
registry,
pending: Mutex::new(HashMap::new()),
update_handlers: Mutex::new(HashMap::new()),
}
}
pub async fn dispatch(
&self,
source: &str,
target: &str,
task_name: Option<String>,
payload: Option<Value>,
) -> Result<TaskHandle> {
self.dispatch_with(
source,
target,
task_name,
payload,
Some(DEFAULT_READY_TIMEOUT),
)
.await
}
pub async fn dispatch_with(
&self,
source: &str,
target: &str,
task_name: Option<String>,
payload: Option<Value>,
ready_timeout: Option<Duration>,
) -> Result<TaskHandle> {
if let Some(wait) = ready_timeout {
self.await_target_ready(target, wait).await?;
}
let task_id = uuid::Uuid::new_v4().to_string();
let (tx, rx) = mpsc::unbounded_channel();
{
let mut pending = self.pending.lock().await;
pending.insert(task_id.clone(), tx);
}
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskRequest {
task_id: task_id.clone(),
task_name,
payload,
},
);
self.bus.send(msg).await;
Ok(TaskHandle {
task_id,
target_agent: target.to_string(),
rx,
})
}
async fn await_target_ready(&self, target: &str, wait: Duration) -> Result<()> {
if self.registry.get(target).await.is_some() {
return Ok(());
}
let (ready_tx, ready_rx) = oneshot::channel::<()>();
let ready_tx = Arc::new(std::sync::Mutex::new(Some(ready_tx)));
self.registry
.watch(
target,
Arc::new(move |_info| {
let ready_tx = ready_tx.clone();
Box::pin(async move {
if let Some(tx) = ready_tx.lock().unwrap().take() {
let _ = tx.send(());
}
})
}),
)
.await;
timeout(wait, ready_rx)
.await
.map_err(|_| PipecatError::pipeline(format!("Target agent '{}' not ready", target)))?
.map_err(|_| PipecatError::pipeline(format!("Target agent '{}' not ready", target)))?;
Ok(())
}
pub async fn fail_all_pending(&self, reason: &str) {
let pending: HashMap<_, _> = {
let mut guard = self.pending.lock().await;
guard.drain().collect()
};
if !pending.is_empty() {
log::debug!(
"TaskContext: failing {} pending task(s): {}",
pending.len(),
reason
);
}
drop(pending); self.update_handlers.lock().await.clear();
}
pub async fn on_update(&self, task_id: &str, handler: UpdateHandler) {
let mut handlers = self.update_handlers.lock().await;
handlers
.entry(task_id.to_string())
.or_default()
.push(handler);
}
pub async fn route_update(&self, task_id: &str, update: TaskUpdate) {
{
let mut pending = self.pending.lock().await;
if let Some(tx) = pending.get(task_id) {
let is_final = matches!(update, TaskUpdate::Response { .. });
let _ = tx.send(update.clone());
if is_final {
pending.remove(task_id);
}
}
}
let handlers: Vec<UpdateHandler> = {
let h = self.update_handlers.lock().await;
h.get(task_id).cloned().unwrap_or_default()
};
for handler in handlers {
handler(update.clone()).await;
}
}
pub async fn stream_start(
&self,
source: &str,
target: &str,
task_id: String,
data: Option<Value>,
) {
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskStreamStart { task_id, data },
);
self.bus.send(msg).await;
}
pub async fn stream_data(
&self,
source: &str,
target: &str,
task_id: String,
data: Option<Value>,
) {
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskStreamData { task_id, data },
);
self.bus.send(msg).await;
}
pub async fn stream_end(
&self,
source: &str,
target: &str,
task_id: String,
data: Option<Value>,
) {
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskStreamEnd { task_id, data },
);
self.bus.send(msg).await;
}
pub async fn complete_task(
&self,
source: &str,
target: &str,
task_id: String,
status: TaskStatus,
response: Option<Value>,
) {
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskResponse {
task_id,
status,
response,
},
);
self.bus.send(msg).await;
}
pub async fn cancel_task(
&self,
source: &str,
target: &str,
task_id: String,
reason: Option<String>,
) {
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskCancel { task_id, reason },
);
self.bus.send(msg).await;
}
pub async fn urgent_response(
&self,
source: &str,
target: &str,
task_id: String,
status: TaskStatus,
response: Option<Value>,
) {
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskResponseUrgent {
task_id,
status,
response,
},
);
self.bus.send(msg).await;
}
pub async fn urgent_update(
&self,
source: &str,
target: &str,
task_id: String,
update: Option<Value>,
) {
let msg = BusMessage::new(
source.to_string(),
Some(target.to_string()),
BusPayload::TaskUpdateUrgent { task_id, update },
);
self.bus.send(msg).await;
}
}