1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[derive(Debug, Serialize)]
15pub struct SdkUserMessage {
16 #[serde(rename = "type")]
17 pub msg_type: &'static str, pub session_id: String,
19 pub message: UserMessageBody,
20 pub parent_tool_use_id: Option<String>,
21}
22
23impl SdkUserMessage {
24 pub fn new(session_id: &str, content: &str) -> Self {
25 Self {
26 msg_type: "user",
27 session_id: session_id.to_string(),
28 message: UserMessageBody {
29 role: "user".to_string(),
30 content: content.to_string(),
31 },
32 parent_tool_use_id: None,
33 }
34 }
35
36 pub fn to_ndjson(&self) -> String {
38 serde_json::to_string(self).expect("SdkUserMessage is always serializable")
39 }
40}
41
42#[derive(Debug, Serialize)]
43pub struct UserMessageBody {
44 pub role: String,
45 pub content: String,
46}
47
48#[derive(Debug, Serialize)]
50pub struct SdkControlResponse {
51 #[serde(rename = "type")]
52 pub msg_type: &'static str, pub response: ControlResponseBody,
54}
55
56impl SdkControlResponse {
57 pub fn approve_tool(request_id: &str, tool_use_id: &str) -> Self {
59 Self {
60 msg_type: "control_response",
61 response: ControlResponseBody {
62 subtype: "success".to_string(),
63 request_id: request_id.to_string(),
64 response: Some(ToolApproval {
65 tool_use_id: tool_use_id.to_string(),
66 approved: true,
67 }),
68 },
69 }
70 }
71
72 pub fn to_ndjson(&self) -> String {
74 serde_json::to_string(self).expect("SdkControlResponse is always serializable")
75 }
76}
77
78#[derive(Debug, Serialize)]
79pub struct ControlResponseBody {
80 pub subtype: String,
81 pub request_id: String,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub response: Option<ToolApproval>,
84}
85
86#[derive(Debug, Serialize)]
87pub struct ToolApproval {
88 #[serde(rename = "toolUseID")]
89 pub tool_use_id: String,
90 pub approved: bool,
91}
92
93#[derive(Debug, Deserialize)]
102pub struct SdkOutput {
103 #[serde(rename = "type")]
104 pub msg_type: String,
105
106 #[serde(default)]
107 pub subtype: Option<String>,
108
109 #[serde(default)]
110 pub session_id: Option<String>,
111
112 #[serde(default)]
113 pub uuid: Option<String>,
114
115 #[serde(default)]
117 pub message: Option<Value>,
118
119 #[serde(default)]
121 pub event: Option<Value>,
122
123 #[serde(default)]
125 pub result: Option<String>,
126
127 #[serde(default)]
129 pub num_turns: Option<u32>,
130
131 #[serde(default)]
133 pub is_error: Option<bool>,
134
135 #[serde(default)]
137 pub errors: Option<Vec<String>>,
138
139 #[serde(default)]
141 pub request_id: Option<String>,
142
143 #[serde(default)]
145 pub request: Option<Value>,
146}
147
148impl SdkOutput {
149 pub fn request_subtype(&self) -> Option<String> {
151 self.request
152 .as_ref()
153 .and_then(|r| r.get("subtype"))
154 .and_then(|v| v.as_str())
155 .map(String::from)
156 }
157
158 pub fn request_tool_use_id(&self) -> Option<String> {
160 self.request
161 .as_ref()
162 .and_then(|r| r.get("tool_use_id"))
163 .and_then(|v| v.as_str())
164 .map(String::from)
165 }
166}
167
168pub fn extract_assistant_text(message: &Value) -> String {
178 let content = match message.get("content") {
179 Some(Value::Array(arr)) => arr,
180 Some(Value::String(s)) => return s.clone(),
181 _ => return String::new(),
182 };
183
184 let mut parts = Vec::new();
185 for block in content {
186 if block.get("type").and_then(|t| t.as_str()) == Some("text") {
187 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
188 parts.push(text);
189 }
190 }
191 }
192 parts.join("")
193}
194
195pub fn extract_stream_text(event: &Value) -> Option<String> {
200 if let Some(delta) = event.get("delta") {
202 if delta.get("type").and_then(|t| t.as_str()) == Some("text_delta") {
203 return delta.get("text").and_then(|t| t.as_str()).map(String::from);
204 }
205 }
206 None
207}
208
209#[cfg(test)]
214mod tests {
215 use super::*;
216 use serde_json::json;
217
218 #[test]
221 fn user_message_serializes_correctly() {
222 let msg = SdkUserMessage::new("sess-1", "Fix the bug");
223 let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
224 assert_eq!(json["type"], "user");
225 assert_eq!(json["session_id"], "sess-1");
226 assert_eq!(json["message"]["role"], "user");
227 assert_eq!(json["message"]["content"], "Fix the bug");
228 assert!(json["parent_tool_use_id"].is_null());
229 }
230
231 #[test]
232 fn user_message_empty_session_id() {
233 let msg = SdkUserMessage::new("", "hello");
234 let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
235 assert_eq!(json["session_id"], "");
236 }
237
238 #[test]
241 fn approve_tool_serializes_correctly() {
242 let resp = SdkControlResponse::approve_tool("req-42", "tool-99");
243 let json: Value = serde_json::from_str(&resp.to_ndjson()).unwrap();
244 assert_eq!(json["type"], "control_response");
245 assert_eq!(json["response"]["subtype"], "success");
246 assert_eq!(json["response"]["request_id"], "req-42");
247 assert_eq!(json["response"]["response"]["toolUseID"], "tool-99");
248 assert_eq!(json["response"]["response"]["approved"], true);
249 }
250
251 #[test]
254 fn parse_assistant_message() {
255 let line = r#"{"type":"assistant","session_id":"abc","uuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello world"}]}}"#;
256 let msg: SdkOutput = serde_json::from_str(line).unwrap();
257 assert_eq!(msg.msg_type, "assistant");
258 assert_eq!(msg.session_id.as_deref(), Some("abc"));
259 assert!(msg.message.is_some());
260 }
261
262 #[test]
263 fn parse_result_success() {
264 let line = r#"{"type":"result","subtype":"success","session_id":"abc","uuid":"u2","result":"done","num_turns":3,"is_error":false}"#;
265 let msg: SdkOutput = serde_json::from_str(line).unwrap();
266 assert_eq!(msg.msg_type, "result");
267 assert_eq!(msg.subtype.as_deref(), Some("success"));
268 assert_eq!(msg.result.as_deref(), Some("done"));
269 assert_eq!(msg.num_turns, Some(3));
270 assert_eq!(msg.is_error, Some(false));
271 }
272
273 #[test]
274 fn parse_result_error() {
275 let line = r#"{"type":"result","subtype":"error_during_execution","is_error":true,"errors":["context window exceeded"],"session_id":"x"}"#;
276 let msg: SdkOutput = serde_json::from_str(line).unwrap();
277 assert_eq!(msg.msg_type, "result");
278 assert_eq!(msg.is_error, Some(true));
279 assert_eq!(
280 msg.errors.as_deref(),
281 Some(&["context window exceeded".to_string()][..])
282 );
283 }
284
285 #[test]
286 fn parse_control_request() {
287 let line = r#"{"type":"control_request","request_id":"req-1","request":{"subtype":"can_use_tool","tool_name":"Bash","tool_use_id":"tu-1","input":{"command":"ls"}}}"#;
288 let msg: SdkOutput = serde_json::from_str(line).unwrap();
289 assert_eq!(msg.msg_type, "control_request");
290 assert_eq!(msg.request_id.as_deref(), Some("req-1"));
291 assert_eq!(msg.request_subtype().as_deref(), Some("can_use_tool"));
292 assert_eq!(msg.request_tool_use_id().as_deref(), Some("tu-1"));
293 }
294
295 #[test]
296 fn parse_stream_event() {
297 let line = r#"{"type":"stream_event","session_id":"abc","uuid":"u3","event":{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}}"#;
298 let msg: SdkOutput = serde_json::from_str(line).unwrap();
299 assert_eq!(msg.msg_type, "stream_event");
300 assert!(msg.event.is_some());
301 }
302
303 #[test]
304 fn unknown_message_type_is_tolerated() {
305 let line = r#"{"type":"future_new_type","session_id":"abc","some_field":42}"#;
306 let msg: SdkOutput = serde_json::from_str(line).unwrap();
307 assert_eq!(msg.msg_type, "future_new_type");
308 }
309
310 #[test]
313 fn extract_text_from_content_array() {
314 let msg = json!({
315 "role": "assistant",
316 "content": [
317 {"type": "text", "text": "Hello "},
318 {"type": "tool_use", "id": "t1", "name": "Bash", "input": {}},
319 {"type": "text", "text": "world"}
320 ]
321 });
322 assert_eq!(extract_assistant_text(&msg), "Hello world");
323 }
324
325 #[test]
326 fn extract_text_from_string_content() {
327 let msg = json!({"role": "assistant", "content": "plain string"});
328 assert_eq!(extract_assistant_text(&msg), "plain string");
329 }
330
331 #[test]
332 fn extract_text_empty_content() {
333 let msg = json!({"role": "assistant", "content": []});
334 assert_eq!(extract_assistant_text(&msg), "");
335 }
336
337 #[test]
338 fn extract_text_no_content_field() {
339 let msg = json!({"role": "assistant"});
340 assert_eq!(extract_assistant_text(&msg), "");
341 }
342
343 #[test]
344 fn extract_text_only_tool_use_blocks() {
345 let msg = json!({
346 "content": [
347 {"type": "tool_use", "id": "t1", "name": "Read", "input": {}}
348 ]
349 });
350 assert_eq!(extract_assistant_text(&msg), "");
351 }
352
353 #[test]
354 fn extract_stream_text_delta() {
355 let event = json!({
356 "type": "content_block_delta",
357 "index": 0,
358 "delta": {"type": "text_delta", "text": "incremental"}
359 });
360 assert_eq!(extract_stream_text(&event), Some("incremental".to_string()));
361 }
362
363 #[test]
364 fn extract_stream_text_non_text_delta() {
365 let event = json!({
366 "type": "content_block_delta",
367 "delta": {"type": "input_json_delta", "partial_json": "{}"}
368 });
369 assert_eq!(extract_stream_text(&event), None);
370 }
371
372 #[test]
373 fn extract_stream_text_no_delta() {
374 let event = json!({"type": "content_block_start"});
375 assert_eq!(extract_stream_text(&event), None);
376 }
377}