Skip to main content

claude_codes/
client_async.rs

1//! Asynchronous client for Claude communication
2
3use crate::cli::ClaudeCliBuilder;
4use crate::error::{Error, Result};
5use crate::io::{
6    ClaudeInput, ClaudeOutput, ContentBlock, ControlRequestMessage, ControlResponse,
7    ControlResponseMessage,
8};
9use crate::protocol::Protocol;
10use log::{debug, error, info, warn};
11use serde::{Deserialize, Serialize};
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufReader as AsyncBufReader};
13use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
14use uuid::Uuid;
15
16/// Asynchronous client for communicating with Claude
17pub struct AsyncClient {
18    child: Child,
19    stdin: ChildStdin,
20    stdout: BufReader<ChildStdout>,
21    stderr: Option<BufReader<ChildStderr>>,
22    session_uuid: Option<Uuid>,
23    /// Whether tool approval protocol has been initialized
24    tool_approval_enabled: bool,
25}
26
27/// Buffer size for reading Claude's stdout (10MB).
28const STDOUT_BUFFER_SIZE: usize = 10 * 1024 * 1024;
29
30impl AsyncClient {
31    /// Create a new async client from a tokio Child process
32    pub fn new(mut child: Child) -> Result<Self> {
33        let stdin = child
34            .stdin
35            .take()
36            .ok_or_else(|| Error::Io(std::io::Error::other("Failed to get stdin handle")))?;
37
38        let stdout = BufReader::with_capacity(
39            STDOUT_BUFFER_SIZE,
40            child
41                .stdout
42                .take()
43                .ok_or_else(|| Error::Io(std::io::Error::other("Failed to get stdout handle")))?,
44        );
45
46        let stderr = child.stderr.take().map(BufReader::new);
47
48        Ok(Self {
49            child,
50            stdin,
51            stdout,
52            stderr,
53            session_uuid: None,
54            tool_approval_enabled: false,
55        })
56    }
57
58    /// Create a client with default settings (using logic from start_claude)
59    pub async fn with_defaults() -> Result<Self> {
60        // Check Claude version (only warns once per session)
61        // NOTE: The claude-codes API is in high flux. If you wish to work around
62        // this version check, you can use AsyncClient::new() directly with:
63        //   let child = ClaudeCliBuilder::new().model("sonnet").spawn().await?;
64        //   AsyncClient::new(child)
65        crate::version::check_claude_version_async().await?;
66        Self::with_model("sonnet").await
67    }
68
69    /// Create a client with a specific model
70    pub async fn with_model(model: &str) -> Result<Self> {
71        let child = ClaudeCliBuilder::new().model(model).spawn().await?;
72
73        info!("Started Claude process with model: {}", model);
74        Self::new(child)
75    }
76
77    /// Create a client from a custom builder
78    pub async fn from_builder(builder: ClaudeCliBuilder) -> Result<Self> {
79        let child = builder.spawn().await?;
80        info!("Started Claude process from custom builder");
81        Self::new(child)
82    }
83
84    /// Resume a previous session by UUID
85    /// This creates a new client that resumes an existing session
86    pub async fn resume_session(session_uuid: Uuid) -> Result<Self> {
87        let child = ClaudeCliBuilder::new()
88            .resume(Some(session_uuid.to_string()))
89            .spawn()
90            .await?;
91
92        info!("Resuming Claude session with UUID: {}", session_uuid);
93        let mut client = Self::new(child)?;
94        // Pre-populate the session UUID since we're resuming
95        client.session_uuid = Some(session_uuid);
96        Ok(client)
97    }
98
99    /// Resume a previous session with a specific model
100    pub async fn resume_session_with_model(session_uuid: Uuid, model: &str) -> Result<Self> {
101        let child = ClaudeCliBuilder::new()
102            .model(model)
103            .resume(Some(session_uuid.to_string()))
104            .spawn()
105            .await?;
106
107        info!(
108            "Resuming Claude session with UUID: {} and model: {}",
109            session_uuid, model
110        );
111        let mut client = Self::new(child)?;
112        // Pre-populate the session UUID since we're resuming
113        client.session_uuid = Some(session_uuid);
114        Ok(client)
115    }
116
117    /// Send a query and collect all responses until Result message
118    /// This is the simplified version that collects all responses
119    pub async fn query(&mut self, text: &str) -> Result<Vec<ClaudeOutput>> {
120        let session_id = Uuid::new_v4();
121        self.query_with_session(text, session_id).await
122    }
123
124    /// Send a query with a custom session ID and collect all responses
125    pub async fn query_with_session(
126        &mut self,
127        text: &str,
128        session_id: Uuid,
129    ) -> Result<Vec<ClaudeOutput>> {
130        // Send the query
131        let input = ClaudeInput::user_message(text, session_id);
132        self.send(&input).await?;
133
134        // Collect responses until we get a Result message
135        let mut responses = Vec::new();
136
137        loop {
138            let output = self.receive().await?;
139            let is_result = matches!(&output, ClaudeOutput::Result(_));
140            responses.push(output);
141
142            if is_result {
143                break;
144            }
145        }
146
147        Ok(responses)
148    }
149
150    /// Send a query and return an async iterator over responses
151    /// Returns a stream that yields ClaudeOutput until Result message is received
152    pub async fn query_stream(&mut self, text: &str) -> Result<ResponseStream<'_>> {
153        let session_id = Uuid::new_v4();
154        self.query_stream_with_session(text, session_id).await
155    }
156
157    /// Send a query with session ID and return an async iterator over responses
158    pub async fn query_stream_with_session(
159        &mut self,
160        text: &str,
161        session_id: Uuid,
162    ) -> Result<ResponseStream<'_>> {
163        // Send the query first
164        let input = ClaudeInput::user_message(text, session_id);
165        self.send(&input).await?;
166
167        // Return a stream that will read responses
168        Ok(ResponseStream {
169            client: self,
170            finished: false,
171        })
172    }
173
174    /// Send a ClaudeInput directly
175    pub async fn send(&mut self, input: &ClaudeInput) -> Result<()> {
176        let json_line = Protocol::serialize(input)?;
177        debug!("[OUTGOING] Sending JSON to Claude: {}", json_line.trim());
178
179        self.stdin
180            .write_all(json_line.as_bytes())
181            .await
182            .map_err(Error::Io)?;
183
184        self.stdin.flush().await.map_err(Error::Io)?;
185        Ok(())
186    }
187
188    /// Receive a single response from Claude.
189    ///
190    /// # Important: Polling Frequency
191    ///
192    /// This method should be polled frequently to prevent the OS pipe buffer from
193    /// filling up. Claude can emit very large JSON messages (hundreds of KB), and
194    /// if the pipe buffer overflows, data may be truncated.
195    ///
196    /// In a `tokio::select!` loop with other async operations, ensure `receive()`
197    /// is given priority or called frequently. For high-throughput scenarios,
198    /// consider spawning a dedicated task to drain stdout into an unbounded channel.
199    ///
200    /// # Returns
201    ///
202    /// - `Ok(ClaudeOutput)` - A parsed message from Claude
203    /// - `Err(Error::ConnectionClosed)` - Claude process has exited
204    /// - `Err(Error::Deserialization)` - Failed to parse the message
205    pub async fn receive(&mut self) -> Result<ClaudeOutput> {
206        let mut line = String::new();
207
208        loop {
209            line.clear();
210            let bytes_read = self.stdout.read_line(&mut line).await.map_err(Error::Io)?;
211
212            if bytes_read == 0 {
213                return Err(Error::ConnectionClosed);
214            }
215
216            let trimmed = line.trim();
217            if trimmed.is_empty() {
218                continue;
219            }
220
221            debug!("[INCOMING] Received JSON from Claude: {}", trimmed);
222
223            // Use the parse_json_tolerant method which handles ANSI escape codes
224            match ClaudeOutput::parse_json_tolerant(trimmed) {
225                Ok(output) => {
226                    debug!("[INCOMING] Parsed output type: {}", output.message_type());
227
228                    // Capture UUID from first response if not already set
229                    if self.session_uuid.is_none() {
230                        if let ClaudeOutput::Assistant(ref msg) = output {
231                            if let Some(ref uuid_str) = msg.uuid {
232                                if let Ok(uuid) = Uuid::parse_str(uuid_str) {
233                                    debug!("[INCOMING] Captured session UUID: {}", uuid);
234                                    self.session_uuid = Some(uuid);
235                                }
236                            }
237                        } else if let ClaudeOutput::Result(ref msg) = output {
238                            if let Some(ref uuid_str) = msg.uuid {
239                                if let Ok(uuid) = Uuid::parse_str(uuid_str) {
240                                    debug!("[INCOMING] Captured session UUID: {}", uuid);
241                                    self.session_uuid = Some(uuid);
242                                }
243                            }
244                        }
245                    }
246
247                    return Ok(output);
248                }
249                Err(parse_error) => {
250                    warn!("[INCOMING] Failed to deserialize message from Claude CLI. Please report this at https://github.com/meawoppl/rust-claude-codes/issues with the raw message below.");
251                    warn!("[INCOMING] Parse error: {}", parse_error);
252                    warn!("[INCOMING] Raw message: {}", trimmed);
253                    return Err(Error::Deserialization(format!(
254                        "{} (raw: {})",
255                        parse_error.error_message, trimmed
256                    )));
257                }
258            }
259        }
260    }
261
262    /// Check if the Claude process is still running
263    pub fn is_alive(&mut self) -> bool {
264        self.child.try_wait().ok().flatten().is_none()
265    }
266
267    /// Gracefully shutdown the client
268    pub async fn shutdown(mut self) -> Result<()> {
269        info!("Shutting down Claude process...");
270        self.child.kill().await.map_err(Error::Io)?;
271        Ok(())
272    }
273
274    /// Get the process ID
275    pub fn pid(&self) -> Option<u32> {
276        self.child.id()
277    }
278
279    /// Take the stderr reader (can only be called once)
280    pub fn take_stderr(&mut self) -> Option<BufReader<ChildStderr>> {
281        self.stderr.take()
282    }
283
284    /// Get the session UUID if available
285    /// Returns an error if no response has been received yet
286    pub fn session_uuid(&self) -> Result<Uuid> {
287        self.session_uuid.ok_or(Error::SessionNotInitialized)
288    }
289
290    /// Test if the Claude connection is working by sending a ping message
291    /// Returns true if Claude responds with "pong", false otherwise
292    pub async fn ping(&mut self) -> bool {
293        // Send a simple ping request
294        let ping_input = ClaudeInput::user_message(
295            "ping - respond with just the word 'pong' and nothing else",
296            self.session_uuid.unwrap_or_else(Uuid::new_v4),
297        );
298
299        // Try to send the ping
300        if let Err(e) = self.send(&ping_input).await {
301            debug!("Ping failed to send: {}", e);
302            return false;
303        }
304
305        // Try to receive responses until we get a result or error
306        let mut found_pong = false;
307        let mut message_count = 0;
308        const MAX_MESSAGES: usize = 10;
309
310        loop {
311            match self.receive().await {
312                Ok(output) => {
313                    message_count += 1;
314
315                    // Check if it's an assistant message containing "pong"
316                    if let ClaudeOutput::Assistant(msg) = &output {
317                        for content in &msg.message.content {
318                            if let ContentBlock::Text(text) = content {
319                                if text.text.to_lowercase().contains("pong") {
320                                    found_pong = true;
321                                }
322                            }
323                        }
324                    }
325
326                    // Stop on result message
327                    if matches!(output, ClaudeOutput::Result(_)) {
328                        break;
329                    }
330
331                    // Safety limit
332                    if message_count >= MAX_MESSAGES {
333                        debug!("Ping exceeded message limit");
334                        break;
335                    }
336                }
337                Err(e) => {
338                    debug!("Ping failed to receive response: {}", e);
339                    break;
340                }
341            }
342        }
343
344        found_pong
345    }
346
347    // =========================================================================
348    // Tool Approval Protocol
349    // =========================================================================
350
351    /// Enable the tool approval protocol by performing the initialization handshake.
352    ///
353    /// After calling this method, the CLI will send `ControlRequest` messages when
354    /// Claude wants to use a tool. You must handle these by calling
355    /// `send_control_response()` with an appropriate response.
356    ///
357    /// **Important**: The client must have been created with
358    /// `ClaudeCliBuilder::permission_prompt_tool("stdio")` for this to work.
359    ///
360    /// # Example
361    ///
362    /// ```no_run
363    /// use claude_codes::{AsyncClient, ClaudeCliBuilder, ClaudeOutput, ControlRequestPayload};
364    ///
365    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
366    /// let child = ClaudeCliBuilder::new()
367    ///     .model("sonnet")
368    ///     .permission_prompt_tool("stdio")
369    ///     .spawn()
370    ///     .await?;
371    ///
372    /// let mut client = AsyncClient::new(child)?;
373    /// client.enable_tool_approval().await?;
374    ///
375    /// // Now when you receive messages, you may get ControlRequest messages
376    /// // that need responses
377    /// # Ok(())
378    /// # }
379    /// ```
380    pub async fn enable_tool_approval(&mut self) -> Result<()> {
381        if self.tool_approval_enabled {
382            debug!("[TOOL_APPROVAL] Already enabled, skipping initialization");
383            return Ok(());
384        }
385
386        let request_id = format!("init-{}", Uuid::new_v4());
387        let init_request = ControlRequestMessage::initialize(&request_id);
388
389        debug!("[TOOL_APPROVAL] Sending initialization handshake");
390        let json_line = Protocol::serialize(&init_request)?;
391        self.stdin
392            .write_all(json_line.as_bytes())
393            .await
394            .map_err(Error::Io)?;
395        self.stdin.flush().await.map_err(Error::Io)?;
396
397        // Wait for the initialization response
398        loop {
399            let mut line = String::new();
400            let bytes_read = self.stdout.read_line(&mut line).await.map_err(Error::Io)?;
401
402            if bytes_read == 0 {
403                return Err(Error::ConnectionClosed);
404            }
405
406            let trimmed = line.trim();
407            if trimmed.is_empty() {
408                continue;
409            }
410
411            debug!("[TOOL_APPROVAL] Received: {}", trimmed);
412
413            // Try to parse as ClaudeOutput
414            match ClaudeOutput::parse_json_tolerant(trimmed) {
415                Ok(ClaudeOutput::ControlResponse(resp)) => {
416                    use crate::io::ControlResponsePayload;
417                    match &resp.response {
418                        ControlResponsePayload::Success {
419                            request_id: rid, ..
420                        } if rid == &request_id => {
421                            debug!("[TOOL_APPROVAL] Initialization successful");
422                            self.tool_approval_enabled = true;
423                            return Ok(());
424                        }
425                        ControlResponsePayload::Error { error, .. } => {
426                            return Err(Error::Protocol(format!(
427                                "Tool approval initialization failed: {}",
428                                error
429                            )));
430                        }
431                        _ => {
432                            // Different request_id, keep waiting
433                            continue;
434                        }
435                    }
436                }
437                Ok(_) => {
438                    // Got a different message type (system, etc.), keep waiting
439                    continue;
440                }
441                Err(e) => {
442                    return Err(Error::Deserialization(e.to_string()));
443                }
444            }
445        }
446    }
447
448    /// Send a control response back to the CLI.
449    ///
450    /// Use this to respond to `ControlRequest` messages received during tool approval.
451    /// The easiest way to create responses is using the helper methods on
452    /// `ToolPermissionRequest`:
453    ///
454    /// # Example
455    ///
456    /// ```no_run
457    /// use claude_codes::{AsyncClient, ClaudeOutput, ControlRequestPayload};
458    ///
459    /// # async fn example(client: &mut AsyncClient) -> Result<(), Box<dyn std::error::Error>> {
460    /// # let output = client.receive().await?;
461    /// if let ClaudeOutput::ControlRequest(req) = output {
462    ///     if let ControlRequestPayload::CanUseTool(perm_req) = &req.request {
463    ///         // Use the ergonomic helpers on ToolPermissionRequest
464    ///         let response = if perm_req.tool_name == "Bash" {
465    ///             perm_req.deny("Bash commands not allowed", &req.request_id)
466    ///         } else {
467    ///             perm_req.allow(&req.request_id)
468    ///         };
469    ///         client.send_control_response(response).await?;
470    ///     }
471    /// }
472    /// # Ok(())
473    /// # }
474    /// ```
475    pub async fn send_control_response(&mut self, response: ControlResponse) -> Result<()> {
476        let message: ControlResponseMessage = response.into();
477        let json_line = Protocol::serialize(&message)?;
478        debug!(
479            "[TOOL_APPROVAL] Sending control response: {}",
480            json_line.trim()
481        );
482
483        self.stdin
484            .write_all(json_line.as_bytes())
485            .await
486            .map_err(Error::Io)?;
487        self.stdin.flush().await.map_err(Error::Io)?;
488        Ok(())
489    }
490
491    /// Check if tool approval protocol is enabled
492    pub fn is_tool_approval_enabled(&self) -> bool {
493        self.tool_approval_enabled
494    }
495}
496
497/// A response stream that yields ClaudeOutput messages
498/// Holds a reference to the client to read from
499pub struct ResponseStream<'a> {
500    client: &'a mut AsyncClient,
501    finished: bool,
502}
503
504impl ResponseStream<'_> {
505    /// Convert to a vector by collecting all responses
506    pub async fn collect(mut self) -> Result<Vec<ClaudeOutput>> {
507        let mut responses = Vec::new();
508
509        while !self.finished {
510            let output = self.client.receive().await?;
511            let is_result = matches!(&output, ClaudeOutput::Result(_));
512            responses.push(output);
513
514            if is_result {
515                self.finished = true;
516                break;
517            }
518        }
519
520        Ok(responses)
521    }
522
523    /// Get the next response
524    pub async fn next(&mut self) -> Option<Result<ClaudeOutput>> {
525        if self.finished {
526            return None;
527        }
528
529        match self.client.receive().await {
530            Ok(output) => {
531                if matches!(&output, ClaudeOutput::Result(_)) {
532                    self.finished = true;
533                }
534                Some(Ok(output))
535            }
536            Err(e) => {
537                self.finished = true;
538                Some(Err(e))
539            }
540        }
541    }
542}
543
544impl Drop for AsyncClient {
545    fn drop(&mut self) {
546        if self.is_alive() {
547            // Try to kill the process
548            if let Err(e) = self.child.start_kill() {
549                error!("Failed to kill Claude process on drop: {}", e);
550            }
551        }
552    }
553}
554
555// Protocol extension methods for asynchronous I/O
556impl Protocol {
557    /// Write a message to an async writer
558    pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
559        writer: &mut W,
560        message: &T,
561    ) -> Result<()> {
562        let line = Self::serialize(message)?;
563        debug!("[PROTOCOL] Sending async: {}", line.trim());
564        writer.write_all(line.as_bytes()).await?;
565        writer.flush().await?;
566        Ok(())
567    }
568
569    /// Read a message from an async reader
570    pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
571        reader: &mut R,
572    ) -> Result<T> {
573        let mut line = String::new();
574        let bytes_read = reader.read_line(&mut line).await?;
575        if bytes_read == 0 {
576            return Err(Error::ConnectionClosed);
577        }
578        debug!("[PROTOCOL] Received async: {}", line.trim());
579        Self::deserialize(&line)
580    }
581}
582
583/// Async stream processor for handling continuous message streams
584pub struct AsyncStreamProcessor<R> {
585    reader: AsyncBufReader<R>,
586}
587
588impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
589    /// Create a new async stream processor
590    pub fn new(reader: R) -> Self {
591        Self {
592            reader: AsyncBufReader::new(reader),
593        }
594    }
595
596    /// Process the next message from the stream
597    pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
598        Protocol::read_async(&mut self.reader).await
599    }
600
601    /// Process all messages in the stream
602    pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
603    where
604        T: for<'de> Deserialize<'de>,
605        F: FnMut(T) -> Fut,
606        Fut: std::future::Future<Output = Result<()>>,
607    {
608        loop {
609            match self.next_message().await {
610                Ok(message) => handler(message).await?,
611                Err(Error::ConnectionClosed) => break,
612                Err(e) => return Err(e),
613            }
614        }
615        Ok(())
616    }
617}