1use crate::providers::traits::{ChatMessage, ChatRequest, ChatResponse, Provider, TokenUsage};
35use async_trait::async_trait;
36use std::path::PathBuf;
37use tokio::io::AsyncWriteExt;
38use tokio::process::Command;
39use tokio::time::{Duration, timeout};
40
41pub const CLAUDE_CODE_PATH_ENV: &str = "CLAUDE_CODE_PATH";
43
44const DEFAULT_CLAUDE_CODE_BINARY: &str = "claude";
46
47const DEFAULT_MODEL_MARKER: &str = "default";
49const CLAUDE_CODE_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
51const MAX_CLAUDE_CODE_STDERR_CHARS: usize = 512;
53const CLAUDE_CODE_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
55const TEMP_EPSILON: f64 = 1e-9;
56
57pub struct ClaudeCodeProvider {
62 binary_path: PathBuf,
64}
65
66impl ClaudeCodeProvider {
67 pub fn new() -> Self {
72 let binary_path = std::env::var(CLAUDE_CODE_PATH_ENV)
73 .ok()
74 .filter(|path| !path.trim().is_empty())
75 .map(PathBuf::from)
76 .unwrap_or_else(|| PathBuf::from(DEFAULT_CLAUDE_CODE_BINARY));
77
78 Self { binary_path }
79 }
80
81 fn should_forward_model(model: &str) -> bool {
83 let trimmed = model.trim();
84 !trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
85 }
86
87 fn supports_temperature(temperature: f64) -> bool {
88 CLAUDE_CODE_SUPPORTED_TEMPERATURES
89 .iter()
90 .any(|v| (temperature - v).abs() < TEMP_EPSILON)
91 }
92
93 fn validate_temperature(temperature: f64) -> anyhow::Result<f64> {
94 if !temperature.is_finite() {
95 anyhow::bail!("Claude Code provider received non-finite temperature value");
96 }
97 if Self::supports_temperature(temperature) {
98 return Ok(temperature);
99 }
100 let clamped = *CLAUDE_CODE_SUPPORTED_TEMPERATURES
104 .iter()
105 .min_by(|a, b| {
106 (temperature - **a)
107 .abs()
108 .partial_cmp(&(temperature - **b).abs())
109 .unwrap()
110 })
111 .unwrap();
112 tracing::debug!(
113 requested = temperature,
114 clamped = clamped,
115 "Clamped unsupported temperature to nearest Claude Code CLI value"
116 );
117 Ok(clamped)
118 }
119
120 fn redact_stderr(stderr: &[u8]) -> String {
121 let text = String::from_utf8_lossy(stderr);
122 let trimmed = text.trim();
123 if trimmed.is_empty() {
124 return String::new();
125 }
126 if trimmed.chars().count() <= MAX_CLAUDE_CODE_STDERR_CHARS {
127 return trimmed.to_string();
128 }
129 let clipped: String = trimmed.chars().take(MAX_CLAUDE_CODE_STDERR_CHARS).collect();
130 format!("{clipped}...")
131 }
132
133 async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
136 let mut cmd = Command::new(&self.binary_path);
137 cmd.arg("--print");
138
139 if Self::should_forward_model(model) {
140 cmd.arg("--model").arg(model);
141 }
142
143 cmd.arg("-");
145 cmd.kill_on_drop(true);
146 cmd.stdin(std::process::Stdio::piped());
147 cmd.stdout(std::process::Stdio::piped());
148 cmd.stderr(std::process::Stdio::piped());
149
150 let mut child = cmd.spawn().map_err(|err| {
151 anyhow::anyhow!(
152 "Failed to spawn Claude Code binary at {}: {err}. \
153 Ensure `claude` is installed and in PATH, or set CLAUDE_CODE_PATH.",
154 self.binary_path.display()
155 )
156 })?;
157
158 if let Some(mut stdin) = child.stdin.take() {
159 stdin.write_all(message.as_bytes()).await.map_err(|err| {
160 anyhow::anyhow!("Failed to write prompt to Claude Code stdin: {err}")
161 })?;
162 stdin.shutdown().await.map_err(|err| {
163 anyhow::anyhow!("Failed to finalize Claude Code stdin stream: {err}")
164 })?;
165 }
166
167 let output = timeout(CLAUDE_CODE_REQUEST_TIMEOUT, child.wait_with_output())
168 .await
169 .map_err(|_| {
170 anyhow::anyhow!(
171 "Claude Code request timed out after {:?} (binary: {})",
172 CLAUDE_CODE_REQUEST_TIMEOUT,
173 self.binary_path.display()
174 )
175 })?
176 .map_err(|err| anyhow::anyhow!("Claude Code process failed: {err}"))?;
177
178 if !output.status.success() {
179 let code = output.status.code().unwrap_or(-1);
180 let stderr_excerpt = Self::redact_stderr(&output.stderr);
181 let stderr_note = if stderr_excerpt.is_empty() {
182 String::new()
183 } else {
184 format!(" Stderr: {stderr_excerpt}")
185 };
186 anyhow::bail!(
187 "Claude Code exited with non-zero status {code}. \
188 Check that Claude Code is authenticated and the CLI is supported.{stderr_note}"
189 );
190 }
191
192 let text = String::from_utf8(output.stdout)
193 .map_err(|err| anyhow::anyhow!("Claude Code produced non-UTF-8 output: {err}"))?;
194
195 Ok(text.trim().to_string())
196 }
197}
198
199impl Default for ClaudeCodeProvider {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205#[async_trait]
206impl Provider for ClaudeCodeProvider {
207 async fn chat_with_system(
208 &self,
209 system_prompt: Option<&str>,
210 message: &str,
211 model: &str,
212 temperature: f64,
213 ) -> anyhow::Result<String> {
214 Self::validate_temperature(temperature)?;
215
216 let full_message = match system_prompt {
217 Some(system) if !system.is_empty() => {
218 format!("{system}\n\n{message}")
219 }
220 _ => message.to_string(),
221 };
222
223 self.invoke_cli(&full_message, model).await
224 }
225
226 async fn chat_with_history(
227 &self,
228 messages: &[ChatMessage],
229 model: &str,
230 temperature: f64,
231 ) -> anyhow::Result<String> {
232 Self::validate_temperature(temperature)?;
233
234 let system = messages
236 .iter()
237 .find(|m| m.role == "system")
238 .map(|m| m.content.as_str());
239
240 let turns: Vec<&ChatMessage> = messages.iter().filter(|m| m.role != "system").collect();
242
243 if turns.len() <= 1 {
245 let last_user = turns.first().map(|m| m.content.as_str()).unwrap_or("");
246 let full_message = match system {
247 Some(s) if !s.is_empty() => format!("{s}\n\n{last_user}"),
248 _ => last_user.to_string(),
249 };
250 return self.invoke_cli(&full_message, model).await;
251 }
252
253 let mut parts = Vec::new();
255 if let Some(s) = system {
256 if !s.is_empty() {
257 parts.push(format!("[system]\n{s}"));
258 }
259 }
260 for msg in &turns {
261 let label = match msg.role.as_str() {
262 "user" => "[user]",
263 "assistant" => "[assistant]",
264 other => other,
265 };
266 parts.push(format!("{label}\n{}", msg.content));
267 }
268 parts.push("[assistant]".to_string());
269
270 let full_message = parts.join("\n\n");
271 self.invoke_cli(&full_message, model).await
272 }
273
274 async fn chat(
275 &self,
276 request: ChatRequest<'_>,
277 model: &str,
278 temperature: f64,
279 ) -> anyhow::Result<ChatResponse> {
280 let text = self
281 .chat_with_history(request.messages, model, temperature)
282 .await?;
283
284 Ok(ChatResponse {
285 text: Some(text),
286 tool_calls: Vec::new(),
287 usage: Some(TokenUsage::default()),
288 reasoning_content: None,
289 })
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use std::sync::atomic::{AtomicUsize, Ordering};
297 use std::sync::{Mutex, OnceLock};
298
299 fn env_lock() -> std::sync::MutexGuard<'static, ()> {
300 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
301 LOCK.get_or_init(|| Mutex::new(()))
302 .lock()
303 .expect("env lock poisoned")
304 }
305
306 fn script_mutex() -> &'static tokio::sync::Mutex<()> {
315 static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
316 LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
317 }
318
319 #[test]
320 fn new_uses_env_override() {
321 let _guard = env_lock();
322 let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
323 unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, "/usr/local/bin/claude") };
325 let provider = ClaudeCodeProvider::new();
326 assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/claude"));
327 match orig {
328 Some(v) => unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, v) },
330 None => unsafe { std::env::remove_var(CLAUDE_CODE_PATH_ENV) },
332 }
333 }
334
335 #[test]
336 fn new_defaults_to_claude() {
337 let _guard = env_lock();
338 let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
339 unsafe { std::env::remove_var(CLAUDE_CODE_PATH_ENV) };
341 let provider = ClaudeCodeProvider::new();
342 assert_eq!(provider.binary_path, PathBuf::from("claude"));
343 if let Some(v) = orig {
344 unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, v) };
346 }
347 }
348
349 #[test]
350 fn new_ignores_blank_env_override() {
351 let _guard = env_lock();
352 let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
353 unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, " ") };
355 let provider = ClaudeCodeProvider::new();
356 assert_eq!(provider.binary_path, PathBuf::from("claude"));
357 match orig {
358 Some(v) => unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, v) },
360 None => unsafe { std::env::remove_var(CLAUDE_CODE_PATH_ENV) },
362 }
363 }
364
365 #[test]
366 fn should_forward_model_standard() {
367 assert!(ClaudeCodeProvider::should_forward_model(
368 "claude-sonnet-4-20250514"
369 ));
370 assert!(ClaudeCodeProvider::should_forward_model(
371 "claude-3.5-sonnet"
372 ));
373 }
374
375 #[test]
376 fn should_not_forward_default_model() {
377 assert!(!ClaudeCodeProvider::should_forward_model(
378 DEFAULT_MODEL_MARKER
379 ));
380 assert!(!ClaudeCodeProvider::should_forward_model(""));
381 assert!(!ClaudeCodeProvider::should_forward_model(" "));
382 }
383
384 #[test]
385 fn validate_temperature_allows_defaults() {
386 assert!(ClaudeCodeProvider::validate_temperature(0.7).is_ok());
387 assert!(ClaudeCodeProvider::validate_temperature(1.0).is_ok());
388 }
389
390 #[test]
391 fn validate_temperature_clamps_custom_value() {
392 let clamped = ClaudeCodeProvider::validate_temperature(0.2).unwrap();
393 assert!((clamped - 0.7).abs() < 1e-9, "0.2 should clamp to 0.7");
394
395 let clamped = ClaudeCodeProvider::validate_temperature(0.9).unwrap();
396 assert!((clamped - 1.0).abs() < 1e-9, "0.9 should clamp to 1.0");
397 }
398
399 #[test]
400 fn validate_temperature_rejects_non_finite() {
401 assert!(ClaudeCodeProvider::validate_temperature(f64::NAN).is_err());
402 assert!(ClaudeCodeProvider::validate_temperature(f64::INFINITY).is_err());
403 }
404
405 #[tokio::test]
406 async fn invoke_missing_binary_returns_error() {
407 let provider = ClaudeCodeProvider {
408 binary_path: PathBuf::from("/nonexistent/path/to/claude"),
409 };
410 let result = provider.invoke_cli("hello", "default").await;
411 assert!(result.is_err());
412 let msg = result.unwrap_err().to_string();
413 assert!(
414 msg.contains("Failed to spawn Claude Code binary"),
415 "unexpected error message: {msg}"
416 );
417 }
418
419 fn echo_provider() -> ClaudeCodeProvider {
427 static SCRIPT_ID: AtomicUsize = AtomicUsize::new(0);
428 let script_id = SCRIPT_ID.fetch_add(1, Ordering::Relaxed);
429 let dir = std::env::temp_dir().join(format!(
430 "construct_test_claude_code_{}_{}",
431 std::process::id(),
432 script_id
433 ));
434 std::fs::create_dir_all(&dir).unwrap();
435
436 let path = dir.join("fake_claude.sh");
437 std::fs::write(&path, "#!/bin/sh\ncat /dev/stdin\n").unwrap();
438 #[cfg(unix)]
439 {
440 use std::os::unix::fs::PermissionsExt;
441 std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
442 }
443 ClaudeCodeProvider { binary_path: path }
444 }
445
446 #[test]
447 fn echo_provider_uses_unique_script_paths() {
448 let first = echo_provider();
449 let second = echo_provider();
450 assert_ne!(first.binary_path, second.binary_path);
451 }
452
453 #[tokio::test]
454 async fn chat_with_history_single_user_message() {
455 let _lock = script_mutex().lock().await;
456 let provider = echo_provider();
457 let messages = vec![ChatMessage::user("hello")];
458 let result = provider
459 .chat_with_history(&messages, "default", 1.0)
460 .await
461 .unwrap();
462 assert_eq!(result, "hello");
463 }
464
465 #[tokio::test]
466 async fn chat_with_history_single_user_with_system() {
467 let _lock = script_mutex().lock().await;
468 let provider = echo_provider();
469 let messages = vec![
470 ChatMessage::system("You are helpful."),
471 ChatMessage::user("hello"),
472 ];
473 let result = provider
474 .chat_with_history(&messages, "default", 1.0)
475 .await
476 .unwrap();
477 assert_eq!(result, "You are helpful.\n\nhello");
478 }
479
480 #[tokio::test]
481 async fn chat_with_history_multi_turn_includes_all_messages() {
482 let _lock = script_mutex().lock().await;
483 let provider = echo_provider();
484 let messages = vec![
485 ChatMessage::system("Be concise."),
486 ChatMessage::user("What is 2+2?"),
487 ChatMessage::assistant("4"),
488 ChatMessage::user("And 3+3?"),
489 ];
490 let result = provider
491 .chat_with_history(&messages, "default", 1.0)
492 .await
493 .unwrap();
494 assert!(result.contains("[system]\nBe concise."));
495 assert!(result.contains("[user]\nWhat is 2+2?"));
496 assert!(result.contains("[assistant]\n4"));
497 assert!(result.contains("[user]\nAnd 3+3?"));
498 assert!(result.ends_with("[assistant]"));
499 }
500
501 #[tokio::test]
502 async fn chat_with_history_multi_turn_without_system() {
503 let _lock = script_mutex().lock().await;
504 let provider = echo_provider();
505 let messages = vec![
506 ChatMessage::user("hi"),
507 ChatMessage::assistant("hello"),
508 ChatMessage::user("bye"),
509 ];
510 let result = provider
511 .chat_with_history(&messages, "default", 1.0)
512 .await
513 .unwrap();
514 assert!(!result.contains("[system]"));
515 assert!(result.contains("[user]\nhi"));
516 assert!(result.contains("[assistant]\nhello"));
517 assert!(result.contains("[user]\nbye"));
518 }
519
520 #[tokio::test]
521 async fn chat_with_history_clamps_bad_temperature() {
522 let _lock = script_mutex().lock().await;
523 let provider = echo_provider();
524 let messages = vec![ChatMessage::user("test")];
525 let result = provider.chat_with_history(&messages, "default", 0.5).await;
526 assert!(
527 result.is_ok(),
528 "unsupported temperature should be clamped, not rejected"
529 );
530 }
531}