Skip to main content

batty_cli/shim/
sdk_types.rs

1//! NDJSON message types for the Claude Code stream-json SDK protocol.
2//!
3//! These types model the messages exchanged over stdin/stdout when Claude Code
4//! runs in `-p --input-format=stream-json --output-format=stream-json` mode.
5
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9// ---------------------------------------------------------------------------
10// Messages written TO Claude's stdin
11// ---------------------------------------------------------------------------
12
13/// A user message sent to Claude Code via stdin.
14#[derive(Debug, Serialize)]
15pub struct SdkUserMessage {
16    #[serde(rename = "type")]
17    pub msg_type: &'static str, // always "user"
18    pub session_id: String,
19    pub message: UserMessageBody,
20    pub parent_tool_use_id: Option<String>,
21}
22
23impl SdkUserMessage {
24    pub fn new(session_id: &str, content: &str) -> Self {
25        Self {
26            msg_type: "user",
27            session_id: session_id.to_string(),
28            message: UserMessageBody {
29                role: "user".to_string(),
30                content: content.to_string(),
31            },
32            parent_tool_use_id: None,
33        }
34    }
35
36    /// Serialize to a single NDJSON line (no trailing newline).
37    pub fn to_ndjson(&self) -> String {
38        serde_json::to_string(self).expect("SdkUserMessage is always serializable")
39    }
40}
41
42#[derive(Debug, Serialize)]
43pub struct UserMessageBody {
44    pub role: String,
45    pub content: String,
46}
47
48/// A control response sent to Claude Code via stdin (e.g. permission approval).
49#[derive(Debug, Serialize)]
50pub struct SdkControlResponse {
51    #[serde(rename = "type")]
52    pub msg_type: &'static str, // always "control_response"
53    pub response: ControlResponseBody,
54}
55
56impl SdkControlResponse {
57    /// Build an approval response for a `can_use_tool` control request.
58    pub fn approve_tool(request_id: &str, tool_use_id: &str) -> Self {
59        Self {
60            msg_type: "control_response",
61            response: ControlResponseBody {
62                subtype: "success".to_string(),
63                request_id: request_id.to_string(),
64                response: Some(ToolApproval {
65                    tool_use_id: tool_use_id.to_string(),
66                    approved: true,
67                }),
68            },
69        }
70    }
71
72    /// Serialize to a single NDJSON line (no trailing newline).
73    pub fn to_ndjson(&self) -> String {
74        serde_json::to_string(self).expect("SdkControlResponse is always serializable")
75    }
76}
77
78#[derive(Debug, Serialize)]
79pub struct ControlResponseBody {
80    pub subtype: String,
81    pub request_id: String,
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub response: Option<ToolApproval>,
84}
85
86#[derive(Debug, Serialize)]
87pub struct ToolApproval {
88    #[serde(rename = "toolUseID")]
89    pub tool_use_id: String,
90    pub approved: bool,
91}
92
93// ---------------------------------------------------------------------------
94// Messages read FROM Claude's stdout
95// ---------------------------------------------------------------------------
96
97/// A single NDJSON message received from Claude Code's stdout.
98///
99/// Uses a flat struct with optional fields rather than a tagged enum so that
100/// unknown or new message types are silently tolerated (future-proof).
101#[derive(Debug, Deserialize)]
102pub struct SdkOutput {
103    #[serde(rename = "type")]
104    pub msg_type: String,
105
106    #[serde(default)]
107    pub subtype: Option<String>,
108
109    #[serde(default)]
110    pub session_id: Option<String>,
111
112    #[serde(default)]
113    pub uuid: Option<String>,
114
115    /// For `assistant` messages: the full message object with `content` array.
116    #[serde(default)]
117    pub message: Option<Value>,
118
119    /// For `stream_event` messages: the stream event payload.
120    #[serde(default)]
121    pub event: Option<Value>,
122
123    /// For `result` messages: the final text result.
124    #[serde(default)]
125    pub result: Option<String>,
126
127    /// For `result` messages: number of API turns taken.
128    #[serde(default)]
129    pub num_turns: Option<u32>,
130
131    /// For `result` messages: whether an error occurred.
132    #[serde(default)]
133    pub is_error: Option<bool>,
134
135    /// For `result` error messages: list of error strings.
136    #[serde(default)]
137    pub errors: Option<Vec<String>>,
138
139    /// For `control_request` messages: the request ID to echo in responses.
140    #[serde(default)]
141    pub request_id: Option<String>,
142
143    /// For `control_request` messages: the request payload.
144    #[serde(default)]
145    pub request: Option<Value>,
146}
147
148impl SdkOutput {
149    /// Extract the `subtype` from a nested `request` object (for control requests).
150    pub fn request_subtype(&self) -> Option<String> {
151        self.request
152            .as_ref()
153            .and_then(|r| r.get("subtype"))
154            .and_then(|v| v.as_str())
155            .map(String::from)
156    }
157
158    /// Extract the `tool_use_id` from a `can_use_tool` control request.
159    pub fn request_tool_use_id(&self) -> Option<String> {
160        self.request
161            .as_ref()
162            .and_then(|r| r.get("tool_use_id"))
163            .and_then(|v| v.as_str())
164            .map(String::from)
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Text extraction helpers
170// ---------------------------------------------------------------------------
171
172/// Extract text content from an `assistant` message's `content` array.
173///
174/// The `message` field contains `{ "role": "assistant", "content": [...] }`
175/// where each content block may be `{ "type": "text", "text": "..." }` or
176/// a tool_use block. We only extract text blocks.
177pub fn extract_assistant_text(message: &Value) -> String {
178    let content = match message.get("content") {
179        Some(Value::Array(arr)) => arr,
180        Some(Value::String(s)) => return s.clone(),
181        _ => return String::new(),
182    };
183
184    let mut parts = Vec::new();
185    for block in content {
186        if block.get("type").and_then(|t| t.as_str()) == Some("text") {
187            if let Some(text) = block.get("text").and_then(|t| t.as_str()) {
188                parts.push(text);
189            }
190        }
191    }
192    parts.join("")
193}
194
195/// Extract incremental text from a `stream_event` payload.
196///
197/// Stream events contain `{ "type": "content_block_delta", "delta": { "type": "text_delta", "text": "..." } }`
198/// or similar structures. We extract the text delta if present.
199pub fn extract_stream_text(event: &Value) -> Option<String> {
200    // content_block_delta with text_delta
201    if let Some(delta) = event.get("delta") {
202        if delta.get("type").and_then(|t| t.as_str()) == Some("text_delta") {
203            return delta.get("text").and_then(|t| t.as_str()).map(String::from);
204        }
205    }
206    None
207}
208
209// ---------------------------------------------------------------------------
210// Tests
211// ---------------------------------------------------------------------------
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216    use serde_json::json;
217
218    // --- SdkUserMessage ---
219
220    #[test]
221    fn user_message_serializes_correctly() {
222        let msg = SdkUserMessage::new("sess-1", "Fix the bug");
223        let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
224        assert_eq!(json["type"], "user");
225        assert_eq!(json["session_id"], "sess-1");
226        assert_eq!(json["message"]["role"], "user");
227        assert_eq!(json["message"]["content"], "Fix the bug");
228        assert!(json["parent_tool_use_id"].is_null());
229    }
230
231    #[test]
232    fn user_message_empty_session_id() {
233        let msg = SdkUserMessage::new("", "hello");
234        let json: Value = serde_json::from_str(&msg.to_ndjson()).unwrap();
235        assert_eq!(json["session_id"], "");
236    }
237
238    // --- SdkControlResponse ---
239
240    #[test]
241    fn approve_tool_serializes_correctly() {
242        let resp = SdkControlResponse::approve_tool("req-42", "tool-99");
243        let json: Value = serde_json::from_str(&resp.to_ndjson()).unwrap();
244        assert_eq!(json["type"], "control_response");
245        assert_eq!(json["response"]["subtype"], "success");
246        assert_eq!(json["response"]["request_id"], "req-42");
247        assert_eq!(json["response"]["response"]["toolUseID"], "tool-99");
248        assert_eq!(json["response"]["response"]["approved"], true);
249    }
250
251    // --- SdkOutput deserialization ---
252
253    #[test]
254    fn parse_assistant_message() {
255        let line = r#"{"type":"assistant","session_id":"abc","uuid":"u1","message":{"role":"assistant","content":[{"type":"text","text":"hello world"}]}}"#;
256        let msg: SdkOutput = serde_json::from_str(line).unwrap();
257        assert_eq!(msg.msg_type, "assistant");
258        assert_eq!(msg.session_id.as_deref(), Some("abc"));
259        assert!(msg.message.is_some());
260    }
261
262    #[test]
263    fn parse_result_success() {
264        let line = r#"{"type":"result","subtype":"success","session_id":"abc","uuid":"u2","result":"done","num_turns":3,"is_error":false}"#;
265        let msg: SdkOutput = serde_json::from_str(line).unwrap();
266        assert_eq!(msg.msg_type, "result");
267        assert_eq!(msg.subtype.as_deref(), Some("success"));
268        assert_eq!(msg.result.as_deref(), Some("done"));
269        assert_eq!(msg.num_turns, Some(3));
270        assert_eq!(msg.is_error, Some(false));
271    }
272
273    #[test]
274    fn parse_result_error() {
275        let line = r#"{"type":"result","subtype":"error_during_execution","is_error":true,"errors":["context window exceeded"],"session_id":"x"}"#;
276        let msg: SdkOutput = serde_json::from_str(line).unwrap();
277        assert_eq!(msg.msg_type, "result");
278        assert_eq!(msg.is_error, Some(true));
279        assert_eq!(
280            msg.errors.as_deref(),
281            Some(&["context window exceeded".to_string()][..])
282        );
283    }
284
285    #[test]
286    fn parse_control_request() {
287        let line = r#"{"type":"control_request","request_id":"req-1","request":{"subtype":"can_use_tool","tool_name":"Bash","tool_use_id":"tu-1","input":{"command":"ls"}}}"#;
288        let msg: SdkOutput = serde_json::from_str(line).unwrap();
289        assert_eq!(msg.msg_type, "control_request");
290        assert_eq!(msg.request_id.as_deref(), Some("req-1"));
291        assert_eq!(msg.request_subtype().as_deref(), Some("can_use_tool"));
292        assert_eq!(msg.request_tool_use_id().as_deref(), Some("tu-1"));
293    }
294
295    #[test]
296    fn parse_stream_event() {
297        let line = r#"{"type":"stream_event","session_id":"abc","uuid":"u3","event":{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}}"#;
298        let msg: SdkOutput = serde_json::from_str(line).unwrap();
299        assert_eq!(msg.msg_type, "stream_event");
300        assert!(msg.event.is_some());
301    }
302
303    #[test]
304    fn unknown_message_type_is_tolerated() {
305        let line = r#"{"type":"future_new_type","session_id":"abc","some_field":42}"#;
306        let msg: SdkOutput = serde_json::from_str(line).unwrap();
307        assert_eq!(msg.msg_type, "future_new_type");
308    }
309
310    // --- Text extraction ---
311
312    #[test]
313    fn extract_text_from_content_array() {
314        let msg = json!({
315            "role": "assistant",
316            "content": [
317                {"type": "text", "text": "Hello "},
318                {"type": "tool_use", "id": "t1", "name": "Bash", "input": {}},
319                {"type": "text", "text": "world"}
320            ]
321        });
322        assert_eq!(extract_assistant_text(&msg), "Hello world");
323    }
324
325    #[test]
326    fn extract_text_from_string_content() {
327        let msg = json!({"role": "assistant", "content": "plain string"});
328        assert_eq!(extract_assistant_text(&msg), "plain string");
329    }
330
331    #[test]
332    fn extract_text_empty_content() {
333        let msg = json!({"role": "assistant", "content": []});
334        assert_eq!(extract_assistant_text(&msg), "");
335    }
336
337    #[test]
338    fn extract_text_no_content_field() {
339        let msg = json!({"role": "assistant"});
340        assert_eq!(extract_assistant_text(&msg), "");
341    }
342
343    #[test]
344    fn extract_text_only_tool_use_blocks() {
345        let msg = json!({
346            "content": [
347                {"type": "tool_use", "id": "t1", "name": "Read", "input": {}}
348            ]
349        });
350        assert_eq!(extract_assistant_text(&msg), "");
351    }
352
353    #[test]
354    fn extract_stream_text_delta() {
355        let event = json!({
356            "type": "content_block_delta",
357            "index": 0,
358            "delta": {"type": "text_delta", "text": "incremental"}
359        });
360        assert_eq!(extract_stream_text(&event), Some("incremental".to_string()));
361    }
362
363    #[test]
364    fn extract_stream_text_non_text_delta() {
365        let event = json!({
366            "type": "content_block_delta",
367            "delta": {"type": "input_json_delta", "partial_json": "{}"}
368        });
369        assert_eq!(extract_stream_text(&event), None);
370    }
371
372    #[test]
373    fn extract_stream_text_no_delta() {
374        let event = json!({"type": "content_block_start"});
375        assert_eq!(extract_stream_text(&event), None);
376    }
377}