1use serde::Deserialize;
5use serde_json::Value;
6use tracing::warn;
7
8#[derive(Debug, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum StreamMessage {
14 System(SystemMessage),
15 #[serde(rename = "assistant")]
16 Assistant(AssistantMessage),
17 Result(ResultMessage),
18 #[serde(rename = "user")]
19 #[allow(dead_code)]
20 User(UserMessage),
21 #[serde(rename = "stream_event")]
22 #[allow(dead_code)]
23 StreamEvent(StreamEventMessage),
24 #[serde(rename = "tool_progress")]
25 ToolProgress(ToolProgressMessage),
26 #[serde(rename = "control_request")]
27 ControlRequest(ControlRequestMessage),
28 #[serde(other)]
29 Unknown,
30}
31
32#[allow(dead_code)]
33#[derive(Debug, Deserialize)]
34pub struct SystemMessage {
35 #[serde(default)]
36 pub subtype: Option<String>,
37 #[serde(default)]
38 pub session_id: Option<String>,
39 #[serde(default)]
40 pub tools: Option<Vec<Value>>,
41 #[serde(default)]
42 pub model: Option<String>,
43 #[serde(default)]
44 pub cwd: Option<String>,
45}
46
47#[allow(dead_code)]
48#[derive(Debug, Deserialize)]
49pub struct AssistantMessage {
50 #[serde(default)]
51 pub message: Option<AssistantMessageBody>,
52}
53
54#[allow(dead_code)]
55#[derive(Debug, Deserialize)]
56pub struct AssistantMessageBody {
57 #[serde(default)]
58 pub content: Vec<ContentBlock>,
59}
60
61#[allow(dead_code)]
62#[derive(Debug, Deserialize)]
63pub struct ResultMessage {
64 #[serde(default)]
65 pub subtype: Option<String>,
66 #[serde(default)]
67 pub session_id: Option<String>,
68 #[serde(default)]
69 pub is_error: Option<bool>,
70 #[serde(default)]
71 pub result: Option<String>,
72 #[serde(default)]
73 pub num_turns: Option<u32>,
74 #[serde(default)]
75 pub duration_ms: Option<u64>,
76 #[serde(default)]
77 pub duration_api_ms: Option<u64>,
78}
79
80#[allow(dead_code)]
81#[derive(Debug, Deserialize)]
82pub struct UserMessage {
83 #[serde(default)]
84 pub message: Option<Value>,
85}
86
87#[allow(dead_code)]
88#[derive(Debug, Deserialize)]
89pub struct StreamEventMessage {
90 #[serde(default)]
91 pub subtype: Option<String>,
92 #[serde(flatten)]
93 pub data: Value,
94}
95
96#[allow(dead_code)]
97#[derive(Debug, Deserialize)]
98pub struct ToolProgressMessage {
99 #[serde(default)]
100 pub tool_name: Option<String>,
101 #[serde(default)]
102 pub tool_use_id: Option<String>,
103 #[serde(flatten)]
104 pub data: Value,
105}
106
107#[allow(dead_code)]
108#[derive(Debug, Deserialize)]
109pub struct ControlRequestMessage {
110 #[serde(default)]
111 pub request_id: Option<String>,
112 #[serde(default)]
113 pub request: Option<ControlRequestBody>,
114}
115
116#[allow(dead_code)]
117#[derive(Debug, Deserialize)]
118pub struct ControlRequestBody {
119 #[serde(default)]
120 pub subtype: Option<String>,
121 #[serde(default)]
122 pub tool_name: Option<String>,
123 #[serde(default)]
124 pub tool_input: Option<Value>,
125 #[serde(flatten)]
126 pub data: Value,
127}
128
129#[allow(dead_code)]
132#[derive(Debug, Deserialize)]
133#[serde(tag = "type", rename_all = "snake_case")]
134pub enum ContentBlock {
135 Text {
136 text: String,
137 },
138 #[serde(rename = "tool_use")]
139 ToolUse {
140 id: String,
141 name: String,
142 input: Value,
143 },
144 #[serde(rename = "tool_result")]
145 ToolResult {
146 tool_use_id: String,
147 #[serde(default)]
148 content: Option<Value>,
149 #[serde(default)]
150 is_error: Option<bool>,
151 },
152 Thinking {
153 thinking: String,
154 },
155 #[serde(other)]
156 Unknown,
157}
158
159pub fn user_message(prompt: &str, session_id: Option<&str>) -> Result<String, serde_json::Error> {
164 let mut msg = serde_json::json!({
165 "type": "user",
166 "message": {
167 "role": "user",
168 "content": prompt
169 }
170 });
171 if let Some(sid) = session_id {
172 msg["session_id"] = serde_json::json!(sid);
173 }
174 serde_json::to_string(&msg)
175}
176
177pub fn control_response(request_id: &str, response: Value) -> Result<String, serde_json::Error> {
180 let msg = serde_json::json!({
181 "type": "control_response",
182 "response": {
183 "subtype": "success",
184 "request_id": request_id,
185 "response": response
186 }
187 });
188 serde_json::to_string(&msg)
189}
190
191pub fn deny_tool(request_id: &str) -> String {
194 control_response(request_id, serde_json::json!({ "behavior": "deny" }))
195 .unwrap_or_else(|e| {
196 warn!("Failed to serialize deny_tool response: {}", e);
197 format!(
199 r#"{{"type":"control_response","response":{{"subtype":"success","request_id":"{}","response":{{"behavior":"deny"}}}}}}"#,
200 request_id
201 )
202 })
203}
204
205pub fn approve_tool(request_id: &str) -> String {
207 control_response(request_id, serde_json::json!({ "behavior": "allow" })).unwrap_or_else(|e| {
208 warn!("Failed to serialize approve_tool response: {}", e);
209 deny_tool(request_id)
211 })
212}
213
214pub fn answer_question(request_id: &str, questions: &Value, answer_text: &str) -> String {
217 let mut answers = serde_json::Map::new();
219 if let Some(arr) = questions.get("questions").and_then(|q| q.as_array()) {
220 for (i, _) in arr.iter().enumerate() {
221 answers.insert(i.to_string(), Value::String(answer_text.to_string()));
222 }
223 } else {
224 answers.insert("0".to_string(), Value::String(answer_text.to_string()));
226 }
227
228 let mut updated_input = questions.clone();
229 if let Some(obj) = updated_input.as_object_mut() {
230 obj.insert("answers".to_string(), Value::Object(answers));
231 }
232
233 control_response(
234 request_id,
235 serde_json::json!({
236 "behavior": "allow",
237 "updated_input": updated_input
238 }),
239 )
240 .unwrap_or_else(|e| {
241 warn!("Failed to serialize answer_question response: {}", e);
242 deny_tool(request_id)
244 })
245}
246
247pub fn parse_line(line: &str) -> Option<StreamMessage> {
252 let trimmed = line.trim();
253 if trimmed.is_empty() {
254 return None;
255 }
256 match serde_json::from_str::<StreamMessage>(trimmed) {
257 Ok(msg) => Some(msg),
258 Err(e) => {
259 warn!(
260 "Failed to parse stream-json line: {}. Line: {}",
261 e,
262 &trimmed[..crate::util::floor_char_boundary(trimmed, 200)]
263 );
264 None
265 }
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 #[test]
274 fn test_parse_system_init() {
275 let line =
276 r#"{"type":"system","subtype":"init","session_id":"abc-123","model":"claude-sonnet"}"#;
277 let msg = parse_line(line).unwrap();
278 match msg {
279 StreamMessage::System(sys) => {
280 assert_eq!(sys.subtype.as_deref(), Some("init"));
281 assert_eq!(sys.session_id.as_deref(), Some("abc-123"));
282 assert_eq!(sys.model.as_deref(), Some("claude-sonnet"));
283 }
284 other => panic!("Expected System, got {:?}", other),
285 }
286 }
287
288 #[test]
289 fn test_parse_assistant_text() {
290 let line =
291 r#"{"type":"assistant","message":{"content":[{"type":"text","text":"Hello world"}]}}"#;
292 let msg = parse_line(line).unwrap();
293 match msg {
294 StreamMessage::Assistant(a) => {
295 let body = a.message.unwrap();
296 assert_eq!(body.content.len(), 1);
297 match &body.content[0] {
298 ContentBlock::Text { text } => assert_eq!(text, "Hello world"),
299 other => panic!("Expected Text, got {:?}", other),
300 }
301 }
302 other => panic!("Expected Assistant, got {:?}", other),
303 }
304 }
305
306 #[test]
307 fn test_parse_assistant_tool_use() {
308 let line = r#"{"type":"assistant","message":{"content":[{"type":"tool_use","id":"t1","name":"Read","input":{"file_path":"/tmp/test"}}]}}"#;
309 let msg = parse_line(line).unwrap();
310 match msg {
311 StreamMessage::Assistant(a) => {
312 let body = a.message.unwrap();
313 match &body.content[0] {
314 ContentBlock::ToolUse { id, name, input } => {
315 assert_eq!(id, "t1");
316 assert_eq!(name, "Read");
317 assert_eq!(input["file_path"], "/tmp/test");
318 }
319 other => panic!("Expected ToolUse, got {:?}", other),
320 }
321 }
322 other => panic!("Expected Assistant, got {:?}", other),
323 }
324 }
325
326 #[test]
327 fn test_parse_result() {
328 let line = r#"{"type":"result","subtype":"success","session_id":"s1","is_error":false,"result":"Done","num_turns":3,"duration_ms":1500}"#;
329 let msg = parse_line(line).unwrap();
330 match msg {
331 StreamMessage::Result(r) => {
332 assert_eq!(r.subtype.as_deref(), Some("success"));
333 assert_eq!(r.session_id.as_deref(), Some("s1"));
334 assert_eq!(r.is_error, Some(false));
335 assert_eq!(r.result.as_deref(), Some("Done"));
336 assert_eq!(r.num_turns, Some(3));
337 assert_eq!(r.duration_ms, Some(1500));
338 }
339 other => panic!("Expected Result, got {:?}", other),
340 }
341 }
342
343 #[test]
344 fn test_parse_control_request() {
345 let line = r#"{"type":"control_request","request_id":"req1","request":{"subtype":"tool_use","tool_name":"Bash"}}"#;
346 let msg = parse_line(line).unwrap();
347 match msg {
348 StreamMessage::ControlRequest(c) => {
349 assert_eq!(c.request_id.as_deref(), Some("req1"));
350 let body = c.request.unwrap();
351 assert_eq!(body.tool_name.as_deref(), Some("Bash"));
352 }
353 other => panic!("Expected ControlRequest, got {:?}", other),
354 }
355 }
356
357 #[test]
358 fn test_parse_tool_progress() {
359 let line = r#"{"type":"tool_progress","tool_name":"Bash","tool_use_id":"t1"}"#;
360 let msg = parse_line(line).unwrap();
361 match msg {
362 StreamMessage::ToolProgress(tp) => {
363 assert_eq!(tp.tool_name.as_deref(), Some("Bash"));
364 assert_eq!(tp.tool_use_id.as_deref(), Some("t1"));
365 }
366 other => panic!("Expected ToolProgress, got {:?}", other),
367 }
368 }
369
370 #[test]
371 fn test_parse_unknown_type() {
372 let line = r#"{"type":"some_future_type","data":42}"#;
373 let msg = parse_line(line).unwrap();
374 assert!(matches!(msg, StreamMessage::Unknown));
375 }
376
377 #[test]
378 fn test_parse_empty_line() {
379 assert!(parse_line("").is_none());
380 assert!(parse_line(" ").is_none());
381 }
382
383 #[test]
384 fn test_parse_malformed_json() {
385 assert!(parse_line("{invalid json}").is_none());
386 }
387
388 #[test]
389 fn test_user_message_without_session() {
390 let msg = user_message("hello", None).unwrap();
391 let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
392 assert_eq!(parsed["type"], "user");
393 assert_eq!(parsed["message"]["content"], "hello");
394 assert!(parsed.get("session_id").is_none());
395 }
396
397 #[test]
398 fn test_user_message_with_session() {
399 let msg = user_message("hello", Some("s1")).unwrap();
400 let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
401 assert_eq!(parsed["type"], "user");
402 assert_eq!(parsed["message"]["content"], "hello");
403 assert_eq!(parsed["session_id"], "s1");
404 }
405
406 #[test]
407 fn test_deny_tool() {
408 let msg = deny_tool("req2");
409 let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
410 assert_eq!(parsed["type"], "control_response");
411 assert_eq!(parsed["response"]["request_id"], "req2");
412 assert_eq!(parsed["response"]["response"]["behavior"], "deny");
413 }
414
415 #[test]
416 fn test_approve_tool() {
417 let msg = approve_tool("req3");
418 let parsed: serde_json::Value = serde_json::from_str(&msg).unwrap();
419 assert_eq!(parsed["type"], "control_response");
420 assert_eq!(parsed["response"]["request_id"], "req3");
421 assert_eq!(parsed["response"]["response"]["behavior"], "allow");
422 }
423
424 #[test]
425 fn test_parse_thinking_block() {
426 let line = r#"{"type":"assistant","message":{"content":[{"type":"thinking","thinking":"Let me think..."},{"type":"text","text":"Here is my answer"}]}}"#;
427 let msg = parse_line(line).unwrap();
428 match msg {
429 StreamMessage::Assistant(a) => {
430 let body = a.message.unwrap();
431 assert_eq!(body.content.len(), 2);
432 assert!(matches!(&body.content[0], ContentBlock::Thinking { .. }));
433 assert!(matches!(&body.content[1], ContentBlock::Text { .. }));
434 }
435 other => panic!("Expected Assistant, got {:?}", other),
436 }
437 }
438}