use std::collections::HashMap;
use serde_json::{Value, json};
use tokio::sync::{mpsc, watch};
use tokio_util::sync::CancellationToken;
use crate::engine::driver::PhaseDriver;
use crate::engine::types::{Direction, DriveResult, ProtocolEvent};
use crate::error::EngineError;
use crate::protocol::agui::build_run_agent_input;
use crate::transport::{JSONRPC_VERSION, JsonRpcMessage, JsonRpcRequest, Transport};
pub struct ContextAgUiDriver {
transport: Box<dyn Transport>,
thread_id: String,
}
impl ContextAgUiDriver {
#[must_use]
pub fn new(transport: Box<dyn Transport>, thread_id: String) -> Self {
Self {
transport,
thread_id,
}
}
}
#[async_trait::async_trait]
impl PhaseDriver for ContextAgUiDriver {
async fn drive_phase(
&mut self,
_phase_index: usize,
state: &Value,
extractors: watch::Receiver<HashMap<String, String>>,
event_tx: mpsc::Sender<ProtocolEvent>,
cancel: CancellationToken,
) -> Result<DriveResult, EngineError> {
let current_extractors = extractors.borrow().clone();
let input = build_run_agent_input(state, ¤t_extractors, &self.thread_id)?;
let input_value = serde_json::to_value(&input)
.map_err(|e| EngineError::Driver(format!("serialize RunAgentInput: {e}")))?;
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Outgoing,
method: "run_agent_input".to_string(),
content: input_value.clone(),
})
.await;
let msg = JsonRpcMessage::Request(JsonRpcRequest {
jsonrpc: JSONRPC_VERSION.to_string(),
method: "run_agent_input".to_string(),
params: Some(input_value),
id: json!(uuid::Uuid::new_v4().to_string()),
});
self.transport
.send_message(&msg)
.await
.map_err(|e| EngineError::Driver(format!("send RunAgentInput: {e}")))?;
loop {
tokio::select! {
result = self.transport.receive_message() => {
match result {
Ok(Some(msg)) => {
let (method, content) = extract_event_from_message(&msg);
let is_run_finished = method == "run_finished";
let _ = event_tx
.send(ProtocolEvent {
direction: Direction::Incoming,
method,
content,
})
.await;
if is_run_finished {
return Ok(DriveResult::Complete);
}
}
Ok(None) => return Ok(DriveResult::TransportClosed),
Err(e) => {
return Err(EngineError::Driver(format!(
"AG-UI transport receive error: {e}"
)));
}
}
}
() = cancel.cancelled() => return Ok(DriveResult::Complete),
}
}
}
}
fn extract_event_from_message(msg: &JsonRpcMessage) -> (String, Value) {
match msg {
JsonRpcMessage::Notification(notif) => {
let content = notif.params.clone().unwrap_or(json!({}));
(notif.method.clone(), content)
}
JsonRpcMessage::Response(resp) => {
let content = resp
.result
.clone()
.or_else(|| resp.error.as_ref().map(|e| json!({"error": e.message})))
.unwrap_or(json!({}));
("response".to_string(), content)
}
JsonRpcMessage::Request(req) => {
let content = req.params.clone().unwrap_or(json!({}));
(req.method.clone(), content)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::JsonRpcNotification;
#[test]
fn test_extract_event_from_notification() {
let msg = JsonRpcMessage::Notification(JsonRpcNotification::new(
"text_message_content",
Some(json!({"delta": "hello"})),
));
let (method, content) = extract_event_from_message(&msg);
assert_eq!(method, "text_message_content");
assert_eq!(content["delta"], "hello");
}
#[test]
fn test_extract_event_from_notification_no_params() {
let msg = JsonRpcMessage::Notification(JsonRpcNotification::new("run_finished", None));
let (method, content) = extract_event_from_message(&msg);
assert_eq!(method, "run_finished");
assert_eq!(content, json!({}));
}
}