1use std::path::{Path, PathBuf};
15use std::process::Stdio;
16
17use async_trait::async_trait;
18use serde_json::{Value, json};
19use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
20use tokio::process::{Child, Command};
21use tokio::sync::mpsc;
22use tracing::{debug, trace, warn};
23
24use crate::core::{AgentEvent, ClientFrame, Content, StopReason, TextChannel, Usage};
25use crate::driver::{Driver, DriverError};
26
27#[derive(Debug)]
31pub struct ClaudeCodeDriver {
32 writer_tx: Option<mpsc::Sender<String>>,
36
37 reader_rx: mpsc::Receiver<AgentEvent>,
39
40 child: Option<Child>,
42}
43
44impl ClaudeCodeDriver {
45 pub fn finish_input(&mut self) {
54 self.writer_tx = None;
55 }
56}
57
58impl ClaudeCodeDriver {
59 pub async fn spawn(cwd: impl AsRef<Path>) -> Result<Self, DriverError> {
68 Self::builder(cwd).spawn().await
69 }
70
71 pub fn builder(cwd: impl AsRef<Path>) -> ClaudeCodeDriverBuilder {
73 ClaudeCodeDriverBuilder {
74 bin: None,
75 cwd: cwd.as_ref().to_path_buf(),
76 model: None,
77 session_id: None,
78 resume: None,
79 replay_user_messages: true,
80 dangerously_skip_permissions: true,
81 }
82 }
83
84 async fn spawn_inner(b: ClaudeCodeDriverBuilder) -> Result<Self, DriverError> {
85 let ClaudeCodeDriverBuilder {
86 bin,
87 cwd,
88 model,
89 session_id,
90 resume,
91 replay_user_messages,
92 dangerously_skip_permissions,
93 } = b;
94
95 let bin = bin
96 .or_else(|| std::env::var("CLAUDE_BIN").ok())
97 .unwrap_or_else(|| "claude".to_string());
98
99 let mut cmd = Command::new(&bin);
100 cmd.arg("-p")
101 .arg("--input-format=stream-json")
102 .arg("--output-format=stream-json")
103 .arg("--verbose")
104 .current_dir(&cwd)
105 .stdin(Stdio::piped())
106 .stdout(Stdio::piped())
107 .stderr(Stdio::piped())
108 .kill_on_drop(true);
109
110 if dangerously_skip_permissions {
111 cmd.arg("--dangerously-skip-permissions");
112 }
113 if replay_user_messages {
114 cmd.arg("--replay-user-messages");
115 }
116 if let Some(m) = &model {
117 cmd.arg("--model").arg(m);
118 }
119 if let Some(sid) = &session_id {
120 cmd.arg("--session-id").arg(sid);
121 }
122 if let Some(rid) = &resume {
123 cmd.arg("--resume").arg(rid);
124 }
125
126 for var in [
131 "CLAUDECODE",
132 "CLAUDE_CODE_ENTRYPOINT",
133 "CLAUDE_CODE_SSE_PORT",
134 "CLAUDE_CODE_OAUTH_TOKEN",
135 "CLAUDE_CODE_SESSION_ID",
136 "CLAUDE_SESSION_ID",
137 ] {
138 cmd.env_remove(var);
139 }
140
141 debug!(
142 bin = %bin,
143 cwd = %cwd.display(),
144 session_mode = replay_user_messages,
145 resume = ?resume,
146 session_id = ?session_id,
147 "spawning claude",
148 );
149
150 let mut child = cmd.spawn().map_err(|e| {
151 if e.kind() == std::io::ErrorKind::NotFound {
152 DriverError::BinaryNotFound(bin.clone())
153 } else {
154 DriverError::SpawnFailed(e)
155 }
156 })?;
157
158 let stdin = child.stdin.take().ok_or(DriverError::AgentExited)?;
159 let stdout = child.stdout.take().ok_or(DriverError::AgentExited)?;
160 let stderr = child.stderr.take().ok_or(DriverError::AgentExited)?;
161
162 let (writer_tx, writer_rx) = mpsc::channel::<String>(32);
163 let (reader_tx, reader_rx) = mpsc::channel::<AgentEvent>(64);
164
165 tokio::spawn(writer_task(stdin, writer_rx));
167
168 tokio::spawn(reader_task(stdout, reader_tx));
170
171 tokio::spawn(stderr_drain(stderr));
173
174 Ok(Self {
175 writer_tx: Some(writer_tx),
176 reader_rx,
177 child: Some(child),
178 })
179 }
180}
181
182#[derive(Debug, Clone)]
211pub struct ClaudeCodeDriverBuilder {
212 bin: Option<String>,
213 cwd: PathBuf,
214 model: Option<String>,
215 session_id: Option<String>,
216 resume: Option<String>,
217 replay_user_messages: bool,
218 dangerously_skip_permissions: bool,
219}
220
221impl ClaudeCodeDriverBuilder {
222 pub fn bin(mut self, bin: impl Into<String>) -> Self {
224 self.bin = Some(bin.into());
225 self
226 }
227
228 pub fn model(mut self, model: impl Into<String>) -> Self {
230 self.model = Some(model.into());
231 self
232 }
233
234 pub fn session_id(mut self, uuid: impl Into<String>) -> Self {
238 self.session_id = Some(uuid.into());
239 self
240 }
241
242 pub fn resume(mut self, uuid: impl Into<String>) -> Self {
245 self.resume = Some(uuid.into());
246 self
247 }
248
249 pub fn replay_user_messages(mut self, on: bool) -> Self {
260 self.replay_user_messages = on;
261 self
262 }
263
264 pub fn dangerously_skip_permissions(mut self, on: bool) -> Self {
270 self.dangerously_skip_permissions = on;
271 self
272 }
273
274 pub async fn spawn(self) -> Result<ClaudeCodeDriver, DriverError> {
276 ClaudeCodeDriver::spawn_inner(self).await
277 }
278}
279
280#[async_trait]
281impl Driver for ClaudeCodeDriver {
282 async fn send(&mut self, frame: ClientFrame) -> Result<(), DriverError> {
283 let tx = self
284 .writer_tx
285 .as_ref()
286 .ok_or(DriverError::AgentExited)?;
287 let line = encode_client_frame(&frame)?;
288 trace!(line = %line, "→ claude");
289 tx.send(line).await.map_err(|_| DriverError::AgentExited)?;
290 Ok(())
291 }
292
293 async fn next_event(&mut self) -> Option<AgentEvent> {
294 self.reader_rx.recv().await
295 }
296
297 async fn shutdown(&mut self) -> Result<(), DriverError> {
298 if let Some(mut child) = self.child.take() {
299 let _ = child.start_kill();
300 let _ = child.wait().await;
301 }
302 Ok(())
303 }
304}
305
306async fn writer_task(
311 mut stdin: tokio::process::ChildStdin,
312 mut rx: mpsc::Receiver<String>,
313) {
314 while let Some(line) = rx.recv().await {
315 if let Err(e) = stdin.write_all(line.as_bytes()).await {
316 warn!(error = %e, "writer task: write failed, exiting");
317 return;
318 }
319 if !line.ends_with('\n') {
320 let _ = stdin.write_all(b"\n").await;
321 }
322 let _ = stdin.flush().await;
323 }
324 debug!("writer task: input channel closed, exiting");
325}
326
327async fn reader_task(stdout: tokio::process::ChildStdout, tx: mpsc::Sender<AgentEvent>) {
328 let mut lines = BufReader::new(stdout).lines();
329 loop {
330 match lines.next_line().await {
331 Ok(Some(line)) => {
332 trace!(line = %line, "← claude");
333 let value: Value = match serde_json::from_str(&line) {
334 Ok(v) => v,
335 Err(e) => {
336 warn!(error = %e, raw = %line, "reader: malformed JSON, skipping");
337 continue;
338 }
339 };
340 for event in parse_stream_frame(&value) {
341 if tx.send(event).await.is_err() {
342 return;
343 }
344 }
345 }
346 Ok(None) => {
347 debug!("reader: stdout EOF");
348 return;
349 }
350 Err(e) => {
351 warn!(error = %e, "reader: read error");
352 return;
353 }
354 }
355 }
356}
357
358async fn stderr_drain(stderr: tokio::process::ChildStderr) {
359 let mut lines = BufReader::new(stderr).lines();
360 while let Ok(Some(line)) = lines.next_line().await {
361 debug!(target: "cap_rs::stream_json::stderr", "{}", line);
362 }
363}
364
365fn encode_client_frame(frame: &ClientFrame) -> Result<String, DriverError> {
370 match frame {
371 ClientFrame::Prompt { content } => {
372 let parts: Vec<Value> = content
373 .iter()
374 .map(|c| match c {
375 Content::Text(t) => json!({"type": "text", "text": t}),
376 Content::Image { mime, data } => json!({
377 "type": "image",
378 "source": {
379 "type": "base64",
380 "media_type": mime,
381 "data": base64_encode(data),
382 }
383 }),
384 })
385 .collect();
386 let frame_json = json!({
387 "type": "user",
388 "message": {
389 "role": "user",
390 "content": parts
391 }
392 });
393 Ok(frame_json.to_string())
394 }
395 ClientFrame::Cancel => {
396 Ok(json!({"type": "control", "subtype": "cancel"}).to_string())
400 }
401 ClientFrame::AskUserAnswer { ask_id, value } => {
402 let text = format!("[answer to {ask_id}]: {value}");
405 Ok(json!({
406 "type": "user",
407 "message": {
408 "role": "user",
409 "content": [{"type": "text", "text": text}]
410 }
411 })
412 .to_string())
413 }
414 ClientFrame::PermissionResponse { req_id, decision } => {
415 let text = format!("[permission {req_id}]: {decision:?}");
416 Ok(json!({
417 "type": "user",
418 "message": {
419 "role": "user",
420 "content": [{"type": "text", "text": text}]
421 }
422 })
423 .to_string())
424 }
425 }
426}
427
428fn parse_stream_frame(frame: &Value) -> Vec<AgentEvent> {
430 let kind = frame.get("type").and_then(Value::as_str).unwrap_or("");
431 match kind {
432 "system" => match frame.get("subtype").and_then(Value::as_str).unwrap_or("") {
433 "init" => vec![AgentEvent::Ready {
434 session_id: frame
435 .get("session_id")
436 .and_then(Value::as_str)
437 .unwrap_or_default()
438 .to_string(),
439 model: frame
440 .get("model")
441 .and_then(Value::as_str)
442 .map(String::from),
443 }],
444 _ => vec![],
445 },
446
447 "assistant" => {
448 let msg = frame.get("message").cloned().unwrap_or(Value::Null);
449 let msg_id = msg
450 .get("id")
451 .and_then(Value::as_str)
452 .unwrap_or_default()
453 .to_string();
454 let content = msg
455 .get("content")
456 .and_then(Value::as_array)
457 .cloned()
458 .unwrap_or_default();
459
460 let mut events = Vec::new();
461 for block in content {
462 let btype = block.get("type").and_then(Value::as_str).unwrap_or("");
463 match btype {
464 "text" => {
465 let text = block
466 .get("text")
467 .and_then(Value::as_str)
468 .unwrap_or_default()
469 .to_string();
470 if !text.is_empty() {
471 events.push(AgentEvent::TextChunk {
472 msg_id: msg_id.clone(),
473 text,
474 channel: TextChannel::Assistant,
475 });
476 }
477 }
478 "thinking" => {
479 let text = block
480 .get("thinking")
481 .and_then(Value::as_str)
482 .or_else(|| block.get("text").and_then(Value::as_str))
483 .unwrap_or_default()
484 .to_string();
485 if !text.is_empty() {
486 events.push(AgentEvent::Thought {
487 msg_id: msg_id.clone(),
488 text,
489 });
490 }
491 }
492 "tool_use" => {
493 events.push(AgentEvent::ToolCallStart {
494 call_id: block
495 .get("id")
496 .and_then(Value::as_str)
497 .unwrap_or_default()
498 .to_string(),
499 name: block
500 .get("name")
501 .and_then(Value::as_str)
502 .unwrap_or_default()
503 .to_string(),
504 input: block.get("input").cloned().unwrap_or(Value::Null),
505 });
506 }
507 _ => {
508 trace!(block_type = btype, "ignoring unknown assistant block");
509 }
510 }
511 }
512 events
513 }
514
515 "user" => {
516 let content = frame
518 .get("message")
519 .and_then(|m| m.get("content"))
520 .and_then(Value::as_array)
521 .cloned()
522 .unwrap_or_default();
523 let mut events = Vec::new();
524 for block in content {
525 if block.get("type").and_then(Value::as_str) == Some("tool_result") {
526 let call_id = block
527 .get("tool_use_id")
528 .and_then(Value::as_str)
529 .unwrap_or_default()
530 .to_string();
531 let output = extract_tool_result_output(&block);
532 let is_error = block
533 .get("is_error")
534 .and_then(Value::as_bool)
535 .unwrap_or(false);
536 events.push(AgentEvent::ToolCallEnd {
537 call_id,
538 output,
539 is_error,
540 });
541 }
542 }
543 events
544 }
545
546 "result" => {
547 let usage = parse_usage(frame);
548 let stop_reason = match frame.get("subtype").and_then(Value::as_str) {
549 Some("success") => StopReason::EndTurn,
550 Some("error_max_turns") => StopReason::MaxTokens,
551 Some("error_during_execution") => StopReason::Error,
552 Some(other) if other.starts_with("error") => StopReason::Error,
553 _ => StopReason::EndTurn,
554 };
555 vec![AgentEvent::Done { stop_reason, usage }]
556 }
557
558 "stream_event" => {
559 vec![]
562 }
563
564 other => {
565 trace!(frame_type = other, "ignoring unknown stream-json frame");
566 vec![]
567 }
568 }
569}
570
571fn extract_tool_result_output(block: &Value) -> String {
572 match block.get("content") {
573 Some(Value::String(s)) => s.clone(),
574 Some(Value::Array(arr)) => arr
575 .iter()
576 .filter_map(|part| part.get("text").and_then(Value::as_str))
577 .collect::<Vec<_>>()
578 .join("\n"),
579 _ => String::new(),
580 }
581}
582
583fn parse_usage(frame: &Value) -> Usage {
584 let u = frame.get("usage").cloned().unwrap_or(Value::Null);
585 Usage {
586 input_tokens: u.get("input_tokens").and_then(Value::as_u64).unwrap_or(0),
587 output_tokens: u.get("output_tokens").and_then(Value::as_u64).unwrap_or(0),
588 cache_read_tokens: u
589 .get("cache_read_input_tokens")
590 .and_then(Value::as_u64)
591 .unwrap_or(0),
592 cache_creation_tokens: u
593 .get("cache_creation_input_tokens")
594 .and_then(Value::as_u64)
595 .unwrap_or(0),
596 cost_usd_estimate: frame.get("total_cost_usd").and_then(Value::as_f64),
597 duration: frame
598 .get("duration_ms")
599 .and_then(Value::as_u64)
600 .map(std::time::Duration::from_millis),
601 model_id: frame
602 .get("modelUsage")
603 .and_then(Value::as_object)
604 .and_then(|m| m.keys().next().cloned()),
605 }
606}
607
608fn base64_encode(data: &[u8]) -> String {
610 const T: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
611 let mut out = String::with_capacity(((data.len() + 2) / 3) * 4);
612 let mut i = 0;
613 while i + 3 <= data.len() {
614 let b = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8) | (data[i + 2] as u32);
615 out.push(T[((b >> 18) & 63) as usize] as char);
616 out.push(T[((b >> 12) & 63) as usize] as char);
617 out.push(T[((b >> 6) & 63) as usize] as char);
618 out.push(T[(b & 63) as usize] as char);
619 i += 3;
620 }
621 let rem = data.len() - i;
622 if rem == 1 {
623 let b = (data[i] as u32) << 16;
624 out.push(T[((b >> 18) & 63) as usize] as char);
625 out.push(T[((b >> 12) & 63) as usize] as char);
626 out.push_str("==");
627 } else if rem == 2 {
628 let b = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8);
629 out.push(T[((b >> 18) & 63) as usize] as char);
630 out.push(T[((b >> 12) & 63) as usize] as char);
631 out.push(T[((b >> 6) & 63) as usize] as char);
632 out.push('=');
633 }
634 out
635}
636
637#[cfg(test)]
638mod tests {
639 use super::*;
640
641 #[test]
642 fn parse_init_frame() {
643 let v: Value = serde_json::from_str(
644 r#"{"type":"system","subtype":"init","session_id":"sess_1","model":"claude-opus-4-7"}"#,
645 )
646 .unwrap();
647 let events = parse_stream_frame(&v);
648 assert_eq!(events.len(), 1);
649 assert!(matches!(events[0], AgentEvent::Ready { .. }));
650 }
651
652 #[test]
653 fn parse_assistant_text() {
654 let v: Value = serde_json::from_str(
655 r#"{"type":"assistant","message":{"id":"msg_1","content":[{"type":"text","text":"hello"}]}}"#,
656 )
657 .unwrap();
658 let events = parse_stream_frame(&v);
659 assert_eq!(events.len(), 1);
660 match &events[0] {
661 AgentEvent::TextChunk { text, .. } => assert_eq!(text, "hello"),
662 other => panic!("wrong variant: {other:?}"),
663 }
664 }
665
666 #[test]
667 fn parse_tool_use() {
668 let v: Value = serde_json::from_str(
669 r#"{"type":"assistant","message":{"id":"m","content":[
670 {"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
671 ]}}"#,
672 )
673 .unwrap();
674 let events = parse_stream_frame(&v);
675 match &events[0] {
676 AgentEvent::ToolCallStart { name, .. } => assert_eq!(name, "Bash"),
677 other => panic!("wrong: {other:?}"),
678 }
679 }
680
681 #[test]
682 fn parse_result_with_usage() {
683 let v: Value = serde_json::from_str(
684 r#"{"type":"result","subtype":"success","duration_ms":1500,"total_cost_usd":0.0021,
685 "usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":0,"cache_creation_input_tokens":0}}"#,
686 )
687 .unwrap();
688 let events = parse_stream_frame(&v);
689 match &events[0] {
690 AgentEvent::Done { usage, stop_reason } => {
691 assert_eq!(*stop_reason, StopReason::EndTurn);
692 assert_eq!(usage.input_tokens, 10);
693 assert_eq!(usage.output_tokens, 20);
694 assert_eq!(usage.cost_usd_estimate, Some(0.0021));
695 }
696 other => panic!("wrong: {other:?}"),
697 }
698 }
699
700 #[test]
701 fn encode_simple_prompt() {
702 let frame = ClientFrame::Prompt {
703 content: vec![Content::Text("hi".into())],
704 };
705 let line = encode_client_frame(&frame).unwrap();
706 let v: Value = serde_json::from_str(&line).unwrap();
707 assert_eq!(v["type"], "user");
708 assert_eq!(v["message"]["content"][0]["text"], "hi");
709 }
710}