use std::time::Duration;
use crate::client::WireClient;
use crate::error::WireError;
use crate::message::{parse_wire_message, WireMessage};
use crate::protocol::{
ApprovalResponseKind, DisplayBlock, Event, HookAction, Request, ToolOutput, ToolReturnValue,
};
pub trait WireClientExt: WireClient {
fn read_message(
&mut self,
) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
async move {
let raw = self.read_raw_message().await?;
parse_wire_message(raw)
}
}
fn read_message_timeout(
&mut self,
timeout: Duration,
) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
async move {
let raw = self.read_raw_message_timeout(timeout).await?;
parse_wire_message(raw)
}
}
}
impl<T: WireClient + ?Sized> WireClientExt for T {}
pub trait EventExt {
fn event_type(&self) -> String;
fn normalized_event_type(&self) -> String;
fn payload(&self) -> serde_json::Value;
}
impl EventExt for Event {
fn event_type(&self) -> String {
self.type_name().to_string()
}
fn normalized_event_type(&self) -> String {
let pascal = self.type_name();
let mut snake = String::new();
for (i, ch) in pascal.chars().enumerate() {
if ch.is_uppercase() && i > 0 {
snake.push('_');
}
snake.push(ch.to_ascii_lowercase());
}
snake
}
fn payload(&self) -> serde_json::Value {
serde_json::to_value(self).map_or(serde_json::Value::Null, |v| v)
}
}
pub trait RequestExt {
fn kind(&self) -> String;
fn default_response(&self) -> serde_json::Value;
}
impl RequestExt for Request {
fn kind(&self) -> String {
match self {
Self::ApprovalRequest(_) => "ApprovalRequest",
Self::ToolCallRequest(_) => "ToolCallRequest",
Self::QuestionRequest(_) => "QuestionRequest",
Self::HookRequest(_) => "HookRequest",
}
.to_string()
}
fn default_response(&self) -> serde_json::Value {
match self {
Self::ApprovalRequest(req) => serde_json::json!({
"request_id": req.id,
"response": ApprovalResponseKind::ApproveForSession,
"feedback": "Auto-approved by non-interactive worker."
}),
Self::ToolCallRequest(req) => serde_json::json!({
"tool_call_id": req.id,
"return_value": ToolReturnValue {
is_error: true,
output: ToolOutput::Text(String::new()),
message: format!("External tool '{}' is not registered.", req.name),
display: vec![DisplayBlock::brief("External tool unavailable.")],
extras: None,
}
}),
Self::QuestionRequest(req) => {
let answers: Vec<serde_json::Value> = req
.questions
.iter()
.map(|q| {
q.options.first().map_or(serde_json::Value::Null, |o| {
serde_json::Value::String(o.label.clone())
})
})
.collect();
serde_json::json!({
"request_id": req.id,
"answers": answers,
"message": "Selected default answers because workers run non-interactively."
})
}
Self::HookRequest(req) => serde_json::json!({
"request_id": req.id,
"action": HookAction::Allow,
"reason": format!(
"No hook policy is configured for '{}' on '{}'.",
req.event, req.target
)
}),
}
}
}