use std::path::{Path, PathBuf};
use std::process::Stdio;
use async_trait::async_trait;
use serde_json::{Value, json};
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, Command};
use tokio::sync::mpsc;
use tracing::{debug, trace, warn};
use crate::core::{AgentEvent, ClientFrame, Content, StopReason, TextChannel, Usage};
use crate::driver::{Driver, DriverError};
#[derive(Debug)]
pub struct ClaudeCodeDriver {
writer_tx: Option<mpsc::Sender<String>>,
reader_rx: mpsc::Receiver<AgentEvent>,
child: Option<Child>,
}
impl ClaudeCodeDriver {
pub fn finish_input(&mut self) {
self.writer_tx = None;
}
}
impl ClaudeCodeDriver {
pub async fn spawn(cwd: impl AsRef<Path>) -> Result<Self, DriverError> {
Self::builder(cwd).spawn().await
}
pub fn builder(cwd: impl AsRef<Path>) -> ClaudeCodeDriverBuilder {
ClaudeCodeDriverBuilder {
bin: None,
cwd: cwd.as_ref().to_path_buf(),
model: None,
session_id: None,
resume: None,
replay_user_messages: true,
dangerously_skip_permissions: true,
}
}
async fn spawn_inner(b: ClaudeCodeDriverBuilder) -> Result<Self, DriverError> {
let ClaudeCodeDriverBuilder {
bin,
cwd,
model,
session_id,
resume,
replay_user_messages,
dangerously_skip_permissions,
} = b;
let bin = bin
.or_else(|| std::env::var("CLAUDE_BIN").ok())
.unwrap_or_else(|| "claude".to_string());
let mut cmd = Command::new(&bin);
cmd.arg("-p")
.arg("--input-format=stream-json")
.arg("--output-format=stream-json")
.arg("--verbose")
.current_dir(&cwd)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
if dangerously_skip_permissions {
cmd.arg("--dangerously-skip-permissions");
}
if replay_user_messages {
cmd.arg("--replay-user-messages");
}
if let Some(m) = &model {
cmd.arg("--model").arg(m);
}
if let Some(sid) = &session_id {
cmd.arg("--session-id").arg(sid);
}
if let Some(rid) = &resume {
cmd.arg("--resume").arg(rid);
}
for var in [
"CLAUDECODE",
"CLAUDE_CODE_ENTRYPOINT",
"CLAUDE_CODE_SSE_PORT",
"CLAUDE_CODE_OAUTH_TOKEN",
"CLAUDE_CODE_SESSION_ID",
"CLAUDE_SESSION_ID",
] {
cmd.env_remove(var);
}
debug!(
bin = %bin,
cwd = %cwd.display(),
session_mode = replay_user_messages,
resume = ?resume,
session_id = ?session_id,
"spawning claude",
);
let mut child = cmd.spawn().map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
DriverError::BinaryNotFound(bin.clone())
} else {
DriverError::SpawnFailed(e)
}
})?;
let stdin = child.stdin.take().ok_or(DriverError::AgentExited)?;
let stdout = child.stdout.take().ok_or(DriverError::AgentExited)?;
let stderr = child.stderr.take().ok_or(DriverError::AgentExited)?;
let (writer_tx, writer_rx) = mpsc::channel::<String>(32);
let (reader_tx, reader_rx) = mpsc::channel::<AgentEvent>(64);
tokio::spawn(writer_task(stdin, writer_rx));
tokio::spawn(reader_task(stdout, reader_tx));
tokio::spawn(stderr_drain(stderr));
Ok(Self {
writer_tx: Some(writer_tx),
reader_rx,
child: Some(child),
})
}
}
#[derive(Debug, Clone)]
pub struct ClaudeCodeDriverBuilder {
bin: Option<String>,
cwd: PathBuf,
model: Option<String>,
session_id: Option<String>,
resume: Option<String>,
replay_user_messages: bool,
dangerously_skip_permissions: bool,
}
impl ClaudeCodeDriverBuilder {
pub fn bin(mut self, bin: impl Into<String>) -> Self {
self.bin = Some(bin.into());
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn session_id(mut self, uuid: impl Into<String>) -> Self {
self.session_id = Some(uuid.into());
self
}
pub fn resume(mut self, uuid: impl Into<String>) -> Self {
self.resume = Some(uuid.into());
self
}
pub fn replay_user_messages(mut self, on: bool) -> Self {
self.replay_user_messages = on;
self
}
pub fn dangerously_skip_permissions(mut self, on: bool) -> Self {
self.dangerously_skip_permissions = on;
self
}
pub async fn spawn(self) -> Result<ClaudeCodeDriver, DriverError> {
ClaudeCodeDriver::spawn_inner(self).await
}
}
#[async_trait]
impl Driver for ClaudeCodeDriver {
async fn send(&mut self, frame: ClientFrame) -> Result<(), DriverError> {
let tx = self
.writer_tx
.as_ref()
.ok_or(DriverError::AgentExited)?;
let line = encode_client_frame(&frame)?;
trace!(line = %line, "โ claude");
tx.send(line).await.map_err(|_| DriverError::AgentExited)?;
Ok(())
}
async fn next_event(&mut self) -> Option<AgentEvent> {
self.reader_rx.recv().await
}
async fn shutdown(&mut self) -> Result<(), DriverError> {
if let Some(mut child) = self.child.take() {
let _ = child.start_kill();
let _ = child.wait().await;
}
Ok(())
}
}
async fn writer_task(
mut stdin: tokio::process::ChildStdin,
mut rx: mpsc::Receiver<String>,
) {
while let Some(line) = rx.recv().await {
if let Err(e) = stdin.write_all(line.as_bytes()).await {
warn!(error = %e, "writer task: write failed, exiting");
return;
}
if !line.ends_with('\n') {
let _ = stdin.write_all(b"\n").await;
}
let _ = stdin.flush().await;
}
debug!("writer task: input channel closed, exiting");
}
async fn reader_task(stdout: tokio::process::ChildStdout, tx: mpsc::Sender<AgentEvent>) {
let mut lines = BufReader::new(stdout).lines();
loop {
match lines.next_line().await {
Ok(Some(line)) => {
trace!(line = %line, "โ claude");
let value: Value = match serde_json::from_str(&line) {
Ok(v) => v,
Err(e) => {
warn!(error = %e, raw = %line, "reader: malformed JSON, skipping");
continue;
}
};
for event in parse_stream_frame(&value) {
if tx.send(event).await.is_err() {
return;
}
}
}
Ok(None) => {
debug!("reader: stdout EOF");
return;
}
Err(e) => {
warn!(error = %e, "reader: read error");
return;
}
}
}
}
async fn stderr_drain(stderr: tokio::process::ChildStderr) {
let mut lines = BufReader::new(stderr).lines();
while let Ok(Some(line)) = lines.next_line().await {
debug!(target: "cap_rs::stream_json::stderr", "{}", line);
}
}
fn encode_client_frame(frame: &ClientFrame) -> Result<String, DriverError> {
match frame {
ClientFrame::Prompt { content } => {
let parts: Vec<Value> = content
.iter()
.map(|c| match c {
Content::Text(t) => json!({"type": "text", "text": t}),
Content::Image { mime, data } => json!({
"type": "image",
"source": {
"type": "base64",
"media_type": mime,
"data": base64_encode(data),
}
}),
})
.collect();
let frame_json = json!({
"type": "user",
"message": {
"role": "user",
"content": parts
}
});
Ok(frame_json.to_string())
}
ClientFrame::Cancel => {
Ok(json!({"type": "control", "subtype": "cancel"}).to_string())
}
ClientFrame::AskUserAnswer { ask_id, value } => {
let text = format!("[answer to {ask_id}]: {value}");
Ok(json!({
"type": "user",
"message": {
"role": "user",
"content": [{"type": "text", "text": text}]
}
})
.to_string())
}
ClientFrame::PermissionResponse { req_id, decision } => {
let text = format!("[permission {req_id}]: {decision:?}");
Ok(json!({
"type": "user",
"message": {
"role": "user",
"content": [{"type": "text", "text": text}]
}
})
.to_string())
}
}
}
fn parse_stream_frame(frame: &Value) -> Vec<AgentEvent> {
let kind = frame.get("type").and_then(Value::as_str).unwrap_or("");
match kind {
"system" => match frame.get("subtype").and_then(Value::as_str).unwrap_or("") {
"init" => vec![AgentEvent::Ready {
session_id: frame
.get("session_id")
.and_then(Value::as_str)
.unwrap_or_default()
.to_string(),
model: frame
.get("model")
.and_then(Value::as_str)
.map(String::from),
}],
_ => vec![],
},
"assistant" => {
let msg = frame.get("message").cloned().unwrap_or(Value::Null);
let msg_id = msg
.get("id")
.and_then(Value::as_str)
.unwrap_or_default()
.to_string();
let content = msg
.get("content")
.and_then(Value::as_array)
.cloned()
.unwrap_or_default();
let mut events = Vec::new();
for block in content {
let btype = block.get("type").and_then(Value::as_str).unwrap_or("");
match btype {
"text" => {
let text = block
.get("text")
.and_then(Value::as_str)
.unwrap_or_default()
.to_string();
if !text.is_empty() {
events.push(AgentEvent::TextChunk {
msg_id: msg_id.clone(),
text,
channel: TextChannel::Assistant,
});
}
}
"thinking" => {
let text = block
.get("thinking")
.and_then(Value::as_str)
.or_else(|| block.get("text").and_then(Value::as_str))
.unwrap_or_default()
.to_string();
if !text.is_empty() {
events.push(AgentEvent::Thought {
msg_id: msg_id.clone(),
text,
});
}
}
"tool_use" => {
events.push(AgentEvent::ToolCallStart {
call_id: block
.get("id")
.and_then(Value::as_str)
.unwrap_or_default()
.to_string(),
name: block
.get("name")
.and_then(Value::as_str)
.unwrap_or_default()
.to_string(),
input: block.get("input").cloned().unwrap_or(Value::Null),
});
}
_ => {
trace!(block_type = btype, "ignoring unknown assistant block");
}
}
}
events
}
"user" => {
let content = frame
.get("message")
.and_then(|m| m.get("content"))
.and_then(Value::as_array)
.cloned()
.unwrap_or_default();
let mut events = Vec::new();
for block in content {
if block.get("type").and_then(Value::as_str) == Some("tool_result") {
let call_id = block
.get("tool_use_id")
.and_then(Value::as_str)
.unwrap_or_default()
.to_string();
let output = extract_tool_result_output(&block);
let is_error = block
.get("is_error")
.and_then(Value::as_bool)
.unwrap_or(false);
events.push(AgentEvent::ToolCallEnd {
call_id,
output,
is_error,
});
}
}
events
}
"result" => {
let usage = parse_usage(frame);
let stop_reason = match frame.get("subtype").and_then(Value::as_str) {
Some("success") => StopReason::EndTurn,
Some("error_max_turns") => StopReason::MaxTokens,
Some("error_during_execution") => StopReason::Error,
Some(other) if other.starts_with("error") => StopReason::Error,
_ => StopReason::EndTurn,
};
vec![AgentEvent::Done { stop_reason, usage }]
}
"stream_event" => {
vec![]
}
other => {
trace!(frame_type = other, "ignoring unknown stream-json frame");
vec![]
}
}
}
fn extract_tool_result_output(block: &Value) -> String {
match block.get("content") {
Some(Value::String(s)) => s.clone(),
Some(Value::Array(arr)) => arr
.iter()
.filter_map(|part| part.get("text").and_then(Value::as_str))
.collect::<Vec<_>>()
.join("\n"),
_ => String::new(),
}
}
fn parse_usage(frame: &Value) -> Usage {
let u = frame.get("usage").cloned().unwrap_or(Value::Null);
Usage {
input_tokens: u.get("input_tokens").and_then(Value::as_u64).unwrap_or(0),
output_tokens: u.get("output_tokens").and_then(Value::as_u64).unwrap_or(0),
cache_read_tokens: u
.get("cache_read_input_tokens")
.and_then(Value::as_u64)
.unwrap_or(0),
cache_creation_tokens: u
.get("cache_creation_input_tokens")
.and_then(Value::as_u64)
.unwrap_or(0),
cost_usd_estimate: frame.get("total_cost_usd").and_then(Value::as_f64),
duration: frame
.get("duration_ms")
.and_then(Value::as_u64)
.map(std::time::Duration::from_millis),
model_id: frame
.get("modelUsage")
.and_then(Value::as_object)
.and_then(|m| m.keys().next().cloned()),
}
}
fn base64_encode(data: &[u8]) -> String {
const T: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(((data.len() + 2) / 3) * 4);
let mut i = 0;
while i + 3 <= data.len() {
let b = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8) | (data[i + 2] as u32);
out.push(T[((b >> 18) & 63) as usize] as char);
out.push(T[((b >> 12) & 63) as usize] as char);
out.push(T[((b >> 6) & 63) as usize] as char);
out.push(T[(b & 63) as usize] as char);
i += 3;
}
let rem = data.len() - i;
if rem == 1 {
let b = (data[i] as u32) << 16;
out.push(T[((b >> 18) & 63) as usize] as char);
out.push(T[((b >> 12) & 63) as usize] as char);
out.push_str("==");
} else if rem == 2 {
let b = ((data[i] as u32) << 16) | ((data[i + 1] as u32) << 8);
out.push(T[((b >> 18) & 63) as usize] as char);
out.push(T[((b >> 12) & 63) as usize] as char);
out.push(T[((b >> 6) & 63) as usize] as char);
out.push('=');
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_init_frame() {
let v: Value = serde_json::from_str(
r#"{"type":"system","subtype":"init","session_id":"sess_1","model":"claude-opus-4-7"}"#,
)
.unwrap();
let events = parse_stream_frame(&v);
assert_eq!(events.len(), 1);
assert!(matches!(events[0], AgentEvent::Ready { .. }));
}
#[test]
fn parse_assistant_text() {
let v: Value = serde_json::from_str(
r#"{"type":"assistant","message":{"id":"msg_1","content":[{"type":"text","text":"hello"}]}}"#,
)
.unwrap();
let events = parse_stream_frame(&v);
assert_eq!(events.len(), 1);
match &events[0] {
AgentEvent::TextChunk { text, .. } => assert_eq!(text, "hello"),
other => panic!("wrong variant: {other:?}"),
}
}
#[test]
fn parse_tool_use() {
let v: Value = serde_json::from_str(
r#"{"type":"assistant","message":{"id":"m","content":[
{"type":"tool_use","id":"t1","name":"Bash","input":{"command":"ls"}}
]}}"#,
)
.unwrap();
let events = parse_stream_frame(&v);
match &events[0] {
AgentEvent::ToolCallStart { name, .. } => assert_eq!(name, "Bash"),
other => panic!("wrong: {other:?}"),
}
}
#[test]
fn parse_result_with_usage() {
let v: Value = serde_json::from_str(
r#"{"type":"result","subtype":"success","duration_ms":1500,"total_cost_usd":0.0021,
"usage":{"input_tokens":10,"output_tokens":20,"cache_read_input_tokens":0,"cache_creation_input_tokens":0}}"#,
)
.unwrap();
let events = parse_stream_frame(&v);
match &events[0] {
AgentEvent::Done { usage, stop_reason } => {
assert_eq!(*stop_reason, StopReason::EndTurn);
assert_eq!(usage.input_tokens, 10);
assert_eq!(usage.output_tokens, 20);
assert_eq!(usage.cost_usd_estimate, Some(0.0021));
}
other => panic!("wrong: {other:?}"),
}
}
#[test]
fn encode_simple_prompt() {
let frame = ClientFrame::Prompt {
content: vec![Content::Text("hi".into())],
};
let line = encode_client_frame(&frame).unwrap();
let v: Value = serde_json::from_str(&line).unwrap();
assert_eq!(v["type"], "user");
assert_eq!(v["message"]["content"][0]["text"], "hi");
}
}