1use std::io::Write;
2use std::process::Stdio;
3
4use crate::error::provider_error::ApiSnafu;
5use crate::error::ProviderError;
6use crate::provider::{
7 AuthStatus, CompletionRequest, CompletionResponse, ContentBlock, LlmProvider, StopReason,
8 TokenUsage,
9};
10
11pub struct ClaudeCodeProvider {
16 model: Option<String>,
17}
18
19impl ClaudeCodeProvider {
20 pub fn new(model: Option<String>) -> Self {
21 Self { model }
22 }
23}
24
25#[derive(Debug, Clone)]
27struct ExtractedToolCall {
28 name: String,
29 input: serde_json::Value,
30 start: usize,
32 end: usize,
34}
35
36fn extract_tool_calls(text: &str) -> Vec<ExtractedToolCall> {
41 let mut results = Vec::new();
42 let bytes = text.as_bytes();
43 let mut pos = 0;
44
45 while pos < bytes.len() {
46 if bytes[pos] != b'{' {
47 pos += 1;
48 continue;
49 }
50
51 let slice = &text[pos..];
52 let mut de = serde_json::Deserializer::from_str(slice).into_iter::<serde_json::Value>();
53
54 if let Some(Ok(value)) = de.next() {
55 let consumed = de.byte_offset();
56 if let Some(obj) = value.as_object() {
57 if obj.contains_key("tool") && obj.contains_key("input") {
58 if let (Some(name), Some(input)) =
59 (obj.get("tool").and_then(|v| v.as_str()), obj.get("input"))
60 {
61 results.push(ExtractedToolCall {
62 name: name.to_string(),
63 input: input.clone(),
64 start: pos,
65 end: pos + consumed,
66 });
67 pos += consumed;
69 continue;
70 }
71 }
72 }
73 pos += 1;
75 } else {
76 pos += 1;
77 }
78 }
79
80 results
81}
82
83fn first_batch(calls: &[ExtractedToolCall], text: &str) -> Vec<ExtractedToolCall> {
90 if calls.is_empty() {
91 return Vec::new();
92 }
93
94 let mut batch = vec![calls[0].clone()];
95
96 for window in calls.windows(2) {
97 let prev = &window[0];
98 let next = &window[1];
99
100 let gap = &text[prev.end..next.start];
101 let non_ws = gap.chars().filter(|c| !c.is_whitespace()).count();
102
103 if non_ws > 40 {
104 break;
105 }
106 batch.push(next.clone());
107 }
108
109 batch
110}
111
112fn build_content_blocks(
115 text: &str,
116 batch: &[ExtractedToolCall],
117 counter: &mut u32,
118) -> (Vec<ContentBlock>, StopReason) {
119 if batch.is_empty() {
120 return (
121 vec![ContentBlock::Text {
122 text: text.to_string(),
123 }],
124 StopReason::EndTurn,
125 );
126 }
127
128 let mut blocks = Vec::new();
129
130 let leading = text[..batch[0].start].trim();
132 if !leading.is_empty() {
133 blocks.push(ContentBlock::Text {
134 text: leading.to_string(),
135 });
136 }
137
138 for call in batch {
139 *counter += 1;
140 blocks.push(ContentBlock::ToolUse {
141 id: format!("toolu_cc_{counter}"),
142 name: call.name.clone(),
143 input: call.input.clone(),
144 });
145 }
146
147 (blocks, StopReason::ToolUse)
148}
149
150fn build_prompt(request: &CompletionRequest) -> String {
152 let mut prompt = String::new();
153
154 if !request.system.is_empty() {
155 prompt.push_str("System: ");
156 prompt.push_str(&request.system);
157 prompt.push_str("\n\n");
158 }
159
160 if !request.tools.is_empty() {
162 prompt.push_str("Available tools:\n");
163 for tool in &request.tools {
164 prompt.push_str(&format!("- {}: {}\n", tool.name, tool.description));
165 prompt.push_str(&format!(
166 " Input schema: {}\n",
167 serde_json::to_string(&tool.input_schema).unwrap_or_default()
168 ));
169 }
170 prompt.push_str(
171 "\nTo use a tool, output a JSON block with {\"tool\": \"name\", \"input\": {...}}\n\n",
172 );
173 }
174
175 for msg in &request.messages {
177 let role = match msg.role {
178 crate::provider::Role::User => "User",
179 crate::provider::Role::Assistant => "Assistant",
180 };
181 for block in &msg.content {
182 match block {
183 ContentBlock::Text { text } => {
184 prompt.push_str(&format!("{role}: {text}\n\n"));
185 }
186 ContentBlock::ToolUse { name, input, .. } => {
187 prompt.push_str(&format!(
188 "{role}: [tool_use: {} {}]\n\n",
189 name,
190 serde_json::to_string(input).unwrap_or_default()
191 ));
192 }
193 ContentBlock::ToolResult {
194 content, is_error, ..
195 } => {
196 let prefix = if *is_error == Some(true) {
197 "Error"
198 } else {
199 "Result"
200 };
201 prompt.push_str(&format!("{role}: [tool_result: {prefix}] {content}\n\n"));
202 }
203 }
204 }
205 }
206
207 prompt
208}
209
210impl LlmProvider for ClaudeCodeProvider {
211 fn complete(&self, request: &CompletionRequest) -> Result<CompletionResponse, ProviderError> {
212 let prompt = build_prompt(request);
213
214 let mut cmd = std::process::Command::new("claude");
216 cmd.arg("--print");
217
218 if let Some(ref model) = self.model {
219 cmd.arg("--model");
220 cmd.arg(model);
221 }
222
223 let mut child = cmd
224 .stdin(Stdio::piped())
225 .stdout(Stdio::piped())
226 .stderr(Stdio::piped())
227 .spawn()
228 .map_err(|e| {
229 if e.kind() == std::io::ErrorKind::NotFound {
230 ProviderError::Api {
231 message: "Claude CLI not found. Install Claude Code or run 'git chronicle reconfigure' to select a different provider.".to_string(),
232 location: snafu::Location::default(),
233 }
234 } else {
235 ProviderError::Api {
236 message: format!("Failed to spawn claude CLI: {e}"),
237 location: snafu::Location::default(),
238 }
239 }
240 })?;
241
242 if let Some(mut stdin) = child.stdin.take() {
244 stdin
245 .write_all(prompt.as_bytes())
246 .map_err(|e| ProviderError::Api {
247 message: format!("Failed to write to claude CLI stdin: {e}"),
248 location: snafu::Location::default(),
249 })?;
250 }
251
252 let output = child.wait_with_output().map_err(|e| ProviderError::Api {
253 message: format!("Failed to wait for claude CLI: {e}"),
254 location: snafu::Location::default(),
255 })?;
256
257 if !output.status.success() {
258 let stderr = String::from_utf8_lossy(&output.stderr);
259 return ApiSnafu {
260 message: format!("claude CLI failed: {stderr}"),
261 }
262 .fail();
263 }
264
265 let response_text = String::from_utf8_lossy(&output.stdout).to_string();
266
267 let all_calls = extract_tool_calls(&response_text);
269 let batch = first_batch(&all_calls, &response_text);
270 let mut counter = 0u32;
271 let (content, stop_reason) = build_content_blocks(&response_text, &batch, &mut counter);
272
273 Ok(CompletionResponse {
274 content,
275 stop_reason,
276 usage: TokenUsage::default(),
277 })
278 }
279
280 fn check_auth(&self) -> Result<AuthStatus, ProviderError> {
281 match std::process::Command::new("claude")
282 .arg("--version")
283 .output()
284 {
285 Ok(output) if output.status.success() => Ok(AuthStatus::Valid),
286 Ok(_) => Ok(AuthStatus::Invalid(
287 "claude CLI returned non-zero exit code".to_string(),
288 )),
289 Err(e) => Ok(AuthStatus::Invalid(format!("claude CLI not found: {e}"))),
290 }
291 }
292
293 fn name(&self) -> &str {
294 "claude-code"
295 }
296
297 fn model(&self) -> &str {
298 self.model
299 .as_deref()
300 .unwrap_or("claude-sonnet-4-5-20250929")
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_extract_tool_calls_basic() {
310 let text = r#"I'll get the diff now.
311{"tool": "get_diff", "input": {}}"#;
312 let calls = extract_tool_calls(text);
313 assert_eq!(calls.len(), 1);
314 assert_eq!(calls[0].name, "get_diff");
315 assert_eq!(calls[0].input, serde_json::json!({}));
316 }
317
318 #[test]
319 fn test_extract_tool_calls_nested() {
320 let text = r#"{"tool": "emit_narrative", "input": {"summary": "Refactored auth", "rejected_alternatives": [{"approach": "JWT", "reason": "overkill"}]}}"#;
321 let calls = extract_tool_calls(text);
322 assert_eq!(calls.len(), 1);
323 assert_eq!(calls[0].name, "emit_narrative");
324 let input = &calls[0].input;
325 assert_eq!(input["summary"], "Refactored auth");
326 assert_eq!(input["rejected_alternatives"][0]["approach"], "JWT");
327 }
328
329 #[test]
330 fn test_extract_tool_calls_multiple() {
331 let text = r#"Let me gather info.
332{"tool": "get_diff", "input": {}}
333{"tool": "get_commit_info", "input": {}}"#;
334 let calls = extract_tool_calls(text);
335 assert_eq!(calls.len(), 2);
336 assert_eq!(calls[0].name, "get_diff");
337 assert_eq!(calls[1].name, "get_commit_info");
338 }
339
340 #[test]
341 fn test_first_batch_stops_at_prose() {
342 let text = r#"{"tool": "get_diff", "input": {}}
345{"tool": "get_commit_info", "input": {}}
346
347Okay, now I can see the diff shows a refactored authentication module with several important changes to the token validation flow.
348
349{"tool": "emit_narrative", "input": {"summary": "test"}}"#;
350 let calls = extract_tool_calls(text);
351 assert_eq!(calls.len(), 3);
352
353 let batch = first_batch(&calls, text);
354 assert_eq!(batch.len(), 2);
355 assert_eq!(batch[0].name, "get_diff");
356 assert_eq!(batch[1].name, "get_commit_info");
357 }
358
359 #[test]
360 fn test_no_tool_calls() {
361 let text = "This is just a plain text response with no tool calls at all.";
362 let calls = extract_tool_calls(text);
363 assert!(calls.is_empty());
364 }
365
366 #[test]
367 fn test_ignores_non_tool_json() {
368 let text = r#"Here is some JSON: {"name": "foo", "value": 42} and more text.
369{"tool": "get_diff", "input": {}}"#;
370 let calls = extract_tool_calls(text);
371 assert_eq!(calls.len(), 1);
372 assert_eq!(calls[0].name, "get_diff");
373 }
374
375 #[test]
376 fn test_realistic_output() {
377 let text = r#"I'll analyze this commit. Let me start by getting the diff and commit info.
380
381{"tool": "get_diff", "input": {}}
382{"tool": "get_commit_info", "input": {}}
383
384Here's the diff output showing the changes:
385```
386--- a/src/lib.rs
387+++ b/src/lib.rs
388@@ -1,5 +1,10 @@
389+use serde::Serialize;
390```
391
392And the commit info:
393SHA: abc123
394Message: Add serialization support
395
396Now I'll emit the narrative and a decision.
397
398{"tool": "emit_narrative", "input": {"summary": "Added serde serialization to core types", "motivation": "Needed for JSON export feature"}}
399{"tool": "emit_decision", "input": {"what": "Use serde for serialization", "why": "Industry standard", "stability": "permanent"}}"#;
400
401 let calls = extract_tool_calls(text);
402 assert_eq!(calls.len(), 4);
403
404 let batch = first_batch(&calls, text);
405 assert_eq!(batch.len(), 2);
408 assert_eq!(batch[0].name, "get_diff");
409 assert_eq!(batch[1].name, "get_commit_info");
410 }
411
412 #[test]
413 fn test_build_content_blocks_no_calls() {
414 let text = "Just a plain response.";
415 let batch = Vec::new();
416 let mut counter = 0;
417 let (blocks, reason) = build_content_blocks(text, &batch, &mut counter);
418 assert_eq!(blocks.len(), 1);
419 assert!(
420 matches!(&blocks[0], ContentBlock::Text { text } if text == "Just a plain response.")
421 );
422 assert_eq!(reason, StopReason::EndTurn);
423 }
424
425 #[test]
426 fn test_build_content_blocks_with_calls() {
427 let text = r#"Let me check.
428{"tool": "get_diff", "input": {}}
429{"tool": "get_commit_info", "input": {}}"#;
430 let calls = extract_tool_calls(text);
431 let batch = first_batch(&calls, text);
432 let mut counter = 0;
433 let (blocks, reason) = build_content_blocks(text, &batch, &mut counter);
434
435 assert_eq!(blocks.len(), 3);
437 assert!(matches!(&blocks[0], ContentBlock::Text { text } if text == "Let me check."));
438 assert!(
439 matches!(&blocks[1], ContentBlock::ToolUse { id, name, .. } if id == "toolu_cc_1" && name == "get_diff")
440 );
441 assert!(
442 matches!(&blocks[2], ContentBlock::ToolUse { id, name, .. } if id == "toolu_cc_2" && name == "get_commit_info")
443 );
444 assert_eq!(reason, StopReason::ToolUse);
445 }
446
447 #[test]
448 fn test_build_prompt_includes_system_and_tools() {
449 use crate::provider::{Message, Role, ToolDefinition};
450
451 let request = CompletionRequest {
452 system: "You are a test assistant.".to_string(),
453 messages: vec![Message {
454 role: Role::User,
455 content: vec![ContentBlock::Text {
456 text: "Hello".to_string(),
457 }],
458 }],
459 tools: vec![ToolDefinition {
460 name: "get_diff".to_string(),
461 description: "Get the diff.".to_string(),
462 input_schema: serde_json::json!({"type": "object"}),
463 }],
464 max_tokens: 4096,
465 };
466
467 let prompt = build_prompt(&request);
468 assert!(prompt.contains("System: You are a test assistant."));
469 assert!(prompt.contains("get_diff"));
470 assert!(prompt.contains("User: Hello"));
471 }
472}