1use std::time::Duration;
7
8use crate::client::WireClient;
9use crate::error::WireError;
10use crate::message::{parse_wire_message, WireMessage};
11use crate::protocol::{
12 ApprovalResponseKind, DisplayBlock, Event, HookAction, Request, ToolOutput, ToolReturnValue,
13};
14
15pub trait WireClientExt: WireClient {
21 fn read_message(
28 &mut self,
29 ) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
30 async move {
31 let raw = self.read_raw_message().await?;
32 parse_wire_message(raw)
33 }
34 }
35
36 fn read_message_timeout(
41 &mut self,
42 timeout: Duration,
43 ) -> impl std::future::Future<Output = Result<WireMessage, WireError>> + Send {
44 async move {
45 let raw = self.read_raw_message_timeout(timeout).await?;
46 parse_wire_message(raw)
47 }
48 }
49}
50
51impl<T: WireClient + ?Sized> WireClientExt for T {}
52
53pub trait EventExt {
55 fn event_type(&self) -> String;
57
58 fn normalized_event_type(&self) -> String;
60
61 fn payload(&self) -> serde_json::Value;
63}
64
65impl EventExt for Event {
66 fn event_type(&self) -> String {
67 self.type_name().to_string()
68 }
69
70 fn normalized_event_type(&self) -> String {
71 let pascal = self.type_name();
72 let mut snake = String::new();
73 for (i, ch) in pascal.chars().enumerate() {
74 if ch.is_uppercase() && i > 0 {
75 snake.push('_');
76 }
77 snake.push(ch.to_ascii_lowercase());
78 }
79 snake
80 }
81
82 fn payload(&self) -> serde_json::Value {
83 match serde_json::to_value(self) {
84 Ok(v) => v,
85 Err(_) => serde_json::Value::Null,
86 }
87 }
88}
89
90pub trait RequestExt {
92 fn kind(&self) -> String;
94
95 fn default_response(&self) -> serde_json::Value;
102}
103
104impl RequestExt for Request {
105 fn kind(&self) -> String {
106 match self {
107 Request::ApprovalRequest(_) => "ApprovalRequest",
108 Request::ToolCallRequest(_) => "ToolCallRequest",
109 Request::QuestionRequest(_) => "QuestionRequest",
110 Request::HookRequest(_) => "HookRequest",
111 }
112 .to_string()
113 }
114
115 fn default_response(&self) -> serde_json::Value {
116 match self {
117 Request::ApprovalRequest(req) => serde_json::json!({
118 "request_id": req.id,
119 "response": ApprovalResponseKind::ApproveForSession,
120 "feedback": "Auto-approved by non-interactive worker."
121 }),
122 Request::ToolCallRequest(req) => serde_json::json!({
123 "tool_call_id": req.id,
124 "return_value": ToolReturnValue {
125 is_error: true,
126 output: ToolOutput::Text(String::new()),
127 message: format!("External tool '{}' is not registered.", req.name),
128 display: vec![DisplayBlock::brief("External tool unavailable.")],
129 extras: None,
130 }
131 }),
132 Request::QuestionRequest(req) => {
133 let answers: Vec<serde_json::Value> = req
134 .questions
135 .iter()
136 .map(|q| {
137 q.options
138 .first()
139 .map(|o| serde_json::Value::String(o.label.clone()))
140 .unwrap_or(serde_json::Value::Null)
141 })
142 .collect();
143 serde_json::json!({
144 "request_id": req.id,
145 "answers": answers,
146 "message": "Selected default answers because workers run non-interactively."
147 })
148 }
149 Request::HookRequest(req) => serde_json::json!({
150 "request_id": req.id,
151 "action": HookAction::Allow,
152 "reason": format!(
153 "No hook policy is configured for '{}' on '{}'.",
154 req.event, req.target
155 )
156 }),
157 }
158 }
159}