Skip to main content

batty_cli/shim/
sdk_types.rs

1//! NDJSON message types for the Claude Code stream-json SDK protocol.
2//!
3//! These types model the messages exchanged over stdin/stdout when Claude Code
4//! runs in `-p --input-format=stream-json --output-format=stream-json` mode.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9// ---------------------------------------------------------------------------
10// Messages written TO Claude's stdin
11// ---------------------------------------------------------------------------
12
13/// A user message sent to Claude Code via stdin.
14#[derive(Debug, Serialize)]
15pub struct SdkUserMessage {
16    #[serde(rename = "type")]
17    pub msg_type: &'static str, // always "user"
18    pub session_id: String,
19    pub message: UserMessageBody,
20    pub parent_tool_use_id: Option<String>,
21}
22
23impl SdkUserMessage {
24    pub fn new(session_id: &str, content: &str) -> Self {
25        Self {
26            msg_type: "user",
27            session_id: session_id.to_string(),
28            message: UserMessageBody {
29                role: "user".to_string(),
30                content: content.to_string(),
31            },
32            parent_tool_use_id: None,
33        }
34    }
35
36    /// Serialize to a single NDJSON line (no trailing newline).
37    pub fn to_ndjson(&self) -> String {
38        serde_json::to_string(self).expect("SdkUserMessage is always serializable")
39    }
40}
41
42#[derive(Debug, Serialize)]
43pub struct UserMessageBody {
44    pub role: String,
45    pub content: String,
46}
47
48/// A control response sent to Claude Code via stdin (e.g. permission approval).
49#[derive(Debug, Serialize)]
50pub struct SdkControlResponse {
51    #[serde(rename = "type")]
52    pub msg_type: &'static str, // always "control_response"
53    pub response: ControlResponseBody,
54}
55
56impl SdkControlResponse {
57    /// Build an approval response for a `can_use_tool` control request.
58    pub fn approve_tool(request_id: &str, tool_use_id: &str) -> Self {
59        Self {
60            msg_type: "control_response",
61            response: ControlResponseBody {
62                subtype: "success".to_string(),
63                request_id: request_id.to_string(),
64                response: Some(ToolApproval {
65                    tool_use_id: tool_use_id.to_string(),
66                    approved: true,
67                }),
68            },
69        }
70    }
71
72    /// Serialize to a single NDJSON line (no trailing newline).
73    pub fn to_ndjson(&self) -> String {
74        serde_json::to_string(self).expect("SdkControlResponse is always serializable")
75    }
76}
77
78#[derive(Debug, Serialize)]
79pub struct ControlResponseBody {
80    pub subtype: String,
81    pub request_id: String,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub response: Option<ToolApproval>,
84}
85
86#[derive(Debug, Serialize)]
87pub struct ToolApproval {
88    #[serde(rename = "toolUseID")]
89    pub tool_use_id: String,
90    pub approved: bool,
91}
92
93// ---------------------------------------------------------------------------
94// Messages read FROM Claude's stdout
95// ---------------------------------------------------------------------------
96
97/// A single NDJSON message received from Claude Code's stdout.
98///
99/// Uses a flat struct with optional fields rather than a tagged enum so that
100/// unknown or new message types are silently tolerated (future-proof).
101#[derive(Debug, Deserialize)]
102pub struct SdkOutput {
103    #[serde(rename = "type")]
104    pub msg_type: String,
105
106    #[serde(default)]
107    pub subtype: Option<String>,
108
109    #[serde(default)]
110    pub session_id: Option<String>,
111
112    #[serde(default)]
113    pub uuid: Option<String>,
114
115    /// For `assistant` messages: the full message object with `content` array.
116    #[serde(default)]
117    pub message: Option<Value>,
118
119    /// For `stream_event` messages: the stream event payload.
120    #[serde(default)]
121    pub event: Option<Value>,
122
123    /// For `result` messages: the final text result.
124    #[serde(default)]
125    pub result: Option<String>,
126
127    /// For `result` messages: number of API turns taken.
128    #[serde(default)]
129    pub num_turns: Option<u32>,
130
131    /// For `result` messages: whether an error occurred.
132    #[serde(default)]
133    pub is_error: Option<bool>,
134
135    /// For `result` error messages: list of error strings.
136    #[serde(default)]
137    pub errors: Option<Vec<String>>,
138
139    /// For `result` messages: usage counters for the completed turn.
140    #[serde(default)]
141    pub usage: Option<Value>,
142
143    /// For `result` messages: model-specific metadata.
144    #[serde(rename = "modelUsage", default)]
145    pub model_usage: Option<Value>,
146
147    /// For `control_request` messages: the request ID to echo in responses.
148    #[serde(default)]
149    pub request_id: Option<String>,
150
151    /// For `control_request` messages: the request payload.
152    #[serde(default)]
153    pub request: Option<Value>,
154}
155
156#[derive(Debug, Clone, Default, PartialEq, Eq)]
157pub struct SdkTokenUsage {
158    pub input_tokens: u64,
159    pub cached_input_tokens: u64,
160    pub cache_creation_input_tokens: u64,
161    pub cache_read_input_tokens: u64,
162    pub output_tokens: u64,
163    pub reasoning_output_tokens: u64,
164}
165
166impl SdkTokenUsage {
167    pub fn total_tokens(&self) -> u64 {
168        self.input_tokens
169            + self.cached_input_tokens
170            + self.cache_creation_input_tokens
171            + self.cache_read_input_tokens
172            + self.output_tokens
173            + self.reasoning_output_tokens
174    }
175}
176
177impl SdkOutput {
178    /// Extract the `subtype` from a nested `request` object (for control requests).
179    pub fn request_subtype(&self) -> Option<String> {
180        self.request
181            .as_ref()
182            .and_then(|r| r.get("subtype"))
183            .and_then(|v| v.as_str())
184            .map(String::from)
185    }
186
187    /// Extract the `tool_use_id` from a `can_use_tool` control request.
188    pub fn request_tool_use_id(&self) -> Option<String> {
189        self.request
190            .as_ref()
191            .and_then(|r| r.get("tool_use_id"))
192            .and_then(|v| v.as_str())
193            .map(String::from)
194    }
195
196    pub fn model_name(&self) -> Option<String> {
197        self.message
198            .as_ref()
199            .and_then(|message| message.get("model"))
200            .and_then(|value| value.as_str())
201            .map(String::from)
202            .or_else(|| {
203                self.model_usage
204                    .as_ref()
205                    .and_then(|value| value.get("model"))
206                    .and_then(|value| value.as_str())
207                    .map(String::from)
208            })
209    }
210
211    pub fn token_usage(&self) -> Option<SdkTokenUsage> {
212        let usage = self.usage.as_ref()?;
213        let cache_creation = usage.get("cache_creation");
214        let cache_creation_classified =
215            json_u64(cache_creation.and_then(|value| value.get("ephemeral_5m_input_tokens")))
216                + json_u64(cache_creation.and_then(|value| value.get("ephemeral_1h_input_tokens")));
217        Some(SdkTokenUsage {
218            input_tokens: json_u64(usage.get("input_tokens")),
219            cached_input_tokens: json_u64(usage.get("cached_input_tokens")),
220            cache_creation_input_tokens: json_u64(usage.get("cache_creation_input_tokens"))
221                .max(cache_creation_classified),
222            cache_read_input_tokens: json_u64(usage.get("cache_read_input_tokens")),
223            output_tokens: json_u64(usage.get("output_tokens")),
224            reasoning_output_tokens: json_u64(usage.get("reasoning_output_tokens")),
225        })
226    }
227
228    pub fn usage_total_tokens(&self) -> u64 {
229        let Some(usage) = self.usage.as_ref() else {
230            return 0;
231        };
232        let cache_creation = usage.get("cache_creation");
233        json_u64(usage.get("input_tokens"))
234            + json_u64(usage.get("cached_input_tokens"))
235            + json_u64(usage.get("cache_creation_input_tokens"))
236            + json_u64(cache_creation.and_then(|value| value.get("ephemeral_5m_input_tokens")))
237            + json_u64(cache_creation.and_then(|value| value.get("ephemeral_1h_input_tokens")))
238            + json_u64(usage.get("cache_read_input_tokens"))
239            + json_u64(usage.get("output_tokens"))
240            + json_u64(usage.get("reasoning_output_tokens"))
241    }
242}
243
244fn json_u64(value: Option<&Value>) -> u64 {
245    value.and_then(Value::as_u64).unwrap_or(0)
246}
247
248// ---------------------------------------------------------------------------
249// Text extraction helpers
250// ---------------------------------------------------------------------------
251
252/// Extract text content from an `assistant` message's `content` array.
253///
254/// The `message` field contains `{ "role": "assistant", "content": [...] }`
255/// where each content block may be `{ "type": "text", "text": "..." }` or
256/// a tool_use block. We only extract text blocks.
257pub fn extract_assistant_text(message: &Value) -> String {
258    let content = match message.get("content") {
259        Some(Value::Array(arr)) => arr,
260        Some(Value::String(s)) => return s.clone(),
261        _ => return String::new(),
262    };
263
264    let mut parts = Vec::new();
265    for block in content {
266        if block.get("type").and_then(|t| t.as_str()) == Some("text") {
267            if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
268                parts.push(text);
269            }
270        }
271    }
272    parts.join("")
273}
274
275/// Extract incremental text from a `stream_event` payload.
276///
277/// Stream events contain `{ "type": "content_block_delta", "delta": { "type": "text_delta", "text": "..." } }`
278/// or similar structures. We extract the text delta if present.
279pub fn extract_stream_text(event: &Value) -> Option<String> {
280    // content_block_delta with text_delta
281    if let Some(delta) = event.get("delta") {
282        if delta.get("type").and_then(|t| t.as_str()) == Some("text_delta") {
283            return delta.get("text").and_then(|t| t.as_str()).map(String::from);
284        }
285    }
286    None
287}
288
289// ---------------------------------------------------------------------------
290// Tests
291// ---------------------------------------------------------------------------
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use serde_json::json;
297
298    // --- SdkUserMessage ---
299
300    #[test]
301    fn user_message_serializes_correctly() {
302        let msg = SdkUserMessage::new("sess-1", "Fix the bug");
303        let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
304        assert_eq!(json["type"], "user");
305        assert_eq!(json["session_id"], "sess-1");
306        assert_eq!(json["message"]["role"], "user");
307        assert_eq!(json["message"]["content"], "Fix the bug");
308        assert!(json["parent_tool_use_id"].is_null());
309    }
310
311    #[test]
312    fn user_message_empty_session_id() {
313        let msg = SdkUserMessage::new("", "hello");
314        let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
315        assert_eq!(json["session_id"], "");
316    }
317
318    // --- SdkControlResponse ---
319
320    #[test]
321    fn approve_tool_serializes_correctly() {
322        let resp = SdkControlResponse::approve_tool("req-42", "tool-99");
323        let json: Value = serde_json::from_str(&resp.to_ndjson()).unwrap();
324        assert_eq!(json["type"], "control_response");
325        assert_eq!(json["response"]["subtype"], "success");
326        assert_eq!(json["response"]["request_id"], "req-42");
327        assert_eq!(json["response"]["response"]["toolUseID"], "tool-99");
328        assert_eq!(json["response"]["response"]["approved"], true);
329    }
330
331    // --- SdkOutput deserialization ---
332
333    #[test]
334    fn parse_assistant_message() {
335        let line = r#"{"type":"assistant","session_id":"abc","uuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello world"}]}}"#;
336        let msg: SdkOutput = serde_json::from_str(line).unwrap();
337        assert_eq!(msg.msg_type, "assistant");
338        assert_eq!(msg.session_id.as_deref(), Some("abc"));
339        assert!(msg.message.is_some());
340    }
341
342    #[test]
343    fn parse_result_token_usage_includes_cache_fields() {
344        let line = r#"{"type":"result","usage":{"input_tokens":10,"cached_input_tokens":4,"cache_creation_input_tokens":3,"cache_read_input_tokens":2,"output_tokens":5,"reasoning_output_tokens":1,"cache_creation":{"ephemeral_5m_input_tokens":7,"ephemeral_1h_input_tokens":11}}}"#;
345        let msg: SdkOutput = serde_json::from_str(line).unwrap();
346        let usage = msg.token_usage().unwrap();
347        assert_eq!(
348            usage,
349            SdkTokenUsage {
350                input_tokens: 10,
351                cached_input_tokens: 4,
352                cache_creation_input_tokens: 18,
353                cache_read_input_tokens: 2,
354                output_tokens: 5,
355                reasoning_output_tokens: 1,
356            }
357        );
358        assert_eq!(usage.total_tokens(), 40);
359    }
360
361    #[test]
362    fn model_name_prefers_message_then_model_usage() {
363        let assistant_line = r#"{"type":"assistant","message":{"model":"claude-sonnet-4-5","content":[{"type":"text","text":"hello"}]}}"#;
364        let assistant: SdkOutput = serde_json::from_str(assistant_line).unwrap();
365        assert_eq!(assistant.model_name().as_deref(), Some("claude-sonnet-4-5"));
366
367        let result_line = r#"{"type":"result","modelUsage":{"model":"claude-opus-4-1"}}"#;
368        let result: SdkOutput = serde_json::from_str(result_line).unwrap();
369        assert_eq!(result.model_name().as_deref(), Some("claude-opus-4-1"));
370    }
371
372    #[test]
373    fn parse_result_success() {
374        let line = r#"{"type":"result","subtype":"success","session_id":"abc","uuid":"u2","result":"done","num_turns":3,"is_error":false}"#;
375        let msg: SdkOutput = serde_json::from_str(line).unwrap();
376        assert_eq!(msg.msg_type, "result");
377        assert_eq!(msg.subtype.as_deref(), Some("success"));
378        assert_eq!(msg.result.as_deref(), Some("done"));
379        assert_eq!(msg.num_turns, Some(3));
380        assert_eq!(msg.is_error, Some(false));
381    }
382
383    #[test]
384    fn parse_result_error() {
385        let line = r#"{"type":"result","subtype":"error_during_execution","is_error":true,"errors":["context window exceeded"],"session_id":"x"}"#;
386        let msg: SdkOutput = serde_json::from_str(line).unwrap();
387        assert_eq!(msg.msg_type, "result");
388        assert_eq!(msg.is_error, Some(true));
389        assert_eq!(
390            msg.errors.as_deref(),
391            Some(&["context window exceeded".to_string()][..])
392        );
393    }
394
395    #[test]
396    fn parse_control_request() {
397        let line = r#"{"type":"control_request","request_id":"req-1","request":{"subtype":"can_use_tool","tool_name":"Bash","tool_use_id":"tu-1","input":{"command":"ls"}}}"#;
398        let msg: SdkOutput = serde_json::from_str(line).unwrap();
399        assert_eq!(msg.msg_type, "control_request");
400        assert_eq!(msg.request_id.as_deref(), Some("req-1"));
401        assert_eq!(msg.request_subtype().as_deref(), Some("can_use_tool"));
402        assert_eq!(msg.request_tool_use_id().as_deref(), Some("tu-1"));
403    }
404
405    #[test]
406    fn parse_result_usage_and_model() {
407        let line = r#"{"type":"result","session_id":"x","usage":{"input_tokens":10,"cached_input_tokens":5,"cache_creation_input_tokens":20,"cache_creation":{"ephemeral_5m_input_tokens":5,"ephemeral_1h_input_tokens":15},"cache_read_input_tokens":3,"output_tokens":7,"reasoning_output_tokens":2},"message":{"model":"claude-opus-4-6-1m"}}"#;
408        let msg: SdkOutput = serde_json::from_str(line).unwrap();
409        assert_eq!(msg.model_name().as_deref(), Some("claude-opus-4-6-1m"));
410        assert_eq!(msg.usage_total_tokens(), 67);
411    }
412
413    #[test]
414    fn parse_stream_event() {
415        let line = r#"{"type":"stream_event","session_id":"abc","uuid":"u3","event":{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}}"#;
416        let msg: SdkOutput = serde_json::from_str(line).unwrap();
417        assert_eq!(msg.msg_type, "stream_event");
418        assert!(msg.event.is_some());
419    }
420
421    #[test]
422    fn unknown_message_type_is_tolerated() {
423        let line = r#"{"type":"future_new_type","session_id":"abc","some_field":42}"#;
424        let msg: SdkOutput = serde_json::from_str(line).unwrap();
425        assert_eq!(msg.msg_type, "future_new_type");
426    }
427
428    // --- Text extraction ---
429
430    #[test]
431    fn extract_text_from_content_array() {
432        let msg = json!({
433            "role": "assistant",
434            "content": [
435                {"type": "text", "text": "Hello "},
436                {"type": "tool_use", "id": "t1", "name": "Bash", "input": {}},
437                {"type": "text", "text": "world"}
438            ]
439        });
440        assert_eq!(extract_assistant_text(&msg), "Hello world");
441    }
442
443    #[test]
444    fn extract_text_from_string_content() {
445        let msg = json!({"role": "assistant", "content": "plain string"});
446        assert_eq!(extract_assistant_text(&msg), "plain string");
447    }
448
449    #[test]
450    fn extract_text_empty_content() {
451        let msg = json!({"role": "assistant", "content": []});
452        assert_eq!(extract_assistant_text(&msg), "");
453    }
454
455    #[test]
456    fn extract_text_no_content_field() {
457        let msg = json!({"role": "assistant"});
458        assert_eq!(extract_assistant_text(&msg), "");
459    }
460
461    #[test]
462    fn extract_text_only_tool_use_blocks() {
463        let msg = json!({
464            "content": [
465                {"type": "tool_use", "id": "t1", "name": "Read", "input": {}}
466            ]
467        });
468        assert_eq!(extract_assistant_text(&msg), "");
469    }
470
471    #[test]
472    fn extract_stream_text_delta() {
473        let event = json!({
474            "type": "content_block_delta",
475            "index": 0,
476            "delta": {"type": "text_delta", "text": "incremental"}
477        });
478        assert_eq!(extract_stream_text(&event), Some("incremental".to_string()));
479    }
480
481    #[test]
482    fn extract_stream_text_non_text_delta() {
483        let event = json!({
484            "type": "content_block_delta",
485            "delta": {"type": "input_json_delta", "partial_json": "{}"}
486        });
487        assert_eq!(extract_stream_text(&event), None);
488    }
489
490    #[test]
491    fn extract_stream_text_no_delta() {
492        let event = json!({"type": "content_block_start"});
493        assert_eq!(extract_stream_text(&event), None);
494    }
495}