1macro_rules! trace_debug {
3 ($($arg:tt)*) => {
4 #[cfg(feature = "tracing")]
5 tracing::debug!($($arg)*);
6 };
7}
8macro_rules! trace_error {
9 ($($arg:tt)*) => {
10 #[cfg(feature = "tracing")]
11 tracing::error!($($arg)*);
12 };
13}
14macro_rules! trace_info {
15 ($($arg:tt)*) => {
16 #[cfg(feature = "tracing")]
17 tracing::info!($($arg)*);
18 };
19}
20
21#[cfg(test)]
22use mockall::automock;
23
24use std::process::Output;
25
26use tokio::process::Command as TokioCommand;
27
28use crate::config::ClaudeConfig;
29use crate::conversation::Conversation;
30use crate::error::ClaudeError;
31use crate::types::{ClaudeResponse, strip_ansi};
32
33#[cfg(feature = "stream")]
34use crate::stream::{StreamEvent, parse_stream};
35#[cfg(feature = "stream")]
36use std::pin::Pin;
37#[cfg(feature = "stream")]
38use tokio::io::BufReader;
39#[cfg(feature = "stream")]
40use tokio_stream::Stream;
41
42#[allow(async_fn_in_trait)]
44#[cfg_attr(test, automock)]
45pub trait CommandRunner: Send + Sync {
46 async fn run(&self, args: &[String]) -> std::io::Result<Output>;
48}
49
50#[derive(Debug, Clone)]
52pub struct DefaultRunner;
53
54impl CommandRunner for DefaultRunner {
55 async fn run(&self, args: &[String]) -> std::io::Result<Output> {
56 TokioCommand::new("claude").args(args).output().await
57 }
58}
59
60#[cfg(feature = "stream")]
66struct ChildGuard(Option<tokio::process::Child>);
67
68#[cfg(feature = "stream")]
69impl Drop for ChildGuard {
70 fn drop(&mut self) {
71 if let Some(ref mut child) = self.0 {
72 let _ = child.start_kill();
73 }
74 }
75}
76
77#[derive(Debug, Clone)]
79pub struct ClaudeClient<R: CommandRunner = DefaultRunner> {
80 config: ClaudeConfig,
81 runner: R,
82}
83
84impl ClaudeClient {
85 #[must_use]
87 pub fn new(config: ClaudeConfig) -> Self {
88 Self {
89 config,
90 runner: DefaultRunner,
91 }
92 }
93}
94
95#[cfg(feature = "stream")]
96#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
97impl ClaudeClient {
98 pub async fn ask_stream(
113 &self,
114 prompt: &str,
115 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
116 {
117 let args = self.config.to_stream_args(prompt);
118
119 trace_debug!(args = ?args, "spawning claude CLI stream");
120
121 let mut child = TokioCommand::new("claude")
122 .args(&args)
123 .stdout(std::process::Stdio::piped())
124 .stderr(std::process::Stdio::piped())
125 .spawn()
126 .map_err(|e| {
127 if e.kind() == std::io::ErrorKind::NotFound {
128 ClaudeError::CliNotFound
129 } else {
130 ClaudeError::Io(e)
131 }
132 })?;
133
134 let stdout = child.stdout.take().expect("stdout must be piped");
135 let reader = BufReader::new(stdout);
136 let event_stream = parse_stream(reader);
137 let mut guard = ChildGuard(Some(child));
138 let idle_timeout = self.config.stream_idle_timeout;
139
140 Ok(Box::pin(async_stream::stream! {
141 tokio::pin!(event_stream);
142
143 loop {
144 let next = tokio_stream::StreamExt::next(&mut event_stream);
145 let maybe_event = if let Some(timeout_dur) = idle_timeout {
146 match tokio::time::timeout(timeout_dur, next).await {
147 Ok(Some(event)) => Some(event),
148 Ok(None) => None,
149 Err(_) => {
150 trace_error!("stream idle timeout");
151 yield Err(ClaudeError::Timeout);
152 return;
153 }
154 }
155 } else {
156 next.await
157 };
158
159 match maybe_event {
160 Some(event) => yield Ok(event),
161 None => break,
162 }
163 }
164
165 if let Some(mut child) = guard.0.take() {
168 let status = child.wait().await;
169 match status {
170 Ok(s) if !s.success() => {
171 let code = s.code().unwrap_or(-1);
172 let mut stderr_buf = Vec::new();
173 if let Some(mut stderr) = child.stderr.take() {
174 let _ = tokio::io::AsyncReadExt::read_to_end(&mut stderr, &mut stderr_buf).await;
175 }
176 let stderr_str = String::from_utf8_lossy(&stderr_buf).into_owned();
177 yield Err(ClaudeError::NonZeroExit { code, stderr: stderr_str });
178 }
179 Err(e) => {
180 yield Err(ClaudeError::Io(e));
181 }
182 Ok(_) => {}
183 }
184 }
185 }))
186 }
187}
188
189impl<R: CommandRunner> ClaudeClient<R> {
190 #[must_use]
192 pub fn with_runner(config: ClaudeConfig, runner: R) -> Self {
193 Self { config, runner }
194 }
195
196 pub async fn ask_structured<T: serde::de::DeserializeOwned>(
202 &self,
203 prompt: &str,
204 ) -> Result<T, ClaudeError> {
205 let response = self.ask(prompt).await?;
206 response.parse_result()
207 }
208
209 pub async fn ask(&self, prompt: &str) -> Result<ClaudeResponse, ClaudeError> {
211 let args = self.config.to_args(prompt);
212
213 trace_debug!(args = ?args, "executing claude CLI");
214
215 let io_result: std::io::Result<Output> = if let Some(timeout) = self.config.timeout {
216 tokio::time::timeout(timeout, self.runner.run(&args))
217 .await
218 .map_err(|_| {
219 let err = ClaudeError::Timeout;
220 trace_error!(error = %err, "claude CLI failed");
221 err
222 })?
223 } else {
224 self.runner.run(&args).await
225 };
226
227 let output = io_result.map_err(|e| {
228 let err = if e.kind() == std::io::ErrorKind::NotFound {
229 ClaudeError::CliNotFound
230 } else {
231 ClaudeError::Io(e)
232 };
233 trace_error!(error = %err, "claude CLI failed");
234 err
235 })?;
236
237 if !output.status.success() {
238 let code = output.status.code().unwrap_or(-1);
239 let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
240 let err = ClaudeError::NonZeroExit { code, stderr };
241 trace_error!(error = %err, "claude CLI failed");
242 return Err(err);
243 }
244
245 let stdout = String::from_utf8_lossy(&output.stdout);
246 let json_str = strip_ansi(&stdout);
247 let response: ClaudeResponse = serde_json::from_str(json_str).map_err(|e| {
248 let err = ClaudeError::ParseError(e);
249 trace_error!(error = %err, "claude CLI failed");
250 err
251 })?;
252
253 trace_info!("claude CLI returned successfully");
254 Ok(response)
255 }
256}
257
258impl<R: CommandRunner + Clone> ClaudeClient<R> {
259 #[must_use]
267 pub fn conversation(&self) -> Conversation<R> {
268 Conversation::with_runner(self.config.clone(), self.runner.clone())
269 }
270
271 #[must_use]
276 pub fn conversation_resume(&self, session_id: impl Into<String>) -> Conversation<R> {
277 Conversation::with_runner_resume(self.config.clone(), self.runner.clone(), session_id)
278 }
279}
280
281pub async fn check_cli() -> Result<String, ClaudeError> {
291 let output = TokioCommand::new("claude")
292 .arg("--version")
293 .output()
294 .await
295 .map_err(|e| {
296 if e.kind() == std::io::ErrorKind::NotFound {
297 ClaudeError::CliNotFound
298 } else {
299 ClaudeError::Io(e)
300 }
301 })?;
302
303 if !output.status.success() {
304 let code = output.status.code().unwrap_or(-1);
305 let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
306 return Err(ClaudeError::NonZeroExit { code, stderr });
307 }
308
309 let version = String::from_utf8_lossy(&output.stdout).trim().to_string();
310 Ok(version)
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use std::os::unix::process::ExitStatusExt;
317 use std::process::ExitStatus;
318
319 fn success_output() -> Output {
320 Output {
321 status: ExitStatus::from_raw(0),
322 stdout: include_bytes!("../tests/fixtures/success.json").to_vec(),
323 stderr: Vec::new(),
324 }
325 }
326
327 fn non_zero_output() -> Output {
328 Output {
329 status: ExitStatus::from_raw(256), stdout: Vec::new(),
331 stderr: b"something went wrong".to_vec(),
332 }
333 }
334
335 #[tokio::test]
336 async fn ask_success() {
337 let mut mock = MockCommandRunner::new();
338 mock.expect_run().returning(|_| Ok(success_output()));
339
340 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
341 let resp = client.ask("hello").await.unwrap();
342 assert_eq!(resp.result, "Hello!");
343 assert!(!resp.is_error);
344 }
345
346 #[tokio::test]
347 async fn ask_cli_not_found() {
348 let mut mock = MockCommandRunner::new();
349 mock.expect_run().returning(|_| {
350 Err(std::io::Error::new(
351 std::io::ErrorKind::NotFound,
352 "not found",
353 ))
354 });
355
356 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
357 let err = client.ask("hello").await.unwrap_err();
358 assert!(matches!(err, ClaudeError::CliNotFound));
359 }
360
361 #[tokio::test]
362 async fn ask_non_zero_exit() {
363 let mut mock = MockCommandRunner::new();
364 mock.expect_run().returning(|_| Ok(non_zero_output()));
365
366 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
367 let err = client.ask("hello").await.unwrap_err();
368 assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
369 }
370
371 #[tokio::test]
372 async fn ask_parse_error() {
373 let mut mock = MockCommandRunner::new();
374 mock.expect_run().returning(|_| {
375 Ok(Output {
376 status: ExitStatus::from_raw(0),
377 stdout: b"not json".to_vec(),
378 stderr: Vec::new(),
379 })
380 });
381
382 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
383 let err = client.ask("hello").await.unwrap_err();
384 assert!(matches!(err, ClaudeError::ParseError(_)));
385 }
386
387 struct SlowRunner;
389
390 impl CommandRunner for SlowRunner {
391 async fn run(&self, _args: &[String]) -> std::io::Result<Output> {
392 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
393 Ok(Output {
394 status: std::os::unix::process::ExitStatusExt::from_raw(0),
395 stdout: Vec::new(),
396 stderr: Vec::new(),
397 })
398 }
399 }
400
401 #[tokio::test(start_paused = true)]
402 async fn ask_timeout() {
403 let config = ClaudeConfig::builder()
404 .timeout(std::time::Duration::from_millis(10))
405 .build();
406 let client = ClaudeClient::with_runner(config, SlowRunner);
407 let err = client.ask("hello").await.unwrap_err();
408 assert!(matches!(err, ClaudeError::Timeout));
409 }
410
411 #[tokio::test]
412 async fn ask_io_error() {
413 let mut mock = MockCommandRunner::new();
414 mock.expect_run().returning(|_| {
415 Err(std::io::Error::new(
416 std::io::ErrorKind::PermissionDenied,
417 "denied",
418 ))
419 });
420
421 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
422 let err = client.ask("hello").await.unwrap_err();
423 assert!(matches!(err, ClaudeError::Io(_)));
424 }
425
426 #[tokio::test]
427 async fn ask_with_ansi_escape() {
428 let json = include_str!("../tests/fixtures/success.json");
429 let stdout = format!("\x1b[?1004l{json}\x1b[?1004l");
430
431 let mut mock = MockCommandRunner::new();
432 mock.expect_run().returning(move |_| {
433 Ok(Output {
434 status: ExitStatus::from_raw(0),
435 stdout: stdout.clone().into_bytes(),
436 stderr: Vec::new(),
437 })
438 });
439
440 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
441 let resp = client.ask("hello").await.unwrap();
442 assert_eq!(resp.result, "Hello!");
443 }
444
445 #[tokio::test]
446 async fn ask_passes_correct_args() {
447 let mut mock = MockCommandRunner::new();
448 mock.expect_run()
449 .withf(|args| {
450 args.contains(&"--print".to_string())
451 && args.contains(&"--model".to_string())
452 && args.contains(&"haiku".to_string())
453 && args.last() == Some(&"test prompt".to_string())
454 })
455 .returning(|_| Ok(success_output()));
456
457 let config = ClaudeConfig::builder().model("haiku").build();
458 let client = ClaudeClient::with_runner(config, mock);
459 client.ask("test prompt").await.unwrap();
460 }
461
462 #[derive(Debug, serde::Deserialize, PartialEq)]
463 struct TestAnswer {
464 value: i32,
465 }
466
467 fn structured_success_output() -> Output {
468 Output {
469 status: ExitStatus::from_raw(0),
470 stdout: include_bytes!("../tests/fixtures/structured_success.json").to_vec(),
471 stderr: Vec::new(),
472 }
473 }
474
475 #[tokio::test]
476 async fn ask_structured_success() {
477 let mut mock = MockCommandRunner::new();
478 mock.expect_run()
479 .returning(|_| Ok(structured_success_output()));
480
481 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
482 let answer: TestAnswer = client.ask_structured("What is 6*7?").await.unwrap();
483 assert_eq!(answer, TestAnswer { value: 42 });
484 }
485
486 #[tokio::test]
487 async fn ask_structured_deserialization_error() {
488 let mut mock = MockCommandRunner::new();
489 mock.expect_run().returning(|_| Ok(success_output()));
490
491 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
492 let err = client
493 .ask_structured::<TestAnswer>("hello")
494 .await
495 .unwrap_err();
496 assert!(matches!(err, ClaudeError::StructuredOutputError { .. }));
497 }
498
499 #[tokio::test]
500 async fn ask_structured_cli_error() {
501 let mut mock = MockCommandRunner::new();
502 mock.expect_run().returning(|_| Ok(non_zero_output()));
503
504 let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
505 let err = client
506 .ask_structured::<TestAnswer>("hello")
507 .await
508 .unwrap_err();
509 assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
510 }
511}