Skip to main content

claude_code/
client.rs

1/// Conditional tracing macros — compile to nothing when the `tracing` feature is disabled.
2macro_rules! trace_debug {
3    ($($arg:tt)*) => {
4        #[cfg(feature = "tracing")]
5        tracing::debug!($($arg)*);
6    };
7}
8macro_rules! trace_error {
9    ($($arg:tt)*) => {
10        #[cfg(feature = "tracing")]
11        tracing::error!($($arg)*);
12    };
13}
14macro_rules! trace_info {
15    ($($arg:tt)*) => {
16        #[cfg(feature = "tracing")]
17        tracing::info!($($arg)*);
18    };
19}
20
21#[cfg(test)]
22use mockall::automock;
23
24use std::process::Output;
25
26use tokio::process::Command as TokioCommand;
27
28use crate::config::ClaudeConfig;
29use crate::conversation::Conversation;
30use crate::error::ClaudeError;
31use crate::types::{ClaudeResponse, strip_ansi};
32
33#[cfg(feature = "stream")]
34use crate::stream::{StreamEvent, parse_stream};
35#[cfg(feature = "stream")]
36use std::pin::Pin;
37#[cfg(feature = "stream")]
38use tokio::io::BufReader;
39#[cfg(feature = "stream")]
40use tokio_stream::Stream;
41
42/// Trait abstracting CLI execution. Mockable in tests.
43#[allow(async_fn_in_trait)]
44#[cfg_attr(test, automock)]
45pub trait CommandRunner: Send + Sync {
46    /// Runs the `claude` command with the given arguments.
47    async fn run(&self, args: &[String]) -> std::io::Result<Output>;
48}
49
50/// Runs `claude` via `tokio::process::Command`.
51#[derive(Debug, Clone)]
52pub struct DefaultRunner {
53    cli_path: String,
54}
55
56impl DefaultRunner {
57    /// Creates a runner with a custom CLI binary path.
58    #[must_use]
59    pub fn new(cli_path: impl Into<String>) -> Self {
60        Self {
61            cli_path: cli_path.into(),
62        }
63    }
64}
65
66impl Default for DefaultRunner {
67    fn default() -> Self {
68        Self {
69            cli_path: "claude".into(),
70        }
71    }
72}
73
74impl CommandRunner for DefaultRunner {
75    async fn run(&self, args: &[String]) -> std::io::Result<Output> {
76        TokioCommand::new(&self.cli_path).args(args).output().await
77    }
78}
79
80/// RAII guard that kills the child process on drop.
81///
82/// tokio's `Child` does NOT kill the process on drop — it detaches.
83/// This guard ensures the CLI subprocess is killed when the stream
84/// is dropped (e.g., client disconnection).
85#[cfg(feature = "stream")]
86struct ChildGuard(Option<tokio::process::Child>);
87
88#[cfg(feature = "stream")]
89impl Drop for ChildGuard {
90    fn drop(&mut self) {
91        if let Some(ref mut child) = self.0 {
92            let _ = child.start_kill();
93        }
94    }
95}
96
97/// Claude Code CLI client.
98#[derive(Debug, Clone)]
99pub struct ClaudeClient<R: CommandRunner = DefaultRunner> {
100    config: ClaudeConfig,
101    runner: R,
102}
103
104impl ClaudeClient {
105    /// Creates a new client with the default [`DefaultRunner`].
106    #[must_use]
107    pub fn new(config: ClaudeConfig) -> Self {
108        let runner = DefaultRunner::new(config.cli_path_or_default());
109        Self { config, runner }
110    }
111}
112
113#[cfg(feature = "stream")]
114#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
115impl ClaudeClient {
116    /// Sends a prompt and returns a stream of events.
117    ///
118    /// Spawns the CLI with `--output-format stream-json` and streams events
119    /// in real-time. The stream ends with a [`StreamEvent::Result`] on success.
120    ///
121    /// For real-time token-level streaming, enable
122    /// [`crate::ClaudeConfigBuilder::include_partial_messages`]. This produces
123    /// [`StreamEvent::Text`] / [`StreamEvent::Thinking`] delta chunks.
124    /// Without it, only complete [`StreamEvent::AssistantText`] /
125    /// [`StreamEvent::AssistantThinking`] messages are emitted.
126    ///
127    /// Use [`crate::ClaudeConfigBuilder::stream_idle_timeout`] to set an idle timeout.
128    /// If no event arrives within the specified duration, the stream yields
129    /// [`ClaudeError::Timeout`] and terminates.
130    pub async fn ask_stream(
131        &self,
132        prompt: &str,
133    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent, ClaudeError>> + Send>>, ClaudeError>
134    {
135        let args = self.config.to_stream_args(prompt);
136
137        trace_debug!(args = ?args, "spawning claude CLI stream");
138
139        let mut child = TokioCommand::new(self.config.cli_path_or_default())
140            .args(&args)
141            .stdout(std::process::Stdio::piped())
142            .stderr(std::process::Stdio::piped())
143            .spawn()
144            .map_err(|e| {
145                if e.kind() == std::io::ErrorKind::NotFound {
146                    ClaudeError::CliNotFound
147                } else {
148                    ClaudeError::Io(e)
149                }
150            })?;
151
152        let stdout = child.stdout.take().expect("stdout must be piped");
153        let reader = BufReader::new(stdout);
154        let event_stream = parse_stream(reader);
155        let mut guard = ChildGuard(Some(child));
156        let idle_timeout = self.config.stream_idle_timeout;
157
158        Ok(Box::pin(async_stream::stream! {
159            tokio::pin!(event_stream);
160
161            loop {
162                let next = tokio_stream::StreamExt::next(&mut event_stream);
163                let maybe_event = if let Some(timeout_dur) = idle_timeout {
164                    match tokio::time::timeout(timeout_dur, next).await {
165                        Ok(Some(event)) => Some(event),
166                        Ok(None) => None,
167                        Err(_) => {
168                            trace_error!("stream idle timeout");
169                            yield Err(ClaudeError::Timeout);
170                            return;
171                        }
172                    }
173                } else {
174                    next.await
175                };
176
177                match maybe_event {
178                    Some(event) => yield Ok(event),
179                    None => break,
180                }
181            }
182
183            // Take child out of guard to wait for exit status.
184            // If stream is dropped before reaching here, guard's Drop kills the process.
185            if let Some(mut child) = guard.0.take() {
186                let status = child.wait().await;
187                match status {
188                    Ok(s) if !s.success() => {
189                        let code = s.code().unwrap_or(-1);
190                        let mut stderr_buf = Vec::new();
191                        if let Some(mut stderr) = child.stderr.take() {
192                            let _ = tokio::io::AsyncReadExt::read_to_end(&mut stderr, &mut stderr_buf).await;
193                        }
194                        let stderr_str = String::from_utf8_lossy(&stderr_buf).into_owned();
195                        yield Err(ClaudeError::NonZeroExit { code, stderr: stderr_str });
196                    }
197                    Err(e) => {
198                        yield Err(ClaudeError::Io(e));
199                    }
200                    Ok(_) => {}
201                }
202            }
203        }))
204    }
205}
206
207impl<R: CommandRunner> ClaudeClient<R> {
208    /// Creates a new client with a custom [`CommandRunner`] for testing.
209    #[must_use]
210    pub fn with_runner(config: ClaudeConfig, runner: R) -> Self {
211        Self { config, runner }
212    }
213
214    /// Sends a prompt and deserializes the result into `T`.
215    ///
216    /// Requires `json_schema` to be set on the config beforehand.
217    /// Use [`generate_schema`](crate::generate_schema) to auto-generate it
218    /// (requires the `structured` feature).
219    pub async fn ask_structured<T: serde::de::DeserializeOwned>(
220        &self,
221        prompt: &str,
222    ) -> Result<T, ClaudeError> {
223        let response = self.ask(prompt).await?;
224        response.parse_result()
225    }
226
227    /// Sends a prompt and returns the response.
228    pub async fn ask(&self, prompt: &str) -> Result<ClaudeResponse, ClaudeError> {
229        let args = self.config.to_args(prompt);
230
231        trace_debug!(args = ?args, "executing claude CLI");
232
233        let io_result: std::io::Result<Output> = if let Some(timeout) = self.config.timeout {
234            tokio::time::timeout(timeout, self.runner.run(&args))
235                .await
236                .map_err(|_| {
237                    let err = ClaudeError::Timeout;
238                    trace_error!(error = %err, "claude CLI failed");
239                    err
240                })?
241        } else {
242            self.runner.run(&args).await
243        };
244
245        let output = io_result.map_err(|e| {
246            let err = if e.kind() == std::io::ErrorKind::NotFound {
247                ClaudeError::CliNotFound
248            } else {
249                ClaudeError::Io(e)
250            };
251            trace_error!(error = %err, "claude CLI failed");
252            err
253        })?;
254
255        if !output.status.success() {
256            let code = output.status.code().unwrap_or(-1);
257            let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
258            let err = ClaudeError::NonZeroExit { code, stderr };
259            trace_error!(error = %err, "claude CLI failed");
260            return Err(err);
261        }
262
263        let stdout = String::from_utf8_lossy(&output.stdout);
264        let json_str = strip_ansi(&stdout);
265        let response: ClaudeResponse = serde_json::from_str(json_str).map_err(|e| {
266            let err = ClaudeError::ParseError(e);
267            trace_error!(error = %err, "claude CLI failed");
268            err
269        })?;
270
271        trace_info!("claude CLI returned successfully");
272        Ok(response)
273    }
274}
275
276impl<R: CommandRunner + Clone> ClaudeClient<R> {
277    /// Creates a new [`Conversation`] for multi-turn interaction.
278    ///
279    /// The conversation manages `session_id` automatically, injecting
280    /// `--resume` from the second turn onwards.
281    ///
282    /// Callers must set [`crate::ClaudeConfigBuilder::no_session_persistence`]`(false)`
283    /// for multi-turn to work.
284    #[must_use]
285    pub fn conversation(&self) -> Conversation<R> {
286        Conversation::with_runner(self.config.clone(), self.runner.clone())
287    }
288
289    /// Creates a [`Conversation`] that resumes an existing session.
290    ///
291    /// The first `ask()` / `ask_stream()` call will include `--resume`
292    /// with the given session ID.
293    #[must_use]
294    pub fn conversation_resume(&self, session_id: impl Into<String>) -> Conversation<R> {
295        Conversation::with_runner_resume(self.config.clone(), self.runner.clone(), session_id)
296    }
297}
298
299/// Checks that the `claude` CLI is available and returns its version string.
300///
301/// Runs `claude --version` and returns the trimmed stdout on success.
302/// To check a binary at a custom path, use [`check_cli_with_path`].
303///
304/// # Errors
305///
306/// - [`ClaudeError::CliNotFound`] if `claude` is not in PATH.
307/// - [`ClaudeError::NonZeroExit`] if the command fails.
308/// - [`ClaudeError::Io`] for other I/O errors.
309pub async fn check_cli() -> Result<String, ClaudeError> {
310    check_cli_with_path("claude").await
311}
312
313/// Checks that the CLI at the given path is available and returns its version string.
314///
315/// Runs `<cli_path> --version` and returns the trimmed stdout on success.
316///
317/// # Errors
318///
319/// - [`ClaudeError::CliNotFound`] if the binary is not found.
320/// - [`ClaudeError::NonZeroExit`] if the command fails.
321/// - [`ClaudeError::Io`] for other I/O errors.
322pub async fn check_cli_with_path(cli_path: &str) -> Result<String, ClaudeError> {
323    let output = TokioCommand::new(cli_path)
324        .arg("--version")
325        .output()
326        .await
327        .map_err(|e| {
328            if e.kind() == std::io::ErrorKind::NotFound {
329                ClaudeError::CliNotFound
330            } else {
331                ClaudeError::Io(e)
332            }
333        })?;
334
335    if !output.status.success() {
336        let code = output.status.code().unwrap_or(-1);
337        let stderr = String::from_utf8_lossy(&output.stderr).into_owned();
338        return Err(ClaudeError::NonZeroExit { code, stderr });
339    }
340
341    let version = String::from_utf8_lossy(&output.stdout).trim().to_string();
342    Ok(version)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348    use std::os::unix::process::ExitStatusExt;
349    use std::process::ExitStatus;
350
351    fn success_output() -> Output {
352        Output {
353            status: ExitStatus::from_raw(0),
354            stdout: include_bytes!("../tests/fixtures/success.json").to_vec(),
355            stderr: Vec::new(),
356        }
357    }
358
359    fn non_zero_output() -> Output {
360        Output {
361            status: ExitStatus::from_raw(256), // exit code 1
362            stdout: Vec::new(),
363            stderr: b"something went wrong".to_vec(),
364        }
365    }
366
367    #[tokio::test]
368    async fn ask_success() {
369        let mut mock = MockCommandRunner::new();
370        mock.expect_run().returning(|_| Ok(success_output()));
371
372        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
373        let resp = client.ask("hello").await.unwrap();
374        assert_eq!(resp.result, "Hello!");
375        assert!(!resp.is_error);
376    }
377
378    #[tokio::test]
379    async fn ask_cli_not_found() {
380        let mut mock = MockCommandRunner::new();
381        mock.expect_run().returning(|_| {
382            Err(std::io::Error::new(
383                std::io::ErrorKind::NotFound,
384                "not found",
385            ))
386        });
387
388        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
389        let err = client.ask("hello").await.unwrap_err();
390        assert!(matches!(err, ClaudeError::CliNotFound));
391    }
392
393    #[tokio::test]
394    async fn ask_non_zero_exit() {
395        let mut mock = MockCommandRunner::new();
396        mock.expect_run().returning(|_| Ok(non_zero_output()));
397
398        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
399        let err = client.ask("hello").await.unwrap_err();
400        assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
401    }
402
403    #[tokio::test]
404    async fn ask_parse_error() {
405        let mut mock = MockCommandRunner::new();
406        mock.expect_run().returning(|_| {
407            Ok(Output {
408                status: ExitStatus::from_raw(0),
409                stdout: b"not json".to_vec(),
410                stderr: Vec::new(),
411            })
412        });
413
414        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
415        let err = client.ask("hello").await.unwrap_err();
416        assert!(matches!(err, ClaudeError::ParseError(_)));
417    }
418
419    /// Custom CommandRunner that always sleeps (for timeout tests).
420    struct SlowRunner;
421
422    impl CommandRunner for SlowRunner {
423        async fn run(&self, _args: &[String]) -> std::io::Result<Output> {
424            tokio::time::sleep(std::time::Duration::from_secs(10)).await;
425            Ok(Output {
426                status: std::os::unix::process::ExitStatusExt::from_raw(0),
427                stdout: Vec::new(),
428                stderr: Vec::new(),
429            })
430        }
431    }
432
433    #[tokio::test(start_paused = true)]
434    async fn ask_timeout() {
435        let config = ClaudeConfig::builder()
436            .timeout(std::time::Duration::from_millis(10))
437            .build();
438        let client = ClaudeClient::with_runner(config, SlowRunner);
439        let err = client.ask("hello").await.unwrap_err();
440        assert!(matches!(err, ClaudeError::Timeout));
441    }
442
443    #[tokio::test]
444    async fn ask_io_error() {
445        let mut mock = MockCommandRunner::new();
446        mock.expect_run().returning(|_| {
447            Err(std::io::Error::new(
448                std::io::ErrorKind::PermissionDenied,
449                "denied",
450            ))
451        });
452
453        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
454        let err = client.ask("hello").await.unwrap_err();
455        assert!(matches!(err, ClaudeError::Io(_)));
456    }
457
458    #[tokio::test]
459    async fn ask_with_ansi_escape() {
460        let json = include_str!("../tests/fixtures/success.json");
461        let stdout = format!("\x1b[?1004l{json}\x1b[?1004l");
462
463        let mut mock = MockCommandRunner::new();
464        mock.expect_run().returning(move |_| {
465            Ok(Output {
466                status: ExitStatus::from_raw(0),
467                stdout: stdout.clone().into_bytes(),
468                stderr: Vec::new(),
469            })
470        });
471
472        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
473        let resp = client.ask("hello").await.unwrap();
474        assert_eq!(resp.result, "Hello!");
475    }
476
477    #[tokio::test]
478    async fn ask_passes_correct_args() {
479        let mut mock = MockCommandRunner::new();
480        mock.expect_run()
481            .withf(|args| {
482                args.contains(&"--print".to_string())
483                    && args.contains(&"--model".to_string())
484                    && args.contains(&"haiku".to_string())
485                    && args.last() == Some(&"test prompt".to_string())
486            })
487            .returning(|_| Ok(success_output()));
488
489        let config = ClaudeConfig::builder().model("haiku").build();
490        let client = ClaudeClient::with_runner(config, mock);
491        client.ask("test prompt").await.unwrap();
492    }
493
494    #[derive(Debug, serde::Deserialize, PartialEq)]
495    struct TestAnswer {
496        value: i32,
497    }
498
499    fn structured_success_output() -> Output {
500        Output {
501            status: ExitStatus::from_raw(0),
502            stdout: include_bytes!("../tests/fixtures/structured_success.json").to_vec(),
503            stderr: Vec::new(),
504        }
505    }
506
507    #[tokio::test]
508    async fn ask_structured_success() {
509        let mut mock = MockCommandRunner::new();
510        mock.expect_run()
511            .returning(|_| Ok(structured_success_output()));
512
513        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
514        let answer: TestAnswer = client.ask_structured("What is 6*7?").await.unwrap();
515        assert_eq!(answer, TestAnswer { value: 42 });
516    }
517
518    #[tokio::test]
519    async fn ask_structured_deserialization_error() {
520        let mut mock = MockCommandRunner::new();
521        mock.expect_run().returning(|_| Ok(success_output()));
522
523        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
524        let err = client
525            .ask_structured::<TestAnswer>("hello")
526            .await
527            .unwrap_err();
528        assert!(matches!(err, ClaudeError::StructuredOutputError { .. }));
529    }
530
531    #[tokio::test]
532    async fn ask_structured_cli_error() {
533        let mut mock = MockCommandRunner::new();
534        mock.expect_run().returning(|_| Ok(non_zero_output()));
535
536        let client = ClaudeClient::with_runner(ClaudeConfig::default(), mock);
537        let err = client
538            .ask_structured::<TestAnswer>("hello")
539            .await
540            .unwrap_err();
541        assert!(matches!(err, ClaudeError::NonZeroExit { code: 1, .. }));
542    }
543
544    /// Verifies that shell metacharacters in `cli_path` are not interpreted.
545    ///
546    /// `Command::new()` uses `execvp` directly (no shell), so a path like
547    /// `"claude; echo pwned"` is treated as a literal filename lookup and
548    /// fails with `NotFound` — not as a shell command.
549    #[tokio::test]
550    async fn cli_path_with_shell_metacharacters_is_not_interpreted() {
551        let malicious = "claude; echo pwned";
552        let err = check_cli_with_path(malicious).await.unwrap_err();
553        assert!(matches!(err, ClaudeError::CliNotFound));
554    }
555
556    #[tokio::test]
557    async fn cli_path_with_command_substitution_is_not_interpreted() {
558        let malicious = "$(echo claude)";
559        let err = check_cli_with_path(malicious).await.unwrap_err();
560        assert!(matches!(err, ClaudeError::CliNotFound));
561    }
562}