1use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[derive(Debug, Serialize)]
15pub struct SdkUserMessage {
16 #[serde(rename = "type")]
17 pub msg_type: &'static str, pub session_id: String,
19 pub message: UserMessageBody,
20 pub parent_tool_use_id: Option<String>,
21}
22
23impl SdkUserMessage {
24 pub fn new(session_id: &str, content: &str) -> Self {
25 Self {
26 msg_type: "user",
27 session_id: session_id.to_string(),
28 message: UserMessageBody {
29 role: "user".to_string(),
30 content: content.to_string(),
31 },
32 parent_tool_use_id: None,
33 }
34 }
35
36 pub fn to_ndjson(&self) -> String {
38 serde_json::to_string(self).expect("SdkUserMessage is always serializable")
39 }
40}
41
42#[derive(Debug, Serialize)]
43pub struct UserMessageBody {
44 pub role: String,
45 pub content: String,
46}
47
48#[derive(Debug, Serialize)]
50pub struct SdkControlResponse {
51 #[serde(rename = "type")]
52 pub msg_type: &'static str, pub response: ControlResponseBody,
54}
55
56impl SdkControlResponse {
57 pub fn approve_tool(request_id: &str, tool_use_id: &str) -> Self {
59 Self {
60 msg_type: "control_response",
61 response: ControlResponseBody {
62 subtype: "success".to_string(),
63 request_id: request_id.to_string(),
64 response: Some(ToolApproval {
65 tool_use_id: tool_use_id.to_string(),
66 approved: true,
67 }),
68 },
69 }
70 }
71
72 pub fn to_ndjson(&self) -> String {
74 serde_json::to_string(self).expect("SdkControlResponse is always serializable")
75 }
76}
77
78#[derive(Debug, Serialize)]
79pub struct ControlResponseBody {
80 pub subtype: String,
81 pub request_id: String,
82 #[serde(skip_serializing_if = "Option::is_none")]
83 pub response: Option<ToolApproval>,
84}
85
86#[derive(Debug, Serialize)]
87pub struct ToolApproval {
88 #[serde(rename = "toolUseID")]
89 pub tool_use_id: String,
90 pub approved: bool,
91}
92
93#[derive(Debug, Deserialize)]
102pub struct SdkOutput {
103 #[serde(rename = "type")]
104 pub msg_type: String,
105
106 #[serde(default)]
107 pub subtype: Option<String>,
108
109 #[serde(default)]
110 pub session_id: Option<String>,
111
112 #[serde(default)]
113 pub uuid: Option<String>,
114
115 #[serde(default)]
117 pub message: Option<Value>,
118
119 #[serde(default)]
121 pub event: Option<Value>,
122
123 #[serde(default)]
125 pub result: Option<String>,
126
127 #[serde(default)]
129 pub num_turns: Option<u32>,
130
131 #[serde(default)]
133 pub is_error: Option<bool>,
134
135 #[serde(default)]
137 pub errors: Option<Vec<String>>,
138
139 #[serde(default)]
141 pub usage: Option<Value>,
142
143 #[serde(rename = "modelUsage", default)]
145 pub model_usage: Option<Value>,
146
147 #[serde(default)]
149 pub request_id: Option<String>,
150
151 #[serde(default)]
153 pub request: Option<Value>,
154}
155
156#[derive(Debug, Clone, Default, PartialEq, Eq)]
157pub struct SdkTokenUsage {
158 pub input_tokens: u64,
159 pub cached_input_tokens: u64,
160 pub cache_creation_input_tokens: u64,
161 pub cache_read_input_tokens: u64,
162 pub output_tokens: u64,
163 pub reasoning_output_tokens: u64,
164}
165
166impl SdkTokenUsage {
167 pub fn total_tokens(&self) -> u64 {
168 self.input_tokens
169 + self.cached_input_tokens
170 + self.cache_creation_input_tokens
171 + self.cache_read_input_tokens
172 + self.output_tokens
173 + self.reasoning_output_tokens
174 }
175}
176
177impl SdkOutput {
178 pub fn request_subtype(&self) -> Option<String> {
180 self.request
181 .as_ref()
182 .and_then(|r| r.get("subtype"))
183 .and_then(|v| v.as_str())
184 .map(String::from)
185 }
186
187 pub fn request_tool_use_id(&self) -> Option<String> {
189 self.request
190 .as_ref()
191 .and_then(|r| r.get("tool_use_id"))
192 .and_then(|v| v.as_str())
193 .map(String::from)
194 }
195
196 pub fn model_name(&self) -> Option<String> {
197 self.message
198 .as_ref()
199 .and_then(|message| message.get("model"))
200 .and_then(|value| value.as_str())
201 .map(String::from)
202 .or_else(|| {
203 self.model_usage
204 .as_ref()
205 .and_then(|value| value.get("model"))
206 .and_then(|value| value.as_str())
207 .map(String::from)
208 })
209 }
210
211 pub fn token_usage(&self) -> Option<SdkTokenUsage> {
212 let usage = self.usage.as_ref()?;
213 let cache_creation = usage.get("cache_creation");
214 let cache_creation_classified =
215 json_u64(cache_creation.and_then(|value| value.get("ephemeral_5m_input_tokens")))
216 + json_u64(cache_creation.and_then(|value| value.get("ephemeral_1h_input_tokens")));
217 Some(SdkTokenUsage {
218 input_tokens: json_u64(usage.get("input_tokens")),
219 cached_input_tokens: json_u64(usage.get("cached_input_tokens")),
220 cache_creation_input_tokens: json_u64(usage.get("cache_creation_input_tokens"))
221 .max(cache_creation_classified),
222 cache_read_input_tokens: json_u64(usage.get("cache_read_input_tokens")),
223 output_tokens: json_u64(usage.get("output_tokens")),
224 reasoning_output_tokens: json_u64(usage.get("reasoning_output_tokens")),
225 })
226 }
227
228 pub fn usage_total_tokens(&self) -> u64 {
229 let Some(usage) = self.usage.as_ref() else {
230 return 0;
231 };
232 let cache_creation = usage.get("cache_creation");
233 json_u64(usage.get("input_tokens"))
234 + json_u64(usage.get("cached_input_tokens"))
235 + json_u64(usage.get("cache_creation_input_tokens"))
236 + json_u64(cache_creation.and_then(|value| value.get("ephemeral_5m_input_tokens")))
237 + json_u64(cache_creation.and_then(|value| value.get("ephemeral_1h_input_tokens")))
238 + json_u64(usage.get("cache_read_input_tokens"))
239 + json_u64(usage.get("output_tokens"))
240 + json_u64(usage.get("reasoning_output_tokens"))
241 }
242}
243
244fn json_u64(value: Option<&Value>) -> u64 {
245 value.and_then(Value::as_u64).unwrap_or(0)
246}
247
248pub fn extract_assistant_text(message: &Value) -> String {
258 let content = match message.get("content") {
259 Some(Value::Array(arr)) => arr,
260 Some(Value::String(s)) => return s.clone(),
261 _ => return String::new(),
262 };
263
264 let mut parts = Vec::new();
265 for block in content {
266 if block.get("type").and_then(|t| t.as_str()) == Some("text") {
267 if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
268 parts.push(text);
269 }
270 }
271 }
272 parts.join("")
273}
274
275pub fn extract_stream_text(event: &Value) -> Option<String> {
280 if let Some(delta) = event.get("delta") {
282 if delta.get("type").and_then(|t| t.as_str()) == Some("text_delta") {
283 return delta.get("text").and_then(|t| t.as_str()).map(String::from);
284 }
285 }
286 None
287}
288
289#[cfg(test)]
294mod tests {
295 use super::*;
296 use serde_json::json;
297
298 #[test]
301 fn user_message_serializes_correctly() {
302 let msg = SdkUserMessage::new("sess-1", "Fix the bug");
303 let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
304 assert_eq!(json["type"], "user");
305 assert_eq!(json["session_id"], "sess-1");
306 assert_eq!(json["message"]["role"], "user");
307 assert_eq!(json["message"]["content"], "Fix the bug");
308 assert!(json["parent_tool_use_id"].is_null());
309 }
310
311 #[test]
312 fn user_message_empty_session_id() {
313 let msg = SdkUserMessage::new("", "hello");
314 let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
315 assert_eq!(json["session_id"], "");
316 }
317
318 #[test]
321 fn approve_tool_serializes_correctly() {
322 let resp = SdkControlResponse::approve_tool("req-42", "tool-99");
323 let json: Value = serde_json::from_str(&resp.to_ndjson()).unwrap();
324 assert_eq!(json["type"], "control_response");
325 assert_eq!(json["response"]["subtype"], "success");
326 assert_eq!(json["response"]["request_id"], "req-42");
327 assert_eq!(json["response"]["response"]["toolUseID"], "tool-99");
328 assert_eq!(json["response"]["response"]["approved"], true);
329 }
330
331 #[test]
334 fn parse_assistant_message() {
335 let line = r#"{"type":"assistant","session_id":"abc","uuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello world"}]}}"#;
336 let msg: SdkOutput = serde_json::from_str(line).unwrap();
337 assert_eq!(msg.msg_type, "assistant");
338 assert_eq!(msg.session_id.as_deref(), Some("abc"));
339 assert!(msg.message.is_some());
340 }
341
342 #[test]
343 fn parse_result_token_usage_includes_cache_fields() {
344 let line = r#"{"type":"result","usage":{"input_tokens":10,"cached_input_tokens":4,"cache_creation_input_tokens":3,"cache_read_input_tokens":2,"output_tokens":5,"reasoning_output_tokens":1,"cache_creation":{"ephemeral_5m_input_tokens":7,"ephemeral_1h_input_tokens":11}}}"#;
345 let msg: SdkOutput = serde_json::from_str(line).unwrap();
346 let usage = msg.token_usage().unwrap();
347 assert_eq!(
348 usage,
349 SdkTokenUsage {
350 input_tokens: 10,
351 cached_input_tokens: 4,
352 cache_creation_input_tokens: 18,
353 cache_read_input_tokens: 2,
354 output_tokens: 5,
355 reasoning_output_tokens: 1,
356 }
357 );
358 assert_eq!(usage.total_tokens(), 40);
359 }
360
361 #[test]
362 fn model_name_prefers_message_then_model_usage() {
363 let assistant_line = r#"{"type":"assistant","message":{"model":"claude-sonnet-4-5","content":[{"type":"text","text":"hello"}]}}"#;
364 let assistant: SdkOutput = serde_json::from_str(assistant_line).unwrap();
365 assert_eq!(assistant.model_name().as_deref(), Some("claude-sonnet-4-5"));
366
367 let result_line = r#"{"type":"result","modelUsage":{"model":"claude-opus-4-1"}}"#;
368 let result: SdkOutput = serde_json::from_str(result_line).unwrap();
369 assert_eq!(result.model_name().as_deref(), Some("claude-opus-4-1"));
370 }
371
372 #[test]
373 fn parse_result_success() {
374 let line = r#"{"type":"result","subtype":"success","session_id":"abc","uuid":"u2","result":"done","num_turns":3,"is_error":false}"#;
375 let msg: SdkOutput = serde_json::from_str(line).unwrap();
376 assert_eq!(msg.msg_type, "result");
377 assert_eq!(msg.subtype.as_deref(), Some("success"));
378 assert_eq!(msg.result.as_deref(), Some("done"));
379 assert_eq!(msg.num_turns, Some(3));
380 assert_eq!(msg.is_error, Some(false));
381 }
382
383 #[test]
384 fn parse_result_error() {
385 let line = r#"{"type":"result","subtype":"error_during_execution","is_error":true,"errors":["context window exceeded"],"session_id":"x"}"#;
386 let msg: SdkOutput = serde_json::from_str(line).unwrap();
387 assert_eq!(msg.msg_type, "result");
388 assert_eq!(msg.is_error, Some(true));
389 assert_eq!(
390 msg.errors.as_deref(),
391 Some(&["context window exceeded".to_string()][..])
392 );
393 }
394
395 #[test]
396 fn parse_control_request() {
397 let line = r#"{"type":"control_request","request_id":"req-1","request":{"subtype":"can_use_tool","tool_name":"Bash","tool_use_id":"tu-1","input":{"command":"ls"}}}"#;
398 let msg: SdkOutput = serde_json::from_str(line).unwrap();
399 assert_eq!(msg.msg_type, "control_request");
400 assert_eq!(msg.request_id.as_deref(), Some("req-1"));
401 assert_eq!(msg.request_subtype().as_deref(), Some("can_use_tool"));
402 assert_eq!(msg.request_tool_use_id().as_deref(), Some("tu-1"));
403 }
404
405 #[test]
406 fn parse_result_usage_and_model() {
407 let line = r#"{"type":"result","session_id":"x","usage":{"input_tokens":10,"cached_input_tokens":5,"cache_creation_input_tokens":20,"cache_creation":{"ephemeral_5m_input_tokens":5,"ephemeral_1h_input_tokens":15},"cache_read_input_tokens":3,"output_tokens":7,"reasoning_output_tokens":2},"message":{"model":"claude-opus-4-6-1m"}}"#;
408 let msg: SdkOutput = serde_json::from_str(line).unwrap();
409 assert_eq!(msg.model_name().as_deref(), Some("claude-opus-4-6-1m"));
410 assert_eq!(msg.usage_total_tokens(), 67);
411 }
412
413 #[test]
414 fn parse_stream_event() {
415 let line = r#"{"type":"stream_event","session_id":"abc","uuid":"u3","event":{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}}"#;
416 let msg: SdkOutput = serde_json::from_str(line).unwrap();
417 assert_eq!(msg.msg_type, "stream_event");
418 assert!(msg.event.is_some());
419 }
420
421 #[test]
422 fn unknown_message_type_is_tolerated() {
423 let line = r#"{"type":"future_new_type","session_id":"abc","some_field":42}"#;
424 let msg: SdkOutput = serde_json::from_str(line).unwrap();
425 assert_eq!(msg.msg_type, "future_new_type");
426 }
427
428 #[test]
431 fn extract_text_from_content_array() {
432 let msg = json!({
433 "role": "assistant",
434 "content": [
435 {"type": "text", "text": "Hello "},
436 {"type": "tool_use", "id": "t1", "name": "Bash", "input": {}},
437 {"type": "text", "text": "world"}
438 ]
439 });
440 assert_eq!(extract_assistant_text(&msg), "Hello world");
441 }
442
443 #[test]
444 fn extract_text_from_string_content() {
445 let msg = json!({"role": "assistant", "content": "plain string"});
446 assert_eq!(extract_assistant_text(&msg), "plain string");
447 }
448
449 #[test]
450 fn extract_text_empty_content() {
451 let msg = json!({"role": "assistant", "content": []});
452 assert_eq!(extract_assistant_text(&msg), "");
453 }
454
455 #[test]
456 fn extract_text_no_content_field() {
457 let msg = json!({"role": "assistant"});
458 assert_eq!(extract_assistant_text(&msg), "");
459 }
460
461 #[test]
462 fn extract_text_only_tool_use_blocks() {
463 let msg = json!({
464 "content": [
465 {"type": "tool_use", "id": "t1", "name": "Read", "input": {}}
466 ]
467 });
468 assert_eq!(extract_assistant_text(&msg), "");
469 }
470
471 #[test]
472 fn extract_stream_text_delta() {
473 let event = json!({
474 "type": "content_block_delta",
475 "index": 0,
476 "delta": {"type": "text_delta", "text": "incremental"}
477 });
478 assert_eq!(extract_stream_text(&event), Some("incremental".to_string()));
479 }
480
481 #[test]
482 fn extract_stream_text_non_text_delta() {
483 let event = json!({
484 "type": "content_block_delta",
485 "delta": {"type": "input_json_delta", "partial_json": "{}"}
486 });
487 assert_eq!(extract_stream_text(&event), None);
488 }
489
490 #[test]
491 fn extract_stream_text_no_delta() {
492 let event = json!({"type": "content_block_start"});
493 assert_eq!(extract_stream_text(&event), None);
494 }
495}