1use std::path::PathBuf;
2
3use async_trait::async_trait;
4
5use crate::config::{PermissionMode, TaskConfig};
6use crate::error::{Error, Result};
7use crate::event::*;
8use crate::process::{spawn_and_stream, StreamHandle};
9use crate::runner::AgentRunner;
10
11pub struct ClaudeRunner;
22
23#[async_trait]
24impl AgentRunner for ClaudeRunner {
25 fn name(&self) -> &str {
26 "claude"
27 }
28
29 fn is_available(&self) -> bool {
30 crate::runner::is_any_binary_available(crate::config::AgentKind::Claude)
31 }
32
33 fn binary_path(&self, config: &TaskConfig) -> Result<PathBuf> {
34 crate::runner::resolve_binary(crate::config::AgentKind::Claude, config)
35 }
36
37 fn build_args(&self, config: &TaskConfig) -> Vec<String> {
38 let mut args = vec![
39 "-p".to_string(),
40 config.prompt.clone(),
41 "--output-format".to_string(),
42 "stream-json".to_string(),
43 "--verbose".to_string(),
44 ];
45
46 if let Some(ref model) = config.model {
47 args.push("--model".to_string());
48 args.push(model.clone());
49 }
50
51 match config.permission_mode {
52 PermissionMode::FullAccess => {
53 args.push("--dangerously-skip-permissions".to_string());
54 }
55 PermissionMode::ReadOnly => {
56 args.push("--permission-mode".to_string());
57 args.push("plan".to_string());
58 }
59 }
60
61 if let Some(turns) = config.max_turns {
62 args.push("--max-turns".to_string());
63 args.push(turns.to_string());
64 }
65
66 if let Some(budget) = config.max_budget_usd {
67 args.push("--max-budget-usd".to_string());
68 args.push(budget.to_string());
69 }
70
71 if let Some(ref sp) = config.system_prompt {
72 args.push("--system-prompt".to_string());
73 args.push(sp.clone());
74 }
75
76 if let Some(ref asp) = config.append_system_prompt {
77 args.push("--append-system-prompt".to_string());
78 args.push(asp.clone());
79 }
80
81 args.extend(config.extra_args.iter().cloned());
82 args
83 }
84
85 fn build_env(&self, _config: &TaskConfig) -> Vec<(String, String)> {
86 vec![]
88 }
89
90 async fn run(
91 &self,
92 config: &TaskConfig,
93 cancel_token: Option<tokio_util::sync::CancellationToken>,
94 ) -> Result<StreamHandle> {
95 spawn_and_stream(self, config, parse_claude_line, cancel_token).await
96 }
97
98 fn capabilities(&self) -> crate::runner::AgentCapabilities {
99 crate::runner::AgentCapabilities {
100 supports_system_prompt: true,
101 supports_budget: true,
102 supports_model: true,
103 supports_max_turns: true,
104 supports_append_system_prompt: true,
105 }
106 }
107}
108
109fn parse_claude_line(line: &str) -> Vec<Result<Event>> {
110 let value: serde_json::Value = match serde_json::from_str(line) {
111 Ok(v) => v,
112 Err(e) => return vec![Err(Error::ParseError(format!("invalid JSON: {e}: {line}")))],
113 };
114
115 let event_type = match value.get("type").and_then(|v| v.as_str()) {
116 Some(t) => t,
117 None => return vec![],
118 };
119
120 match event_type {
121 "system" => {
122 let subtype = value.get("subtype").and_then(|v| v.as_str()).unwrap_or("");
123 if subtype == "init" {
124 vec![Ok(Event::SessionStart(SessionStartEvent {
125 session_id: value
126 .get("session_id")
127 .and_then(|v| v.as_str())
128 .unwrap_or("")
129 .to_string(),
130 agent: "claude".to_string(),
131 model: value
132 .get("model")
133 .and_then(|v| v.as_str())
134 .map(|s| s.to_string()),
135 cwd: value
136 .get("cwd")
137 .and_then(|v| v.as_str())
138 .map(|s| s.to_string()),
139 timestamp_ms: 0,
140 }))]
141 } else {
142 vec![]
143 }
144 }
145
146 "assistant" => {
147 let mut events = Vec::new();
150
151 let content = value.pointer("/message/content").and_then(|v| v.as_array());
152 if let Some(blocks) = content {
153 let mut text_parts = Vec::new();
154
155 for block in blocks {
156 let block_type = block.get("type").and_then(|v| v.as_str()).unwrap_or("");
157 match block_type {
158 "text" => {
159 if let Some(t) = block.get("text").and_then(|v| v.as_str()) {
160 text_parts.push(t);
161 }
162 }
163 "tool_use" => {
164 let call_id = block
165 .get("id")
166 .and_then(|v| v.as_str())
167 .unwrap_or("")
168 .to_string();
169 let tool_name = block
170 .get("name")
171 .and_then(|v| v.as_str())
172 .unwrap_or("unknown")
173 .to_string();
174 let input = block.get("input").cloned();
175 events.push(Ok(Event::ToolStart(ToolStartEvent {
176 call_id,
177 tool_name,
178 input,
179 timestamp_ms: 0,
180 })));
181 }
182 _ => {}
183 }
184 }
185
186 let text = text_parts.join("");
187 if !text.is_empty() {
188 events.insert(
190 0,
191 Ok(Event::Message(MessageEvent {
192 role: Role::Assistant,
193 text,
194 usage: None,
195 timestamp_ms: 0,
196 })),
197 );
198 }
199 }
200
201 events
202 }
203
204 "user" => {
205 let mut events = Vec::new();
207
208 let content = value.pointer("/message/content").and_then(|v| v.as_array());
209 if let Some(blocks) = content {
210 for block in blocks {
211 let block_type = block.get("type").and_then(|v| v.as_str()).unwrap_or("");
212 if block_type == "tool_result" {
213 let call_id = block
214 .get("tool_use_id")
215 .and_then(|v| v.as_str())
216 .unwrap_or("")
217 .to_string();
218 let is_error = block
219 .get("is_error")
220 .and_then(|v| v.as_bool())
221 .unwrap_or(false);
222 let output = block
223 .get("content")
224 .map(|v| {
225 if let Some(s) = v.as_str() {
226 s.to_string()
227 } else if let Some(arr) = v.as_array() {
228 arr.iter()
229 .filter_map(|item| {
230 item.get("text").and_then(|t| t.as_str())
231 })
232 .collect::<Vec<_>>()
233 .join("")
234 } else {
235 v.to_string()
236 }
237 });
238 events.push(Ok(Event::ToolEnd(ToolEndEvent {
239 call_id,
240 tool_name: "unknown".to_string(),
241 success: !is_error,
242 output,
243 usage: None,
244 timestamp_ms: 0,
245 })));
246 }
247 }
248 }
249
250 events
251 }
252
253 "stream_event" => {
254 let mut events = Vec::new();
255
256 let delta_text = value
258 .pointer("/event/delta/text")
259 .and_then(|v| v.as_str())
260 .unwrap_or("");
261 if !delta_text.is_empty() {
262 events.push(Ok(Event::TextDelta(TextDeltaEvent {
263 text: delta_text.to_string(),
264 timestamp_ms: 0,
265 })));
266 }
267
268 if let Some(usage_val) = value.pointer("/event/usage").or_else(|| value.get("usage")) {
270 let usage = parse_usage_data(usage_val);
271 if usage.input_tokens.is_some() || usage.output_tokens.is_some() || usage.cost_usd.is_some() {
272 events.push(Ok(Event::UsageDelta(UsageDeltaEvent {
273 usage,
274 timestamp_ms: 0,
275 })));
276 }
277 }
278
279 events
280 }
281
282 "result" => {
283 let subtype = value
284 .get("subtype")
285 .and_then(|v| v.as_str())
286 .unwrap_or("success");
287 let success = subtype == "success";
288 let result_text = value
289 .get("result")
290 .and_then(|v| v.as_str())
291 .unwrap_or("")
292 .to_string();
293 let session_id = value
294 .get("session_id")
295 .and_then(|v| v.as_str())
296 .unwrap_or("")
297 .to_string();
298 let duration_ms = value.get("duration_ms").and_then(|v| v.as_u64());
299 let total_cost_usd = value.get("total_cost_usd").and_then(|v| v.as_f64());
300
301 let usage = value.get("usage").map(parse_usage_data);
302
303 vec![Ok(Event::Result(ResultEvent {
304 success,
305 text: result_text,
306 session_id,
307 duration_ms,
308 total_cost_usd,
309 usage,
310 timestamp_ms: 0,
311 }))]
312 }
313
314 _ => vec![],
315 }
316}
317
318fn parse_usage_data(value: &serde_json::Value) -> UsageData {
319 UsageData {
320 input_tokens: value.get("input_tokens").and_then(|v| v.as_u64()),
321 output_tokens: value.get("output_tokens").and_then(|v| v.as_u64()),
322 cache_read_tokens: value
323 .get("cache_read_input_tokens")
324 .or_else(|| value.get("cache_read_tokens"))
325 .and_then(|v| v.as_u64()),
326 cache_creation_tokens: value
327 .get("cache_creation_input_tokens")
328 .or_else(|| value.get("cache_creation_tokens"))
329 .and_then(|v| v.as_u64()),
330 cost_usd: value
331 .get("cost_usd")
332 .or_else(|| value.get("cost"))
333 .and_then(|v| v.as_f64()),
334 }
335}
336
337#[cfg(test)]
338mod tests {
339 use super::*;
340
341 #[test]
342 fn parse_init_event() {
343 let line = r#"{"type":"system","subtype":"init","session_id":"abc-123","model":"opus","cwd":"/tmp"}"#;
344 let events = parse_claude_line(line);
345 assert_eq!(events.len(), 1);
346 let event = events.into_iter().next().unwrap().unwrap();
347 match event {
348 Event::SessionStart(s) => {
349 assert_eq!(s.session_id, "abc-123");
350 assert_eq!(s.agent, "claude");
351 assert_eq!(s.model, Some("opus".into()));
352 assert_eq!(s.cwd, Some("/tmp".into()));
353 }
354 other => panic!("expected SessionStart, got {other:?}"),
355 }
356 }
357
358 #[test]
359 fn parse_assistant_message() {
360 let line = r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Hello world"}]}}"#;
361 let events = parse_claude_line(line);
362 assert_eq!(events.len(), 1);
363 let event = events.into_iter().next().unwrap().unwrap();
364 match event {
365 Event::Message(m) => {
366 assert_eq!(m.role, Role::Assistant);
367 assert_eq!(m.text, "Hello world");
368 }
369 other => panic!("expected Message, got {other:?}"),
370 }
371 }
372
373 #[test]
374 fn parse_assistant_with_tool_use() {
375 let line = r#"{"type":"assistant","message":{"role":"assistant","content":[{"type":"text","text":"Let me check"},{"type":"tool_use","id":"tu-1","name":"bash","input":{"command":"ls"}}]}}"#;
376 let events = parse_claude_line(line);
377 assert_eq!(events.len(), 2);
378 assert!(matches!(&events[0], Ok(Event::Message(m)) if m.text == "Let me check"));
379 assert!(matches!(&events[1], Ok(Event::ToolStart(t)) if t.tool_name == "bash" && t.call_id == "tu-1"));
380 }
381
382 #[test]
383 fn parse_user_tool_result() {
384 let line = r#"{"type":"user","message":{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu-1","content":"file.txt\nREADME.md"}]}}"#;
385 let events = parse_claude_line(line);
386 assert_eq!(events.len(), 1);
387 match events.into_iter().next().unwrap().unwrap() {
388 Event::ToolEnd(t) => {
389 assert_eq!(t.call_id, "tu-1");
390 assert!(t.success);
391 assert_eq!(t.output, Some("file.txt\nREADME.md".into()));
392 }
393 other => panic!("expected ToolEnd, got {other:?}"),
394 }
395 }
396
397 #[test]
398 fn parse_user_tool_result_error() {
399 let line = r#"{"type":"user","message":{"role":"user","content":[{"type":"tool_result","tool_use_id":"tu-2","is_error":true,"content":"command not found"}]}}"#;
400 let events = parse_claude_line(line);
401 assert_eq!(events.len(), 1);
402 match events.into_iter().next().unwrap().unwrap() {
403 Event::ToolEnd(t) => {
404 assert_eq!(t.call_id, "tu-2");
405 assert!(!t.success);
406 }
407 other => panic!("expected ToolEnd, got {other:?}"),
408 }
409 }
410
411 #[test]
412 fn parse_stream_delta() {
413 let line = r#"{"type":"stream_event","event":{"delta":{"type":"text_delta","text":"Hi"}}}"#;
414 let events = parse_claude_line(line);
415 assert_eq!(events.len(), 1);
416 let event = events.into_iter().next().unwrap().unwrap();
417 match event {
418 Event::TextDelta(d) => assert_eq!(d.text, "Hi"),
419 other => panic!("expected TextDelta, got {other:?}"),
420 }
421 }
422
423 #[test]
424 fn parse_result_success() {
425 let line = r#"{"type":"result","subtype":"success","result":"Done","session_id":"s1","duration_ms":1234,"total_cost_usd":0.05}"#;
426 let events = parse_claude_line(line);
427 assert_eq!(events.len(), 1);
428 let event = events.into_iter().next().unwrap().unwrap();
429 match event {
430 Event::Result(r) => {
431 assert!(r.success);
432 assert_eq!(r.text, "Done");
433 assert_eq!(r.session_id, "s1");
434 assert_eq!(r.duration_ms, Some(1234));
435 assert_eq!(r.total_cost_usd, Some(0.05));
436 }
437 other => panic!("expected Result, got {other:?}"),
438 }
439 }
440
441 #[test]
442 fn parse_result_error() {
443 let line =
444 r#"{"type":"result","subtype":"error_max_turns","result":"","session_id":"s1"}"#;
445 let events = parse_claude_line(line);
446 assert_eq!(events.len(), 1);
447 match events.into_iter().next().unwrap().unwrap() {
448 Event::Result(r) => assert!(!r.success),
449 other => panic!("expected Result, got {other:?}"),
450 }
451 }
452
453 #[test]
454 fn build_args_defaults() {
455 let config = TaskConfig::new("fix the bug", crate::config::AgentKind::Claude);
456 let runner = ClaudeRunner;
457 let args = runner.build_args(&config);
458 assert!(args.contains(&"-p".to_string()));
459 assert!(args.contains(&"fix the bug".to_string()));
460 assert!(args.contains(&"stream-json".to_string()));
461 }
462
463 #[test]
464 fn build_args_with_model_and_full_access() {
465 let mut config = TaskConfig::new("do it", crate::config::AgentKind::Claude);
466 config.model = Some("opus".into());
467 config.max_turns = Some(10);
468
469 let runner = ClaudeRunner;
470 let args = runner.build_args(&config);
471 assert!(args.contains(&"--model".to_string()));
472 assert!(args.contains(&"opus".to_string()));
473 assert!(args.contains(&"--dangerously-skip-permissions".to_string()));
474 assert!(args.contains(&"--max-turns".to_string()));
475 assert!(args.contains(&"10".to_string()));
476 }
477}