1macro_rules! trace_debug {
3 ($($arg:tt)*) => {
4 #[cfg(feature = "tracing")]
5 tracing::debug!($($arg)*);
6 };
7}
8macro_rules! trace_error {
9 ($($arg:tt)*) => {
10 #[cfg(feature = "tracing")]
11 tracing::error!($($arg)*);
12 };
13}
14macro_rules! trace_info {
15 ($($arg:tt)*) => {
16 #[cfg(feature = "tracing")]
17 tracing::info!($($arg)*);
18 };
19}
20
21#[cfg(test)]
22use mockall::automock;
23
24use std::process::Output;
25
26use tokio::process::Command as TokioCommand;
27
28use crate::config::ClaudeConfig;
29use crate::conversation::Conversation;
30use crate::error::ClaudeError;
31use crate::types::{ClaudeResponse, strip_ansi};
32
33#[cfg(feature = "stream")]
34use crate::stream::{StreamEvent, parse_stream};
35#[cfg(feature = "stream")]
36use std::pin::Pin;
37#[cfg(feature = "stream")]
38use tokio::io::BufReader;
39#[cfg(feature = "stream")]
40use tokio_stream::Stream;
41
42#[allow(async_fn_in_trait)]
44#[cfg_attr(test, automock)]
45pub trait CommandRunner: Send + Sync {
46 async fn run(&self, args: &[String]) -> std::io::Result<Output>;
48}
49
50#[derive(Debug, Clone)]
52pub struct DefaultRunner {
53 cli_path: String,
54}
55
56impl DefaultRunner {
57 #[must_use]
59 pub fn new(cli_path: impl Into<String>) -> Self {
60 Self {
61 cli_path: cli_path.into(),
62 }
63 }
64}
65
66impl Default for DefaultRunner {
67 fn default() -> Self {
68 Self {
69 cli_path: "claude".into(),
70 }
71 }
72}
73
74impl CommandRunner for DefaultRunner {
75 async fn run(&self, args: &[String]) -> std::io::Result<Output> {
76 TokioCommand::new(&self.cli_path).args(args).output().await
77 }
78}
79
80#[cfg(feature = "stream")]
86struct ChildGuard(Option<tokio::process::Child>);
87
88#[cfg(feature = "stream")]
89impl Drop for ChildGuard {
90 fn drop(&mut self) {
91 if let Some(ref mut child) = self.0 {
92 let _ = child.start_kill();
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct ClaudeClient<R: CommandRunner = DefaultRunner> {
100 config: ClaudeConfig,
101 runner: R,
102}
103
104impl ClaudeClient {
105 #[must_use]
107 pub fn new(config: ClaudeConfig) -> Self {
108 let runner = DefaultRunner::new(config.cli_path_or_default());
109 Self { config, runner }
110 }
111}
112
113#[cfg(feature = "stream")]
114#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
115impl ClaudeClient {
116 pub async fn ask_stream(
131 &self,
132 prompt: &str,
133 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
134 {
135 let args = self.config.to_stream_args(prompt);
136
137 trace_debug!(args = ?args, "spawning claude CLI stream");
138
139 let mut child = TokioCommand::new(self.config.cli_path_or_default())
140 .args(&args)
141 .stdout(std::process::Stdio::piped())
142 .stderr(std::process::Stdio::piped())
143 .spawn()
144 .map_err(|e| {
145 if e.kind() == std::io::ErrorKind::NotFound {
146 ClaudeError::CliNotFound
147 } else {
148 ClaudeError::Io(e)
149 }
150 })?;
151
152 let stdout = child.stdout.take().expect("stdout must be piped");
153 let reader = BufReader::new(stdout);
154 let event_stream = parse_stream(reader);
155 let mut guard = ChildGuard(Some(child));
156 let idle_timeout = self.config.stream_idle_timeout;
157
158 Ok(Box::pin(async_stream::stream! {
159 tokio::pin!(event_stream);
160
161 loop {
162 let next = tokio_stream::StreamExt::next(&mut event_stream);
163 let maybe_event = if let Some(timeout_dur) = idle_timeout {
164 match tokio::time::timeout(timeout_dur, next).await {
165 Ok(Some(event)) => Some(event),
166 Ok(None) => None,
167 Err(_) => {
168 trace_error!("stream idle timeout");
169 yield Err(ClaudeError::Timeout);
170 return;
171 }
172 }
173 } else {
174 next.await
175 };
176
177 match maybe_event {
178 Some(event) => yield Ok(event),
179 None => break,
180 }
181 }
182
183 if let Some(mut child) = guard.0.take() {
186 let status = child.wait().await;
187 match status {
188 Ok(s) if !s.success() => {
189 let code = s.code().unwrap_or(-1);
190 let mut stderr_buf = Vec::new();
191 if let Some(mut stderr) = child.stderr.take() {
192 let _ = tokio::io::AsyncReadExt::read_to_end(&mut stderr, &mut stderr_buf).await;
193 }
194 let stderr_str = String::from_utf8_lossy(&stderr_buf).into_owned();
195 yield Err(ClaudeError::NonZeroExit { code, stderr: stderr_str });
196 }
197 Err(e) => {
198 yield Err(ClaudeError::Io(e));
199 }
200 Ok(_) => {}
201 }
202 }
203 }))
204 }
205}
206
207impl<R: CommandRunner> ClaudeClient<R> {
208 #[must_use]
210 pub fn with_runner(config: ClaudeConfig, runner: R) -> Self {
211 Self { config, runner }
212 }
213
214 pub async fn ask_structured<T: serde::de::DeserializeOwned>(
220 &self,
221 prompt: &str,
222 ) -> Result<T, ClaudeError> {
223 let response = self.ask(prompt).await?;
224 response.parse_result()
225 }
226
227 pub async fn ask(&self, prompt: &str) -> Result<ClaudeResponse, ClaudeError> {
229 let args = self.config.to_args(prompt);
230
231 trace_debug!(args = ?args, "executing claude CLI");
232
233 let io_result: std::io::Result<Output> = if let Some(timeout) = self.config.timeout {
234 tokio::time::timeout(timeout, self.runner.run(&args))
235 .await
236 .map_err(|_| {
237 let err = ClaudeError::Timeout;
238 trace_error!(error = %err, "claude CLI failed");
239 err
240 })?
241 } else {
242 self.runner.run(&args).await
243 };
244
245 let output = io_result.map_err(|e| {
246 let err = if e.kind() == std::io::ErrorKind::NotFound {
247 ClaudeError::CliNotFound
248 } else {
249 ClaudeError::Io(e)
250 };
251 trace_error!(error = %err, "claude CLI failed");
252 err
253 })?;
254
255 if !output.status.success() {
256 let code = output.status.code().unwrap_or(-1);
257 let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
258 let err = ClaudeError::NonZeroExit { code, stderr };
259 trace_error!(error = %err, "claude CLI failed");
260 return Err(err);
261 }
262
263 let stdout = String::from_utf8_lossy(&output.stdout);
264 let json_str = strip_ansi(&stdout);
265 let response: ClaudeResponse = serde_json::from_str(json_str).map_err(|e| {
266 let err = ClaudeError::ParseError(e);
267 trace_error!(error = %err, "claude CLI failed");
268 err
269 })?;
270
271 trace_info!("claude CLI returned successfully");
272 Ok(response)
273 }
274}
275
276impl<R: CommandRunner + Clone> ClaudeClient<R> {
277 #[must_use]
285 pub fn conversation(&self) -> Conversation<R> {
286 Conversation::with_runner(self.config.clone(), self.runner.clone())
287 }
288
289 #[must_use]
294 pub fn conversation_resume(&self, session_id: impl Into<String>) -> Conversation<R> {
295 Conversation::with_runner_resume(self.config.clone(), self.runner.clone(), session_id)
296 }
297}
298
299pub async fn check_cli() -> Result<String, ClaudeError> {
310 check_cli_with_path("claude").await
311}
312
313pub async fn check_cli_with_path(cli_path: &str) -> Result<String, ClaudeError> {
323 let output = TokioCommand::new(cli_path)
324 .arg("--version")
325 .output()
326 .await
327 .map_err(|e| {
328 if e.kind() == std::io::ErrorKind::NotFound {
329 ClaudeError::CliNotFound
330 } else {
331 ClaudeError::Io(e)
332 }
333 })?;
334
335 if !output.status.success() {
336 let code = output.status.code().unwrap_or(-1);
337 let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
338 return Err(ClaudeError::NonZeroExit { code, stderr });
339 }
340
341 let version = String::from_utf8_lossy(&output.stdout).trim().to_string();
342 Ok(version)
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348 use std::os::unix::process::ExitStatusExt;
349 use std::process::ExitStatus;
350
351 fn success_output() -> Output {
352 Output {
353 status: ExitStatus::from_raw(0),
354 stdout: include_bytes!("../tests/fixtures/success.json").to_vec(),
355 stderr: Vec::new(),
356 }
357 }
358
359 fn non_zero_output() -> Output {
360 Output {
361 status: ExitStatus::from_raw(256), stdout: Vec::new(),
363 stderr: b"something went wrong".to_vec(),
364 }
365 }
366
367 #[tokio::test]
368 async fn ask_success() {
369 let mut mock = MockCommandRunner::new();
370 mock.expect_run().returning(|_| Ok(success_output()));
371
372 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
373 let resp = client.ask("hello").await.unwrap();
374 assert_eq!(resp.result, "Hello!");
375 assert!(!resp.is_error);
376 }
377
378 #[tokio::test]
379 async fn ask_cli_not_found() {
380 let mut mock = MockCommandRunner::new();
381 mock.expect_run().returning(|_| {
382 Err(std::io::Error::new(
383 std::io::ErrorKind::NotFound,
384 "not found",
385 ))
386 });
387
388 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
389 let err = client.ask("hello").await.unwrap_err();
390 assert!(matches!(err, ClaudeError::CliNotFound));
391 }
392
393 #[tokio::test]
394 async fn ask_non_zero_exit() {
395 let mut mock = MockCommandRunner::new();
396 mock.expect_run().returning(|_| Ok(non_zero_output()));
397
398 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
399 let err = client.ask("hello").await.unwrap_err();
400 assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
401 }
402
403 #[tokio::test]
404 async fn ask_parse_error() {
405 let mut mock = MockCommandRunner::new();
406 mock.expect_run().returning(|_| {
407 Ok(Output {
408 status: ExitStatus::from_raw(0),
409 stdout: b"not json".to_vec(),
410 stderr: Vec::new(),
411 })
412 });
413
414 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
415 let err = client.ask("hello").await.unwrap_err();
416 assert!(matches!(err, ClaudeError::ParseError(_)));
417 }
418
419 struct SlowRunner;
421
422 impl CommandRunner for SlowRunner {
423 async fn run(&self, _args: &[String]) -> std::io::Result<Output> {
424 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
425 Ok(Output {
426 status: std::os::unix::process::ExitStatusExt::from_raw(0),
427 stdout: Vec::new(),
428 stderr: Vec::new(),
429 })
430 }
431 }
432
433 #[tokio::test(start_paused = true)]
434 async fn ask_timeout() {
435 let config = ClaudeConfig::builder()
436 .timeout(std::time::Duration::from_millis(10))
437 .build();
438 let client = ClaudeClient::with_runner(config, SlowRunner);
439 let err = client.ask("hello").await.unwrap_err();
440 assert!(matches!(err, ClaudeError::Timeout));
441 }
442
443 #[tokio::test]
444 async fn ask_io_error() {
445 let mut mock = MockCommandRunner::new();
446 mock.expect_run().returning(|_| {
447 Err(std::io::Error::new(
448 std::io::ErrorKind::PermissionDenied,
449 "denied",
450 ))
451 });
452
453 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
454 let err = client.ask("hello").await.unwrap_err();
455 assert!(matches!(err, ClaudeError::Io(_)));
456 }
457
458 #[tokio::test]
459 async fn ask_with_ansi_escape() {
460 let json = include_str!("../tests/fixtures/success.json");
461 let stdout = format!("\x1b[?1004l{json}\x1b[?1004l");
462
463 let mut mock = MockCommandRunner::new();
464 mock.expect_run().returning(move |_| {
465 Ok(Output {
466 status: ExitStatus::from_raw(0),
467 stdout: stdout.clone().into_bytes(),
468 stderr: Vec::new(),
469 })
470 });
471
472 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
473 let resp = client.ask("hello").await.unwrap();
474 assert_eq!(resp.result, "Hello!");
475 }
476
477 #[tokio::test]
478 async fn ask_passes_correct_args() {
479 let mut mock = MockCommandRunner::new();
480 mock.expect_run()
481 .withf(|args| {
482 args.contains(&"--print".to_string())
483 && args.contains(&"--model".to_string())
484 && args.contains(&"haiku".to_string())
485 && args.last() == Some(&"test prompt".to_string())
486 })
487 .returning(|_| Ok(success_output()));
488
489 let config = ClaudeConfig::builder().model("haiku").build();
490 let client = ClaudeClient::with_runner(config, mock);
491 client.ask("test prompt").await.unwrap();
492 }
493
494 #[derive(Debug, serde::Deserialize, PartialEq)]
495 struct TestAnswer {
496 value: i32,
497 }
498
499 fn structured_success_output() -> Output {
500 Output {
501 status: ExitStatus::from_raw(0),
502 stdout: include_bytes!("../tests/fixtures/structured_success.json").to_vec(),
503 stderr: Vec::new(),
504 }
505 }
506
507 #[tokio::test]
508 async fn ask_structured_success() {
509 let mut mock = MockCommandRunner::new();
510 mock.expect_run()
511 .returning(|_| Ok(structured_success_output()));
512
513 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
514 let answer: TestAnswer = client.ask_structured("What is 6*7?").await.unwrap();
515 assert_eq!(answer, TestAnswer { value: 42 });
516 }
517
518 #[tokio::test]
519 async fn ask_structured_deserialization_error() {
520 let mut mock = MockCommandRunner::new();
521 mock.expect_run().returning(|_| Ok(success_output()));
522
523 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
524 let err = client
525 .ask_structured::<TestAnswer>("hello")
526 .await
527 .unwrap_err();
528 assert!(matches!(err, ClaudeError::StructuredOutputError { .. }));
529 }
530
531 #[tokio::test]
532 async fn ask_structured_cli_error() {
533 let mut mock = MockCommandRunner::new();
534 mock.expect_run().returning(|_| Ok(non_zero_output()));
535
536 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
537 let err = client
538 .ask_structured::<TestAnswer>("hello")
539 .await
540 .unwrap_err();
541 assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
542 }
543
544 #[tokio::test]
550 async fn cli_path_with_shell_metacharacters_is_not_interpreted() {
551 let malicious = "claude; echo pwned";
552 let err = check_cli_with_path(malicious).await.unwrap_err();
553 assert!(matches!(err, ClaudeError::CliNotFound));
554 }
555
556 #[tokio::test]
557 async fn cli_path_with_command_substitution_is_not_interpreted() {
558 let malicious = "$(echo claude)";
559 let err = check_cli_with_path(malicious).await.unwrap_err();
560 assert!(matches!(err, ClaudeError::CliNotFound));
561 }
562}