Skip to main content

construct/providers/
claude_code.rs

1//! Claude Code headless CLI provider.
2//!
3//! Integrates with the Claude Code CLI, spawning the `claude` binary
4//! as a subprocess for each inference request. This allows using Claude's AI
5//! models without an interactive UI session.
6//!
7//! # Usage
8//!
9//! The `claude` binary must be available in `PATH`, or its location must be
10//! set via the `CLAUDE_CODE_PATH` environment variable.
11//!
12//! Claude Code is invoked as:
13//! ```text
14//! claude --print -
15//! ```
16//! with prompt content written to stdin.
17//!
18//! # Limitations
19//!
20//! - **System prompt**: The system prompt is prepended to the user message with a
21//!   blank-line separator, as the CLI does not provide a dedicated system-prompt flag.
22//! - **Temperature**: The CLI does not expose a temperature parameter.
23//!   Only default values are accepted; custom values return an explicit error.
24//!
25//! # Authentication
26//!
27//! Authentication is handled by Claude Code itself (its own credential store).
28//! No explicit API key is required by this provider.
29//!
30//! # Environment variables
31//!
32//! - `CLAUDE_CODE_PATH` — override the path to the `claude` binary (default: `"claude"`)
33
34use crate::providers::traits::{ChatMessage, ChatRequest, ChatResponse, Provider, TokenUsage};
35use async_trait::async_trait;
36use std::path::PathBuf;
37use tokio::io::AsyncWriteExt;
38use tokio::process::Command;
39use tokio::time::{Duration, timeout};
40
41/// Environment variable for overriding the path to the `claude` binary.
42pub const CLAUDE_CODE_PATH_ENV: &str = "CLAUDE_CODE_PATH";
43
44/// Default `claude` binary name (resolved via `PATH`).
45const DEFAULT_CLAUDE_CODE_BINARY: &str = "claude";
46
47/// Model name used to signal "use the provider's own default model".
48const DEFAULT_MODEL_MARKER: &str = "default";
49/// Claude Code requests are bounded to avoid hung subprocesses.
50const CLAUDE_CODE_REQUEST_TIMEOUT: Duration = Duration::from_secs(120);
51/// Avoid leaking oversized stderr payloads.
52const MAX_CLAUDE_CODE_STDERR_CHARS: usize = 512;
53/// The CLI does not support sampling controls; allow only baseline defaults.
54const CLAUDE_CODE_SUPPORTED_TEMPERATURES: [f64; 2] = [0.7, 1.0];
55const TEMP_EPSILON: f64 = 1e-9;
56
57/// Provider that invokes the Claude Code CLI as a subprocess.
58///
59/// Each inference request spawns a fresh `claude` process. This is the
60/// non-interactive approach: the process handles the prompt and exits.
61pub struct ClaudeCodeProvider {
62    /// Path to the `claude` binary.
63    binary_path: PathBuf,
64}
65
66impl ClaudeCodeProvider {
67    /// Create a new `ClaudeCodeProvider`.
68    ///
69    /// The binary path is resolved from `CLAUDE_CODE_PATH` env var if set,
70    /// otherwise defaults to `"claude"` (found via `PATH`).
71    pub fn new() -> Self {
72        let binary_path = std::env::var(CLAUDE_CODE_PATH_ENV)
73            .ok()
74            .filter(|path| !path.trim().is_empty())
75            .map(PathBuf::from)
76            .unwrap_or_else(|| PathBuf::from(DEFAULT_CLAUDE_CODE_BINARY));
77
78        Self { binary_path }
79    }
80
81    /// Returns true if the model argument should be forwarded to the CLI.
82    fn should_forward_model(model: &str) -> bool {
83        let trimmed = model.trim();
84        !trimmed.is_empty() && trimmed != DEFAULT_MODEL_MARKER
85    }
86
87    fn supports_temperature(temperature: f64) -> bool {
88        CLAUDE_CODE_SUPPORTED_TEMPERATURES
89            .iter()
90            .any(|v| (temperature - v).abs() < TEMP_EPSILON)
91    }
92
93    fn validate_temperature(temperature: f64) -> anyhow::Result<f64> {
94        if !temperature.is_finite() {
95            anyhow::bail!("Claude Code provider received non-finite temperature value");
96        }
97        if Self::supports_temperature(temperature) {
98            return Ok(temperature);
99        }
100        // Clamp to the nearest supported value — the CLI ignores temperature
101        // anyway, so a hard error just blocks callers like memory consolidation
102        // that legitimately request low temperatures.
103        let clamped = *CLAUDE_CODE_SUPPORTED_TEMPERATURES
104            .iter()
105            .min_by(|a, b| {
106                (temperature - **a)
107                    .abs()
108                    .partial_cmp(&(temperature - **b).abs())
109                    .unwrap()
110            })
111            .unwrap();
112        tracing::debug!(
113            requested = temperature,
114            clamped = clamped,
115            "Clamped unsupported temperature to nearest Claude Code CLI value"
116        );
117        Ok(clamped)
118    }
119
120    fn redact_stderr(stderr: &[u8]) -> String {
121        let text = String::from_utf8_lossy(stderr);
122        let trimmed = text.trim();
123        if trimmed.is_empty() {
124            return String::new();
125        }
126        if trimmed.chars().count() <= MAX_CLAUDE_CODE_STDERR_CHARS {
127            return trimmed.to_string();
128        }
129        let clipped: String = trimmed.chars().take(MAX_CLAUDE_CODE_STDERR_CHARS).collect();
130        format!("{clipped}...")
131    }
132
133    /// Invoke the claude binary with the given prompt and optional model.
134    /// Returns the trimmed stdout output as the assistant response.
135    async fn invoke_cli(&self, message: &str, model: &str) -> anyhow::Result<String> {
136        let mut cmd = Command::new(&self.binary_path);
137        cmd.arg("--print");
138
139        if Self::should_forward_model(model) {
140            cmd.arg("--model").arg(model);
141        }
142
143        // Read prompt from stdin to avoid exposing sensitive content in process args.
144        cmd.arg("-");
145        cmd.kill_on_drop(true);
146        cmd.stdin(std::process::Stdio::piped());
147        cmd.stdout(std::process::Stdio::piped());
148        cmd.stderr(std::process::Stdio::piped());
149
150        let mut child = cmd.spawn().map_err(|err| {
151            anyhow::anyhow!(
152                "Failed to spawn Claude Code binary at {}: {err}. \
153                 Ensure `claude` is installed and in PATH, or set CLAUDE_CODE_PATH.",
154                self.binary_path.display()
155            )
156        })?;
157
158        if let Some(mut stdin) = child.stdin.take() {
159            stdin.write_all(message.as_bytes()).await.map_err(|err| {
160                anyhow::anyhow!("Failed to write prompt to Claude Code stdin: {err}")
161            })?;
162            stdin.shutdown().await.map_err(|err| {
163                anyhow::anyhow!("Failed to finalize Claude Code stdin stream: {err}")
164            })?;
165        }
166
167        let output = timeout(CLAUDE_CODE_REQUEST_TIMEOUT, child.wait_with_output())
168            .await
169            .map_err(|_| {
170                anyhow::anyhow!(
171                    "Claude Code request timed out after {:?} (binary: {})",
172                    CLAUDE_CODE_REQUEST_TIMEOUT,
173                    self.binary_path.display()
174                )
175            })?
176            .map_err(|err| anyhow::anyhow!("Claude Code process failed: {err}"))?;
177
178        if !output.status.success() {
179            let code = output.status.code().unwrap_or(-1);
180            let stderr_excerpt = Self::redact_stderr(&output.stderr);
181            let stderr_note = if stderr_excerpt.is_empty() {
182                String::new()
183            } else {
184                format!(" Stderr: {stderr_excerpt}")
185            };
186            anyhow::bail!(
187                "Claude Code exited with non-zero status {code}. \
188                 Check that Claude Code is authenticated and the CLI is supported.{stderr_note}"
189            );
190        }
191
192        let text = String::from_utf8(output.stdout)
193            .map_err(|err| anyhow::anyhow!("Claude Code produced non-UTF-8 output: {err}"))?;
194
195        Ok(text.trim().to_string())
196    }
197}
198
199impl Default for ClaudeCodeProvider {
200    fn default() -> Self {
201        Self::new()
202    }
203}
204
205#[async_trait]
206impl Provider for ClaudeCodeProvider {
207    async fn chat_with_system(
208        &self,
209        system_prompt: Option<&str>,
210        message: &str,
211        model: &str,
212        temperature: f64,
213    ) -> anyhow::Result<String> {
214        Self::validate_temperature(temperature)?;
215
216        let full_message = match system_prompt {
217            Some(system) if !system.is_empty() => {
218                format!("{system}\n\n{message}")
219            }
220            _ => message.to_string(),
221        };
222
223        self.invoke_cli(&full_message, model).await
224    }
225
226    async fn chat_with_history(
227        &self,
228        messages: &[ChatMessage],
229        model: &str,
230        temperature: f64,
231    ) -> anyhow::Result<String> {
232        Self::validate_temperature(temperature)?;
233
234        // Separate system prompt from conversation messages.
235        let system = messages
236            .iter()
237            .find(|m| m.role == "system")
238            .map(|m| m.content.as_str());
239
240        // Build conversation turns (skip system messages).
241        let turns: Vec<&ChatMessage> = messages.iter().filter(|m| m.role != "system").collect();
242
243        // If there's only one user message, use the simple path.
244        if turns.len() <= 1 {
245            let last_user = turns.first().map(|m| m.content.as_str()).unwrap_or("");
246            let full_message = match system {
247                Some(s) if !s.is_empty() => format!("{s}\n\n{last_user}"),
248                _ => last_user.to_string(),
249            };
250            return self.invoke_cli(&full_message, model).await;
251        }
252
253        // Format multi-turn conversation into a single prompt.
254        let mut parts = Vec::new();
255        if let Some(s) = system {
256            if !s.is_empty() {
257                parts.push(format!("[system]\n{s}"));
258            }
259        }
260        for msg in &turns {
261            let label = match msg.role.as_str() {
262                "user" => "[user]",
263                "assistant" => "[assistant]",
264                other => other,
265            };
266            parts.push(format!("{label}\n{}", msg.content));
267        }
268        parts.push("[assistant]".to_string());
269
270        let full_message = parts.join("\n\n");
271        self.invoke_cli(&full_message, model).await
272    }
273
274    async fn chat(
275        &self,
276        request: ChatRequest<'_>,
277        model: &str,
278        temperature: f64,
279    ) -> anyhow::Result<ChatResponse> {
280        let text = self
281            .chat_with_history(request.messages, model, temperature)
282            .await?;
283
284        Ok(ChatResponse {
285            text: Some(text),
286            tool_calls: Vec::new(),
287            usage: Some(TokenUsage::default()),
288            reasoning_content: None,
289        })
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use std::sync::atomic::{AtomicUsize, Ordering};
297    use std::sync::{Mutex, OnceLock};
298
299    fn env_lock() -> std::sync::MutexGuard<'static, ()> {
300        static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
301        LOCK.get_or_init(|| Mutex::new(()))
302            .lock()
303            .expect("env lock poisoned")
304    }
305
306    /// Serialize tests that spawn the echo-provider script.
307    ///
308    /// On Linux, writing a shell script and exec'ing it from parallel threads
309    /// can trigger `ETXTBSY` ("Text file busy") even with unique file paths,
310    /// because the kernel briefly holds `deny_write_access` on the interpreter
311    /// page cache. Serializing these tests eliminates the race.
312    ///
313    /// Uses `tokio::sync::Mutex` so the guard can be held across `.await`.
314    fn script_mutex() -> &'static tokio::sync::Mutex<()> {
315        static LOCK: OnceLock<tokio::sync::Mutex<()>> = OnceLock::new();
316        LOCK.get_or_init(|| tokio::sync::Mutex::new(()))
317    }
318
319    #[test]
320    fn new_uses_env_override() {
321        let _guard = env_lock();
322        let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
323        // SAFETY: test-only, single-threaded test runner.
324        unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, "/usr/local/bin/claude") };
325        let provider = ClaudeCodeProvider::new();
326        assert_eq!(provider.binary_path, PathBuf::from("/usr/local/bin/claude"));
327        match orig {
328            // SAFETY: test-only, single-threaded test runner.
329            Some(v) => unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, v) },
330            // SAFETY: test-only, single-threaded test runner.
331            None => unsafe { std::env::remove_var(CLAUDE_CODE_PATH_ENV) },
332        }
333    }
334
335    #[test]
336    fn new_defaults_to_claude() {
337        let _guard = env_lock();
338        let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
339        // SAFETY: test-only, single-threaded test runner.
340        unsafe { std::env::remove_var(CLAUDE_CODE_PATH_ENV) };
341        let provider = ClaudeCodeProvider::new();
342        assert_eq!(provider.binary_path, PathBuf::from("claude"));
343        if let Some(v) = orig {
344            // SAFETY: test-only, single-threaded test runner.
345            unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, v) };
346        }
347    }
348
349    #[test]
350    fn new_ignores_blank_env_override() {
351        let _guard = env_lock();
352        let orig = std::env::var(CLAUDE_CODE_PATH_ENV).ok();
353        // SAFETY: test-only, single-threaded test runner.
354        unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, "   ") };
355        let provider = ClaudeCodeProvider::new();
356        assert_eq!(provider.binary_path, PathBuf::from("claude"));
357        match orig {
358            // SAFETY: test-only, single-threaded test runner.
359            Some(v) => unsafe { std::env::set_var(CLAUDE_CODE_PATH_ENV, v) },
360            // SAFETY: test-only, single-threaded test runner.
361            None => unsafe { std::env::remove_var(CLAUDE_CODE_PATH_ENV) },
362        }
363    }
364
365    #[test]
366    fn should_forward_model_standard() {
367        assert!(ClaudeCodeProvider::should_forward_model(
368            "claude-sonnet-4-20250514"
369        ));
370        assert!(ClaudeCodeProvider::should_forward_model(
371            "claude-3.5-sonnet"
372        ));
373    }
374
375    #[test]
376    fn should_not_forward_default_model() {
377        assert!(!ClaudeCodeProvider::should_forward_model(
378            DEFAULT_MODEL_MARKER
379        ));
380        assert!(!ClaudeCodeProvider::should_forward_model(""));
381        assert!(!ClaudeCodeProvider::should_forward_model("   "));
382    }
383
384    #[test]
385    fn validate_temperature_allows_defaults() {
386        assert!(ClaudeCodeProvider::validate_temperature(0.7).is_ok());
387        assert!(ClaudeCodeProvider::validate_temperature(1.0).is_ok());
388    }
389
390    #[test]
391    fn validate_temperature_clamps_custom_value() {
392        let clamped = ClaudeCodeProvider::validate_temperature(0.2).unwrap();
393        assert!((clamped - 0.7).abs() < 1e-9, "0.2 should clamp to 0.7");
394
395        let clamped = ClaudeCodeProvider::validate_temperature(0.9).unwrap();
396        assert!((clamped - 1.0).abs() < 1e-9, "0.9 should clamp to 1.0");
397    }
398
399    #[test]
400    fn validate_temperature_rejects_non_finite() {
401        assert!(ClaudeCodeProvider::validate_temperature(f64::NAN).is_err());
402        assert!(ClaudeCodeProvider::validate_temperature(f64::INFINITY).is_err());
403    }
404
405    #[tokio::test]
406    async fn invoke_missing_binary_returns_error() {
407        let provider = ClaudeCodeProvider {
408            binary_path: PathBuf::from("/nonexistent/path/to/claude"),
409        };
410        let result = provider.invoke_cli("hello", "default").await;
411        assert!(result.is_err());
412        let msg = result.unwrap_err().to_string();
413        assert!(
414            msg.contains("Failed to spawn Claude Code binary"),
415            "unexpected error message: {msg}"
416        );
417    }
418
419    /// Helper: create a provider that uses a shell script echoing stdin back.
420    /// The script ignores CLI flags (`--print`, `--model`, `-`) and just cats stdin.
421    ///
422    /// Each invocation places the script in its own unique directory and writes
423    /// the file atomically via `std::fs::write` to avoid `ETXTBSY` ("Text file
424    /// busy") races that occur when parallel test threads create and exec
425    /// scripts concurrently on the same filesystem.
426    fn echo_provider() -> ClaudeCodeProvider {
427        static SCRIPT_ID: AtomicUsize = AtomicUsize::new(0);
428        let script_id = SCRIPT_ID.fetch_add(1, Ordering::Relaxed);
429        let dir = std::env::temp_dir().join(format!(
430            "construct_test_claude_code_{}_{}",
431            std::process::id(),
432            script_id
433        ));
434        std::fs::create_dir_all(&dir).unwrap();
435
436        let path = dir.join("fake_claude.sh");
437        std::fs::write(&path, "#!/bin/sh\ncat /dev/stdin\n").unwrap();
438        #[cfg(unix)]
439        {
440            use std::os::unix::fs::PermissionsExt;
441            std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o755)).unwrap();
442        }
443        ClaudeCodeProvider { binary_path: path }
444    }
445
446    #[test]
447    fn echo_provider_uses_unique_script_paths() {
448        let first = echo_provider();
449        let second = echo_provider();
450        assert_ne!(first.binary_path, second.binary_path);
451    }
452
453    #[tokio::test]
454    async fn chat_with_history_single_user_message() {
455        let _lock = script_mutex().lock().await;
456        let provider = echo_provider();
457        let messages = vec![ChatMessage::user("hello")];
458        let result = provider
459            .chat_with_history(&messages, "default", 1.0)
460            .await
461            .unwrap();
462        assert_eq!(result, "hello");
463    }
464
465    #[tokio::test]
466    async fn chat_with_history_single_user_with_system() {
467        let _lock = script_mutex().lock().await;
468        let provider = echo_provider();
469        let messages = vec![
470            ChatMessage::system("You are helpful."),
471            ChatMessage::user("hello"),
472        ];
473        let result = provider
474            .chat_with_history(&messages, "default", 1.0)
475            .await
476            .unwrap();
477        assert_eq!(result, "You are helpful.\n\nhello");
478    }
479
480    #[tokio::test]
481    async fn chat_with_history_multi_turn_includes_all_messages() {
482        let _lock = script_mutex().lock().await;
483        let provider = echo_provider();
484        let messages = vec![
485            ChatMessage::system("Be concise."),
486            ChatMessage::user("What is 2+2?"),
487            ChatMessage::assistant("4"),
488            ChatMessage::user("And 3+3?"),
489        ];
490        let result = provider
491            .chat_with_history(&messages, "default", 1.0)
492            .await
493            .unwrap();
494        assert!(result.contains("[system]\nBe concise."));
495        assert!(result.contains("[user]\nWhat is 2+2?"));
496        assert!(result.contains("[assistant]\n4"));
497        assert!(result.contains("[user]\nAnd 3+3?"));
498        assert!(result.ends_with("[assistant]"));
499    }
500
501    #[tokio::test]
502    async fn chat_with_history_multi_turn_without_system() {
503        let _lock = script_mutex().lock().await;
504        let provider = echo_provider();
505        let messages = vec![
506            ChatMessage::user("hi"),
507            ChatMessage::assistant("hello"),
508            ChatMessage::user("bye"),
509        ];
510        let result = provider
511            .chat_with_history(&messages, "default", 1.0)
512            .await
513            .unwrap();
514        assert!(!result.contains("[system]"));
515        assert!(result.contains("[user]\nhi"));
516        assert!(result.contains("[assistant]\nhello"));
517        assert!(result.contains("[user]\nbye"));
518    }
519
520    #[tokio::test]
521    async fn chat_with_history_clamps_bad_temperature() {
522        let _lock = script_mutex().lock().await;
523        let provider = echo_provider();
524        let messages = vec![ChatMessage::user("test")];
525        let result = provider.chat_with_history(&messages, "default", 0.5).await;
526        assert!(
527            result.is_ok(),
528            "unsupported temperature should be clamped, not rejected"
529        );
530    }
531}