use chrono::Utc;
use distri_a2a::{JsonRpcRequest, MessageKind, MessageSendParams};
use distri_types::{AgentEvent, AgentEventType, Message, ToolCall, ToolResponse};
use futures_util::StreamExt;
use reqwest_eventsource::{Error as EsError, Event, RequestBuilderExt};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;
use crate::{ExternalToolRegistry, HookRegistry};
use crate::config::{self, BuildHttpClient};
#[derive(Debug, Error)]
pub enum StreamError {
#[error("http request failed: {0}")]
Http(#[from] reqwest::Error),
#[error("stream failed: {0}")]
Event(String),
#[error("invalid response: {0}")]
InvalidResponse(String),
#[error("server error: {0}")]
Server(String),
#[error("external tool handler failed: {0}")]
ExternalTool(String),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
}
#[derive(Debug, Clone)]
pub struct StreamItem {
pub message: Option<Message>,
pub agent_event: Option<AgentEvent>,
}
#[derive(Clone)]
pub struct AgentStreamClient {
base_url: String,
config: config::DistriConfig,
http: reqwest::Client,
tool_registry: Option<ExternalToolRegistry>,
hook_registry: Option<HookRegistry>,
}
impl AgentStreamClient {
pub fn new(base_url: impl Into<String>) -> Self {
let cfg = config::DistriConfig::new(base_url);
Self::from_config(cfg)
}
pub fn from_config(cfg: config::DistriConfig) -> Self {
let base_url = cfg.base_url.clone();
let http = <config::DistriConfig as BuildHttpClient>::build_http_client(&cfg)
.expect("Failed to build HTTP client for AgentStreamClient");
Self {
base_url,
config: cfg,
http,
tool_registry: None,
hook_registry: None,
}
}
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
self.http = client;
self
}
pub fn with_tool_registry(mut self, registry: ExternalToolRegistry) -> Self {
self.tool_registry = Some(registry);
self
}
pub fn with_hook_registry(mut self, registry: HookRegistry) -> Self {
self.hook_registry = Some(registry);
self
}
pub async fn stream_agent<H, Fut>(
&self,
agent_id: &str,
params: MessageSendParams,
mut on_event: H,
) -> Result<(), StreamError>
where
H: FnMut(StreamItem) -> Fut,
Fut: std::future::Future<Output = ()> + Send,
{
let url = format!(
"{}/agents/{}",
self.base_url.trim_end_matches('/'),
agent_id
);
let params = {
use distri_types::configuration::DefinitionOverrides;
let mut p = params;
let factory = crate::platform_tools::build_distri_request_factory(&self.config);
let meta = p.metadata.get_or_insert_with(|| serde_json::json!({}));
let mut overrides: DefinitionOverrides = meta
.get("definition_overrides")
.and_then(|v| serde_json::from_value(v.clone()).ok())
.unwrap_or_default();
overrides
.dynamic_tools
.get_or_insert_with(Vec::new)
.push(factory);
meta.as_object_mut()
.unwrap()
.insert("definition_overrides".to_string(), serde_json::to_value(&overrides).unwrap());
p
};
let rpc = JsonRpcRequest {
jsonrpc: "2.0".to_string(),
id: Some(serde_json::Value::String(Uuid::new_v4().to_string())),
method: "message/stream".to_string(),
params: serde_json::to_value(params)?,
};
let builder = self.http.post(url).json(&rpc);
let mut es = builder
.eventsource()
.map_err(|e| StreamError::Event(e.to_string()))?;
while let Some(event) = es.next().await {
match event {
Ok(Event::Open) => continue,
Ok(Event::Message(message)) => {
if message.data.trim().is_empty() {
continue;
}
let rpc: RpcResponse = serde_json::from_str(&message.data)?;
if let Some(err) = rpc.error {
return Err(StreamError::Server(err.message));
}
let Some(result) = rpc.result else {
continue;
};
let message_kind: MessageKind = serde_json::from_value(result)?;
let agent_event =
Self::agent_event_from_message(agent_id, &message_kind).unwrap_or(None);
let distri_message = convert_kind(&message_kind)?;
if let Some(agent_event) = agent_event.clone() {
if let AgentEventType::InlineHookRequested { request } = &agent_event.event
&& let Some(registry) = &self.hook_registry {
registry.try_handle(agent_id, request).await;
}
if let AgentEventType::ToolCalls { tool_calls, .. } = &agent_event.event {
self.try_handle_external_tools(agent_id, &agent_event, tool_calls)
.await?;
}
}
on_event(StreamItem {
message: distri_message,
agent_event,
})
.await;
}
Err(EsError::StreamEnded) => break,
Err(err) => return Err(StreamError::Event(err.to_string())),
}
}
Ok(())
}
fn agent_event_from_message(
agent_id: &str,
message: &MessageKind,
) -> Result<Option<AgentEvent>, StreamError> {
let (metadata, context_id, task_id) = match message {
MessageKind::Message(msg) => (
msg.metadata.clone(),
msg.context_id.clone(),
msg.task_id.clone(),
),
MessageKind::TaskStatusUpdate(update) => (
update.metadata.clone(),
Some(update.context_id.clone()),
Some(update.task_id.clone()),
),
MessageKind::Artifact(_) => (None, None, None),
};
let Some(meta) = metadata else {
return Ok(None);
};
let Ok(event_type) = serde_json::from_value::<AgentEventType>(meta) else {
return Ok(None);
};
let thread_id = context_id.unwrap_or_else(|| "unknown_thread".to_string());
let task_id = task_id.unwrap_or_else(|| "unknown_task".to_string());
Ok(Some(AgentEvent {
timestamp: Utc::now(),
thread_id,
run_id: agent_id.to_string(),
task_id,
event: event_type,
agent_id: agent_id.to_string(),
user_id: None,
identifier_id: None,
workspace_id: None,
channel_id: None,
}))
}
async fn try_handle_external_tools(
&self,
agent_id: &str,
agent_event: &AgentEvent,
tool_calls: &[ToolCall],
) -> Result<(), StreamError> {
let Some(registry) = &self.tool_registry else {
return Ok(());
};
for call in tool_calls {
if let Some(result) = registry
.try_handle(agent_id, &call.tool_name, call, agent_event)
.await
{
match result {
Ok(response) => {
self.complete_tool(agent_id, &call.tool_call_id, response)
.await?;
}
Err(err) => {
return Err(StreamError::ExternalTool(err.to_string()));
}
}
}
}
Ok(())
}
async fn complete_tool(
&self,
agent_id: &str,
tool_call_id: &str,
tool_response: ToolResponse,
) -> Result<(), StreamError> {
let url = format!(
"{}/agents/{}/complete-tool",
self.base_url.trim_end_matches('/'),
agent_id
);
let payload = CompleteToolRequest {
tool_call_id: tool_call_id.to_string(),
tool_response,
};
let resp = self.http.post(&url).json(&payload).send().await?;
let status = resp.status();
if !status.is_success() {
let body = resp.text().await.unwrap_or_default();
tracing::error!("complete_tool failed ({}): {}", status, body);
return Err(StreamError::InvalidResponse(format!(
"complete_tool failed ({}): {}",
status, body
)));
}
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)]
struct CompleteToolRequest {
tool_call_id: String,
tool_response: ToolResponse,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct RpcResponse {
pub jsonrpc: String,
#[serde(default)]
pub result: Option<serde_json::Value>,
#[serde(default)]
pub error: Option<RpcError>,
#[serde(default)]
pub id: Option<serde_json::Value>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct RpcError {
pub code: i32,
pub message: String,
#[serde(default)]
pub data: Option<serde_json::Value>,
}
fn convert_kind(kind: &MessageKind) -> Result<Option<Message>, StreamError> {
match kind {
MessageKind::Message(msg) => distri_types::Message::try_from(msg.clone())
.map(Some)
.map_err(|e| StreamError::InvalidResponse(e.to_string())),
_ => Ok(None),
}
}