1macro_rules! trace_debug {
3 ($($arg:tt)*) => {
4 #[cfg(feature = "tracing")]
5 tracing::debug!($($arg)*);
6 };
7}
8macro_rules! trace_error {
9 ($($arg:tt)*) => {
10 #[cfg(feature = "tracing")]
11 tracing::error!($($arg)*);
12 };
13}
14macro_rules! trace_info {
15 ($($arg:tt)*) => {
16 #[cfg(feature = "tracing")]
17 tracing::info!($($arg)*);
18 };
19}
20
21#[cfg(test)]
22use mockall::automock;
23
24use std::process::Output;
25
26use tokio::process::Command as TokioCommand;
27
28use crate::config::ClaudeConfig;
29use crate::conversation::Conversation;
30use crate::error::ClaudeError;
31use crate::types::{ClaudeResponse, strip_ansi};
32
33#[cfg(feature = "stream")]
34use crate::stream::{StreamEvent, parse_stream};
35#[cfg(feature = "stream")]
36use std::pin::Pin;
37#[cfg(feature = "stream")]
38use tokio::io::BufReader;
39#[cfg(feature = "stream")]
40use tokio_stream::Stream;
41
42#[allow(async_fn_in_trait)]
44#[cfg_attr(test, automock)]
45pub trait CommandRunner: Send + Sync {
46 async fn run(&self, args: &[String]) -> std::io::Result<Output>;
48}
49
50#[derive(Debug, Clone)]
52pub struct DefaultRunner {
53 cli_path: String,
54}
55
56impl DefaultRunner {
57 #[must_use]
59 pub fn new(cli_path: impl Into<String>) -> Self {
60 Self {
61 cli_path: cli_path.into(),
62 }
63 }
64}
65
66impl Default for DefaultRunner {
67 fn default() -> Self {
68 Self {
69 cli_path: "claude".into(),
70 }
71 }
72}
73
74impl CommandRunner for DefaultRunner {
75 async fn run(&self, args: &[String]) -> std::io::Result<Output> {
76 TokioCommand::new(&self.cli_path).args(args).output().await
77 }
78}
79
80#[cfg(feature = "stream")]
86struct ChildGuard(Option<tokio::process::Child>);
87
88#[cfg(feature = "stream")]
89impl Drop for ChildGuard {
90 fn drop(&mut self) {
91 if let Some(ref mut child) = self.0 {
92 let _ = child.start_kill();
93 }
94 }
95}
96
97#[derive(Debug, Clone)]
99pub struct ClaudeClient<R: CommandRunner = DefaultRunner> {
100 config: ClaudeConfig,
101 runner: R,
102}
103
104impl ClaudeClient {
105 #[must_use]
107 pub fn new(config: ClaudeConfig) -> Self {
108 let runner = DefaultRunner::new(config.cli_path_or_default());
109 Self { config, runner }
110 }
111}
112
113#[cfg(feature = "stream")]
114#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
115impl ClaudeClient {
116 pub async fn ask_stream(
131 &self,
132 prompt: &str,
133 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
134 {
135 let args = self.config.to_stream_args(prompt);
136
137 trace_debug!(args = ?args, "spawning claude CLI stream");
138
139 let mut child = TokioCommand::new(self.config.cli_path_or_default())
140 .args(&args)
141 .stdin(std::process::Stdio::null())
142 .stdout(std::process::Stdio::piped())
143 .stderr(std::process::Stdio::piped())
144 .spawn()
145 .map_err(|e| {
146 if e.kind() == std::io::ErrorKind::NotFound {
147 ClaudeError::CliNotFound
148 } else {
149 ClaudeError::Io(e)
150 }
151 })?;
152
153 let stdout = child.stdout.take().expect("stdout must be piped");
154 let reader = BufReader::new(stdout);
155 let event_stream = parse_stream(reader);
156 let mut guard = ChildGuard(Some(child));
157 let idle_timeout = self.config.stream_idle_timeout;
158
159 Ok(Box::pin(async_stream::stream! {
160 tokio::pin!(event_stream);
161
162 loop {
163 let next = tokio_stream::StreamExt::next(&mut event_stream);
164 let maybe_event = if let Some(timeout_dur) = idle_timeout {
165 match tokio::time::timeout(timeout_dur, next).await {
166 Ok(Some(event)) => Some(event),
167 Ok(None) => None,
168 Err(_) => {
169 trace_error!("stream idle timeout");
170 yield Err(ClaudeError::Timeout);
171 return;
172 }
173 }
174 } else {
175 next.await
176 };
177
178 match maybe_event {
179 Some(event) => yield Ok(event),
180 None => break,
181 }
182 }
183
184 if let Some(mut child) = guard.0.take() {
187 let status = child.wait().await;
188 match status {
189 Ok(s) if !s.success() => {
190 let code = s.code().unwrap_or(-1);
191 let mut stderr_buf = Vec::new();
192 if let Some(mut stderr) = child.stderr.take() {
193 let _ = tokio::io::AsyncReadExt::read_to_end(&mut stderr, &mut stderr_buf).await;
194 }
195 let stderr_str = String::from_utf8_lossy(&stderr_buf).into_owned();
196 yield Err(ClaudeError::NonZeroExit { code, stderr: stderr_str });
197 }
198 Err(e) => {
199 yield Err(ClaudeError::Io(e));
200 }
201 Ok(_) => {}
202 }
203 }
204 }))
205 }
206}
207
208impl<R: CommandRunner> ClaudeClient<R> {
209 #[must_use]
211 pub fn with_runner(config: ClaudeConfig, runner: R) -> Self {
212 Self { config, runner }
213 }
214
215 pub async fn ask_structured<T: serde::de::DeserializeOwned>(
221 &self,
222 prompt: &str,
223 ) -> Result<T, ClaudeError> {
224 let response = self.ask(prompt).await?;
225 response.parse_result()
226 }
227
228 pub async fn ask(&self, prompt: &str) -> Result<ClaudeResponse, ClaudeError> {
230 let args = self.config.to_args(prompt);
231
232 trace_debug!(args = ?args, "executing claude CLI");
233
234 let io_result: std::io::Result<Output> = if let Some(timeout) = self.config.timeout {
235 tokio::time::timeout(timeout, self.runner.run(&args))
236 .await
237 .map_err(|_| {
238 let err = ClaudeError::Timeout;
239 trace_error!(error = %err, "claude CLI failed");
240 err
241 })?
242 } else {
243 self.runner.run(&args).await
244 };
245
246 let output = io_result.map_err(|e| {
247 let err = if e.kind() == std::io::ErrorKind::NotFound {
248 ClaudeError::CliNotFound
249 } else {
250 ClaudeError::Io(e)
251 };
252 trace_error!(error = %err, "claude CLI failed");
253 err
254 })?;
255
256 if !output.status.success() {
257 let code = output.status.code().unwrap_or(-1);
258 let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
259 let err = ClaudeError::NonZeroExit { code, stderr };
260 trace_error!(error = %err, "claude CLI failed");
261 return Err(err);
262 }
263
264 let stdout = String::from_utf8_lossy(&output.stdout);
265 let json_str = strip_ansi(&stdout);
266 let response: ClaudeResponse = serde_json::from_str(json_str).map_err(|e| {
267 let err = ClaudeError::ParseError(e);
268 trace_error!(error = %err, "claude CLI failed");
269 err
270 })?;
271
272 trace_info!("claude CLI returned successfully");
273 Ok(response)
274 }
275}
276
277impl<R: CommandRunner + Clone> ClaudeClient<R> {
278 #[must_use]
286 pub fn conversation(&self) -> Conversation<R> {
287 Conversation::with_runner(self.config.clone(), self.runner.clone())
288 }
289
290 #[must_use]
295 pub fn conversation_resume(&self, session_id: impl Into<String>) -> Conversation<R> {
296 Conversation::with_runner_resume(self.config.clone(), self.runner.clone(), session_id)
297 }
298}
299
300pub async fn check_cli() -> Result<String, ClaudeError> {
311 check_cli_with_path("claude").await
312}
313
314pub async fn check_cli_with_path(cli_path: &str) -> Result<String, ClaudeError> {
324 let output = TokioCommand::new(cli_path)
325 .arg("--version")
326 .output()
327 .await
328 .map_err(|e| {
329 if e.kind() == std::io::ErrorKind::NotFound {
330 ClaudeError::CliNotFound
331 } else {
332 ClaudeError::Io(e)
333 }
334 })?;
335
336 if !output.status.success() {
337 let code = output.status.code().unwrap_or(-1);
338 let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
339 return Err(ClaudeError::NonZeroExit { code, stderr });
340 }
341
342 let version = String::from_utf8_lossy(&output.stdout).trim().to_string();
343 Ok(version)
344}
345
346fn parse_version(version: &str) -> Option<(u64, u64, u64)> {
350 let ver = version.split_whitespace().next_back()?;
352 let mut parts = ver.splitn(3, '.');
353 let major = parts.next()?.parse().ok()?;
354 let minor = parts.next()?.parse().ok()?;
355 let patch = parts.next()?.parse().ok()?;
356 Some((major, minor, patch))
357}
358
359#[derive(Debug, Clone, PartialEq, Eq)]
365#[non_exhaustive]
366pub enum CliVersionStatus {
367 Exact(String),
369 Newer(String),
371 Older(String),
373 Unknown(String),
375}
376
377fn compare_version(installed: &str, tested: &str) -> CliVersionStatus {
379 let tested_tuple = parse_version(tested).unwrap_or((0, 0, 0));
380 match parse_version(installed) {
381 None => CliVersionStatus::Unknown(installed.to_string()),
382 Some(v) if v == tested_tuple => CliVersionStatus::Exact(installed.to_string()),
383 Some(v) if v > tested_tuple => CliVersionStatus::Newer(installed.to_string()),
384 Some(_) => CliVersionStatus::Older(installed.to_string()),
385 }
386}
387
388pub async fn check_cli_version() -> Result<CliVersionStatus, ClaudeError> {
399 check_cli_version_with_path("claude").await
400}
401
402pub async fn check_cli_version_with_path(cli_path: &str) -> Result<CliVersionStatus, ClaudeError> {
412 let version = check_cli_with_path(cli_path).await?;
413 Ok(compare_version(&version, crate::TESTED_CLI_VERSION))
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use std::os::unix::process::ExitStatusExt;
420 use std::process::ExitStatus;
421
422 fn success_output() -> Output {
423 Output {
424 status: ExitStatus::from_raw(0),
425 stdout: include_bytes!("../tests/fixtures/success.json").to_vec(),
426 stderr: Vec::new(),
427 }
428 }
429
430 fn non_zero_output() -> Output {
431 Output {
432 status: ExitStatus::from_raw(256), stdout: Vec::new(),
434 stderr: b"something went wrong".to_vec(),
435 }
436 }
437
438 #[tokio::test]
439 async fn ask_success() {
440 let mut mock = MockCommandRunner::new();
441 mock.expect_run().returning(|_| Ok(success_output()));
442
443 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
444 let resp = client.ask("hello").await.unwrap();
445 assert_eq!(resp.result, "Hello!");
446 assert!(!resp.is_error);
447 }
448
449 #[tokio::test]
450 async fn ask_cli_not_found() {
451 let mut mock = MockCommandRunner::new();
452 mock.expect_run().returning(|_| {
453 Err(std::io::Error::new(
454 std::io::ErrorKind::NotFound,
455 "not found",
456 ))
457 });
458
459 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
460 let err = client.ask("hello").await.unwrap_err();
461 assert!(matches!(err, ClaudeError::CliNotFound));
462 }
463
464 #[tokio::test]
465 async fn ask_non_zero_exit() {
466 let mut mock = MockCommandRunner::new();
467 mock.expect_run().returning(|_| Ok(non_zero_output()));
468
469 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
470 let err = client.ask("hello").await.unwrap_err();
471 assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
472 }
473
474 #[tokio::test]
475 async fn ask_parse_error() {
476 let mut mock = MockCommandRunner::new();
477 mock.expect_run().returning(|_| {
478 Ok(Output {
479 status: ExitStatus::from_raw(0),
480 stdout: b"not json".to_vec(),
481 stderr: Vec::new(),
482 })
483 });
484
485 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
486 let err = client.ask("hello").await.unwrap_err();
487 assert!(matches!(err, ClaudeError::ParseError(_)));
488 }
489
490 struct SlowRunner;
492
493 impl CommandRunner for SlowRunner {
494 async fn run(&self, _args: &[String]) -> std::io::Result<Output> {
495 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
496 Ok(Output {
497 status: std::os::unix::process::ExitStatusExt::from_raw(0),
498 stdout: Vec::new(),
499 stderr: Vec::new(),
500 })
501 }
502 }
503
504 #[tokio::test(start_paused = true)]
505 async fn ask_timeout() {
506 let config = ClaudeConfig::builder()
507 .timeout(std::time::Duration::from_millis(10))
508 .build();
509 let client = ClaudeClient::with_runner(config, SlowRunner);
510 let err = client.ask("hello").await.unwrap_err();
511 assert!(matches!(err, ClaudeError::Timeout));
512 }
513
514 #[tokio::test]
515 async fn ask_io_error() {
516 let mut mock = MockCommandRunner::new();
517 mock.expect_run().returning(|_| {
518 Err(std::io::Error::new(
519 std::io::ErrorKind::PermissionDenied,
520 "denied",
521 ))
522 });
523
524 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
525 let err = client.ask("hello").await.unwrap_err();
526 assert!(matches!(err, ClaudeError::Io(_)));
527 }
528
529 #[tokio::test]
530 async fn ask_with_ansi_escape() {
531 let json = include_str!("../tests/fixtures/success.json");
532 let stdout = format!("\x1b[?1004l{json}\x1b[?1004l");
533
534 let mut mock = MockCommandRunner::new();
535 mock.expect_run().returning(move |_| {
536 Ok(Output {
537 status: ExitStatus::from_raw(0),
538 stdout: stdout.clone().into_bytes(),
539 stderr: Vec::new(),
540 })
541 });
542
543 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
544 let resp = client.ask("hello").await.unwrap();
545 assert_eq!(resp.result, "Hello!");
546 }
547
548 #[tokio::test]
549 async fn ask_passes_correct_args() {
550 let mut mock = MockCommandRunner::new();
551 mock.expect_run()
552 .withf(|args| {
553 args.contains(&"--print".to_string())
554 && args.contains(&"--model".to_string())
555 && args.contains(&"haiku".to_string())
556 && args.last() == Some(&"test prompt".to_string())
557 })
558 .returning(|_| Ok(success_output()));
559
560 let config = ClaudeConfig::builder().model("haiku").build();
561 let client = ClaudeClient::with_runner(config, mock);
562 client.ask("test prompt").await.unwrap();
563 }
564
565 #[derive(Debug, serde::Deserialize, PartialEq)]
566 struct TestAnswer {
567 value: i32,
568 }
569
570 fn structured_success_output() -> Output {
571 Output {
572 status: ExitStatus::from_raw(0),
573 stdout: include_bytes!("../tests/fixtures/structured_success.json").to_vec(),
574 stderr: Vec::new(),
575 }
576 }
577
578 #[tokio::test]
579 async fn ask_structured_success() {
580 let mut mock = MockCommandRunner::new();
581 mock.expect_run()
582 .returning(|_| Ok(structured_success_output()));
583
584 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
585 let answer: TestAnswer = client.ask_structured("What is 6*7?").await.unwrap();
586 assert_eq!(answer, TestAnswer { value: 42 });
587 }
588
589 #[tokio::test]
590 async fn ask_structured_deserialization_error() {
591 let mut mock = MockCommandRunner::new();
592 mock.expect_run().returning(|_| Ok(success_output()));
593
594 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
595 let err = client
596 .ask_structured::<TestAnswer>("hello")
597 .await
598 .unwrap_err();
599 assert!(matches!(err, ClaudeError::StructuredOutputError { .. }));
600 }
601
602 #[tokio::test]
603 async fn ask_structured_cli_error() {
604 let mut mock = MockCommandRunner::new();
605 mock.expect_run().returning(|_| Ok(non_zero_output()));
606
607 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
608 let err = client
609 .ask_structured::<TestAnswer>("hello")
610 .await
611 .unwrap_err();
612 assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
613 }
614
615 #[tokio::test]
621 async fn cli_path_with_shell_metacharacters_is_not_interpreted() {
622 let malicious = "claude; echo pwned";
623 let err = check_cli_with_path(malicious).await.unwrap_err();
624 assert!(matches!(err, ClaudeError::CliNotFound));
625 }
626
627 #[tokio::test]
628 async fn cli_path_with_command_substitution_is_not_interpreted() {
629 let malicious = "$(echo claude)";
630 let err = check_cli_with_path(malicious).await.unwrap_err();
631 assert!(matches!(err, ClaudeError::CliNotFound));
632 }
633
634 #[test]
635 fn parse_version_semver() {
636 assert_eq!(parse_version("2.1.92"), Some((2, 1, 92)));
637 }
638
639 #[test]
640 fn parse_version_with_prefix() {
641 assert_eq!(parse_version("claude-code 2.1.92"), Some((2, 1, 92)));
642 }
643
644 #[test]
645 fn parse_version_invalid() {
646 assert_eq!(parse_version("not-a-version"), None);
647 }
648
649 #[test]
650 fn parse_version_empty() {
651 assert_eq!(parse_version(""), None);
652 }
653
654 #[test]
655 fn parse_version_two_components() {
656 assert_eq!(parse_version("2.1"), None);
657 }
658
659 #[test]
660 fn parse_version_four_components() {
661 assert_eq!(parse_version("2.1.92.1"), None);
663 }
664
665 #[test]
666 fn compare_version_exact() {
667 let status = compare_version("2.1.92", "2.1.92");
668 assert!(matches!(status, CliVersionStatus::Exact(_)));
669 }
670
671 #[test]
672 fn compare_version_newer() {
673 let status = compare_version("2.2.0", "2.1.92");
674 assert!(matches!(status, CliVersionStatus::Newer(_)));
675 }
676
677 #[test]
678 fn compare_version_older() {
679 let status = compare_version("2.0.0", "2.1.92");
680 assert!(matches!(status, CliVersionStatus::Older(_)));
681 }
682
683 #[test]
684 fn compare_version_major_newer() {
685 let status = compare_version("3.0.0", "2.1.92");
686 assert!(matches!(status, CliVersionStatus::Newer(_)));
687 }
688
689 #[test]
690 fn compare_version_major_older() {
691 let status = compare_version("1.9.99", "2.1.92");
692 assert!(matches!(status, CliVersionStatus::Older(_)));
693 }
694
695 #[test]
696 fn compare_version_unparseable() {
697 let status = compare_version("garbage", "2.1.92");
698 assert!(matches!(status, CliVersionStatus::Unknown(_)));
699 }
700
701 #[test]
702 fn compare_version_with_prefix() {
703 let status = compare_version("claude-code 2.1.92", "2.1.92");
704 assert!(matches!(status, CliVersionStatus::Exact(_)));
705 }
706
707 #[test]
708 fn cli_version_status_preserves_version_string() {
709 let status = compare_version("2.2.0", "2.1.92");
710 match status {
711 CliVersionStatus::Newer(v) => assert_eq!(v, "2.2.0"),
712 other => panic!("expected Newer, got {other:?}"),
713 }
714 }
715}