use crate::error::ProviderError;
use crate::provider::{CompletionConfig, LlmProvider, resolve_claude_alias};
use serde::Deserialize;
use std::process::Stdio;
use tokio::io::AsyncWriteExt;
use tokio::process::Command;
#[derive(Debug)]
pub struct ClaudeCliProvider {
model_id: String,
}
impl ClaudeCliProvider {
pub fn new(model: impl Into<String>) -> Result<Self, ProviderError> {
if std::env::var("CLAUDECODE").is_ok() {
return Err(ProviderError::NestedSession);
}
let model = model.into();
let model_id = resolve_claude_alias(&model)?;
Ok(Self { model_id })
}
pub fn name(&self) -> &str {
"claude-cli"
}
pub fn model(&self) -> &str {
&self.model_id
}
fn build_args(&self, system_prompt: &str) -> Vec<String> {
vec![
"--print".to_string(),
"--output-format".to_string(),
"json".to_string(),
"--model".to_string(),
self.model_id.clone(),
"--system-prompt".to_string(),
system_prompt.to_string(),
]
}
}
fn parse_cli_output(raw: &str) -> Result<String, ProviderError> {
let output: CliOutput = serde_json::from_str(raw).map_err(|e| ProviderError::Process {
exit_code: None,
stderr: format!("failed to parse CLI output: {e}"),
})?;
if output.is_error {
return Err(ProviderError::Process {
exit_code: None,
stderr: output.result,
});
}
Ok(output.result)
}
fn strip_code_fences(text: &str) -> &str {
let stripped = text
.strip_prefix("```json\n")
.or_else(|| text.strip_prefix("```\n"));
match stripped {
Some(inner) => inner.strip_suffix("\n```").unwrap_or(inner),
None => text,
}
}
#[derive(Deserialize)]
struct CliOutput {
is_error: bool,
result: String,
}
#[async_trait::async_trait]
impl LlmProvider for ClaudeCliProvider {
async fn complete(
&self,
system_prompt: &str,
user_prompt: &str,
_config: &CompletionConfig,
) -> Result<String, ProviderError> {
let args = self.build_args(system_prompt);
let mut child = Command::new("claude")
.args(&args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true)
.spawn()
.map_err(|e| ProviderError::Process {
exit_code: None,
stderr: format!("failed to spawn claude process: {e}"),
})?;
if let Some(mut stdin) = child.stdin.take() {
stdin
.write_all(user_prompt.as_bytes())
.await
.map_err(|e| ProviderError::Process {
exit_code: None,
stderr: format!("failed to write to stdin: {e}"),
})?;
}
let output = child
.wait_with_output()
.await
.map_err(|e| ProviderError::Process {
exit_code: None,
stderr: format!("failed to wait for claude process: {e}"),
})?;
if !output.status.success() {
return Err(ProviderError::Process {
exit_code: output.status.code(),
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
});
}
let stdout = String::from_utf8_lossy(&output.stdout);
let result = parse_cli_output(&stdout)?;
Ok(strip_code_fences(&result).to_string())
}
fn name(&self) -> &str {
ClaudeCliProvider::name(self)
}
fn model(&self) -> &str {
ClaudeCliProvider::model(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
fn without_claudecode<F: FnOnce()>(f: F) {
let original = std::env::var("CLAUDECODE").ok();
unsafe {
std::env::remove_var("CLAUDECODE");
}
f();
if let Some(val) = original {
unsafe {
std::env::set_var("CLAUDECODE", val);
}
}
}
fn with_claudecode<F: FnOnce()>(f: F) {
let original = std::env::var("CLAUDECODE").ok();
unsafe {
std::env::set_var("CLAUDECODE", "1");
}
f();
if let Some(val) = original {
unsafe {
std::env::set_var("CLAUDECODE", val);
}
} else {
unsafe {
std::env::remove_var("CLAUDECODE");
}
}
}
#[test]
#[serial]
fn test_new_with_claudecode_env_returns_nested_session_error() {
with_claudecode(|| {
let result = ClaudeCliProvider::new("sonnet");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, ProviderError::NestedSession),
"expected NestedSession, got: {err}"
);
});
}
#[test]
#[serial]
fn test_new_sonnet_maps_to_claude_sonnet_model() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("sonnet").unwrap();
assert_eq!(provider.model(), "claude-sonnet-4-6");
});
}
#[test]
#[serial]
fn test_new_opus_maps_to_claude_opus_model() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("opus").unwrap();
assert_eq!(provider.model(), "claude-opus-4-7");
});
}
#[test]
#[serial]
fn test_new_haiku_maps_to_claude_haiku_model() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("haiku").unwrap();
assert_eq!(provider.model(), "claude-haiku-4-5-20251001");
});
}
#[test]
#[serial]
fn test_new_claude_prefix_passes_through() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("claude-custom-model").unwrap();
assert_eq!(provider.model(), "claude-custom-model");
});
}
#[test]
#[serial]
fn test_new_invalid_model_returns_auth_error() {
without_claudecode(|| {
let result = ClaudeCliProvider::new("invalid");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, ProviderError::Auth { .. }),
"expected Auth, got: {err}"
);
});
}
#[test]
#[serial]
fn test_new_mixed_case_alias_returns_auth_error() {
without_claudecode(|| {
let result_upper = ClaudeCliProvider::new("Sonnet");
let result_all_caps = ClaudeCliProvider::new("SONNET");
assert!(
matches!(result_upper.unwrap_err(), ProviderError::Auth { .. }),
"expected Auth for 'Sonnet'"
);
assert!(
matches!(result_all_caps.unwrap_err(), ProviderError::Auth { .. }),
"expected Auth for 'SONNET'"
);
});
}
#[test]
#[serial]
fn test_provider_name_returns_claude_cli() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("sonnet").unwrap();
assert_eq!(provider.name(), "claude-cli");
});
}
#[test]
#[serial]
fn test_provider_model_returns_resolved_model_id() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("opus").unwrap();
assert_eq!(provider.model(), "claude-opus-4-7");
});
}
#[test]
#[serial]
fn test_build_args_includes_required_cli_flags() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("sonnet").unwrap();
let args = provider.build_args("You are an analyst.");
assert!(args.contains(&"--print".to_string()));
assert!(args.contains(&"--output-format".to_string()));
assert!(args.contains(&"json".to_string()));
assert!(args.contains(&"--model".to_string()));
assert!(args.contains(&"claude-sonnet-4-6".to_string()));
assert!(args.contains(&"--system-prompt".to_string()));
assert!(args.contains(&"You are an analyst.".to_string()));
});
}
#[test]
#[serial]
fn test_user_prompt_not_in_build_args() {
without_claudecode(|| {
let provider = ClaudeCliProvider::new("sonnet").unwrap();
let user_prompt = "Analyze this code for security issues";
let args = provider.build_args("System prompt");
assert!(
!args.contains(&user_prompt.to_string()),
"user prompt should not be in CLI args"
);
});
}
#[test]
fn test_parse_cli_output_extracts_inner_result() {
let outer = r#"{"type":"result","subtype":"success","is_error":false,"result":"{\"agent\":\"melchior\",\"verdict\":\"approve\"}","usage":{"input_tokens":100}}"#;
let result = parse_cli_output(outer).unwrap();
assert_eq!(result, r#"{"agent":"melchior","verdict":"approve"}"#);
}
#[test]
fn test_parse_cli_output_error_flag_returns_process_error() {
let outer =
r#"{"type":"result","subtype":"error","is_error":true,"result":"Rate limit exceeded"}"#;
let result = parse_cli_output(outer);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, ProviderError::Process { .. }),
"expected Process, got: {err}"
);
}
#[test]
fn test_parse_cli_output_malformed_json_returns_process_error() {
let result = parse_cli_output("not valid json");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, ProviderError::Process { .. }),
"expected Process, got: {err}"
);
}
#[test]
fn test_strip_code_fences_removes_json_fence() {
let input = "```json\n{\"key\": \"value\"}\n```";
let result = strip_code_fences(input);
assert_eq!(result, "{\"key\": \"value\"}");
}
#[test]
fn test_strip_code_fences_removes_plain_fence() {
let input = "```\n{\"key\": \"value\"}\n```";
let result = strip_code_fences(input);
assert_eq!(result, "{\"key\": \"value\"}");
}
#[test]
fn test_strip_code_fences_no_fences_returns_unchanged() {
let input = r#"{"key": "value"}"#;
let result = strip_code_fences(input);
assert_eq!(result, input);
}
}