claude_codes/
client_async.rs

1//! Asynchronous client for Claude communication
2
3use crate::cli::ClaudeCliBuilder;
4use crate::error::{Error, Result};
5use crate::io::{
6    ClaudeInput, ClaudeOutput, ContentBlock, ControlRequestMessage, ControlResponse,
7    ControlResponseMessage,
8};
9use crate::protocol::Protocol;
10use log::{debug, error, info};
11use serde::{Deserialize, Serialize};
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufReader as AsyncBufReader};
13use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
14use uuid::Uuid;
15
16/// Asynchronous client for communicating with Claude
17pub struct AsyncClient {
18    child: Child,
19    stdin: ChildStdin,
20    stdout: BufReader<ChildStdout>,
21    stderr: Option<BufReader<ChildStderr>>,
22    session_uuid: Option<Uuid>,
23    /// Whether tool approval protocol has been initialized
24    tool_approval_enabled: bool,
25}
26
27impl AsyncClient {
28    /// Create a new async client from a tokio Child process
29    pub fn new(mut child: Child) -> Result<Self> {
30        let stdin = child
31            .stdin
32            .take()
33            .ok_or_else(|| Error::Io(std::io::Error::other("Failed to get stdin handle")))?;
34
35        let stdout = BufReader::new(
36            child
37                .stdout
38                .take()
39                .ok_or_else(|| Error::Io(std::io::Error::other("Failed to get stdout handle")))?,
40        );
41
42        let stderr = child.stderr.take().map(BufReader::new);
43
44        Ok(Self {
45            child,
46            stdin,
47            stdout,
48            stderr,
49            session_uuid: None,
50            tool_approval_enabled: false,
51        })
52    }
53
54    /// Create a client with default settings (using logic from start_claude)
55    pub async fn with_defaults() -> Result<Self> {
56        // Check Claude version (only warns once per session)
57        // NOTE: The claude-codes API is in high flux. If you wish to work around
58        // this version check, you can use AsyncClient::new() directly with:
59        //   let child = ClaudeCliBuilder::new().model("sonnet").spawn().await?;
60        //   AsyncClient::new(child)
61        crate::version::check_claude_version_async().await?;
62        Self::with_model("sonnet").await
63    }
64
65    /// Create a client with a specific model
66    pub async fn with_model(model: &str) -> Result<Self> {
67        let child = ClaudeCliBuilder::new().model(model).spawn().await?;
68
69        info!("Started Claude process with model: {}", model);
70        Self::new(child)
71    }
72
73    /// Create a client from a custom builder
74    pub async fn from_builder(builder: ClaudeCliBuilder) -> Result<Self> {
75        let child = builder.spawn().await?;
76        info!("Started Claude process from custom builder");
77        Self::new(child)
78    }
79
80    /// Resume a previous session by UUID
81    /// This creates a new client that resumes an existing session
82    pub async fn resume_session(session_uuid: Uuid) -> Result<Self> {
83        let child = ClaudeCliBuilder::new()
84            .resume(Some(session_uuid.to_string()))
85            .spawn()
86            .await?;
87
88        info!("Resuming Claude session with UUID: {}", session_uuid);
89        let mut client = Self::new(child)?;
90        // Pre-populate the session UUID since we're resuming
91        client.session_uuid = Some(session_uuid);
92        Ok(client)
93    }
94
95    /// Resume a previous session with a specific model
96    pub async fn resume_session_with_model(session_uuid: Uuid, model: &str) -> Result<Self> {
97        let child = ClaudeCliBuilder::new()
98            .model(model)
99            .resume(Some(session_uuid.to_string()))
100            .spawn()
101            .await?;
102
103        info!(
104            "Resuming Claude session with UUID: {} and model: {}",
105            session_uuid, model
106        );
107        let mut client = Self::new(child)?;
108        // Pre-populate the session UUID since we're resuming
109        client.session_uuid = Some(session_uuid);
110        Ok(client)
111    }
112
113    /// Send a query and collect all responses until Result message
114    /// This is the simplified version that collects all responses
115    pub async fn query(&mut self, text: &str) -> Result<Vec<ClaudeOutput>> {
116        let session_id = Uuid::new_v4();
117        self.query_with_session(text, session_id).await
118    }
119
120    /// Send a query with a custom session ID and collect all responses
121    pub async fn query_with_session(
122        &mut self,
123        text: &str,
124        session_id: Uuid,
125    ) -> Result<Vec<ClaudeOutput>> {
126        // Send the query
127        let input = ClaudeInput::user_message(text, session_id);
128        self.send(&input).await?;
129
130        // Collect responses until we get a Result message
131        let mut responses = Vec::new();
132
133        loop {
134            let output = self.receive().await?;
135            let is_result = matches!(&output, ClaudeOutput::Result(_));
136            responses.push(output);
137
138            if is_result {
139                break;
140            }
141        }
142
143        Ok(responses)
144    }
145
146    /// Send a query and return an async iterator over responses
147    /// Returns a stream that yields ClaudeOutput until Result message is received
148    pub async fn query_stream(&mut self, text: &str) -> Result<ResponseStream<'_>> {
149        let session_id = Uuid::new_v4();
150        self.query_stream_with_session(text, session_id).await
151    }
152
153    /// Send a query with session ID and return an async iterator over responses
154    pub async fn query_stream_with_session(
155        &mut self,
156        text: &str,
157        session_id: Uuid,
158    ) -> Result<ResponseStream<'_>> {
159        // Send the query first
160        let input = ClaudeInput::user_message(text, session_id);
161        self.send(&input).await?;
162
163        // Return a stream that will read responses
164        Ok(ResponseStream {
165            client: self,
166            finished: false,
167        })
168    }
169
170    /// Send a ClaudeInput directly
171    pub async fn send(&mut self, input: &ClaudeInput) -> Result<()> {
172        let json_line = Protocol::serialize(input)?;
173        debug!("[OUTGOING] Sending JSON to Claude: {}", json_line.trim());
174
175        self.stdin
176            .write_all(json_line.as_bytes())
177            .await
178            .map_err(Error::Io)?;
179
180        self.stdin.flush().await.map_err(Error::Io)?;
181        Ok(())
182    }
183
184    /// Try to receive a single response
185    pub async fn receive(&mut self) -> Result<ClaudeOutput> {
186        let mut line = String::new();
187
188        loop {
189            line.clear();
190            let bytes_read = self.stdout.read_line(&mut line).await.map_err(Error::Io)?;
191
192            if bytes_read == 0 {
193                return Err(Error::ConnectionClosed);
194            }
195
196            let trimmed = line.trim();
197            if trimmed.is_empty() {
198                continue;
199            }
200
201            debug!("[INCOMING] Received JSON from Claude: {}", trimmed);
202
203            // Use the parse_json_tolerant method which handles ANSI escape codes
204            match ClaudeOutput::parse_json_tolerant(trimmed) {
205                Ok(output) => {
206                    debug!("[INCOMING] Parsed output type: {}", output.message_type());
207
208                    // Capture UUID from first response if not already set
209                    if self.session_uuid.is_none() {
210                        if let ClaudeOutput::Assistant(ref msg) = output {
211                            if let Some(ref uuid_str) = msg.uuid {
212                                if let Ok(uuid) = Uuid::parse_str(uuid_str) {
213                                    debug!("[INCOMING] Captured session UUID: {}", uuid);
214                                    self.session_uuid = Some(uuid);
215                                }
216                            }
217                        } else if let ClaudeOutput::Result(ref msg) = output {
218                            if let Some(ref uuid_str) = msg.uuid {
219                                if let Ok(uuid) = Uuid::parse_str(uuid_str) {
220                                    debug!("[INCOMING] Captured session UUID: {}", uuid);
221                                    self.session_uuid = Some(uuid);
222                                }
223                            }
224                        }
225                    }
226
227                    return Ok(output);
228                }
229                Err(parse_error) => {
230                    error!("[INCOMING] Failed to deserialize: {}", parse_error);
231                    error!("[INCOMING] Raw JSON that failed: {}", trimmed);
232                    // Convert ParseError to our Error type
233                    return Err(Error::Deserialization(format!(
234                        "{} (raw: {})",
235                        parse_error.error_message, trimmed
236                    )));
237                }
238            }
239        }
240    }
241
242    /// Check if the Claude process is still running
243    pub fn is_alive(&mut self) -> bool {
244        self.child.try_wait().ok().flatten().is_none()
245    }
246
247    /// Gracefully shutdown the client
248    pub async fn shutdown(mut self) -> Result<()> {
249        info!("Shutting down Claude process...");
250        self.child.kill().await.map_err(Error::Io)?;
251        Ok(())
252    }
253
254    /// Get the process ID
255    pub fn pid(&self) -> Option<u32> {
256        self.child.id()
257    }
258
259    /// Take the stderr reader (can only be called once)
260    pub fn take_stderr(&mut self) -> Option<BufReader<ChildStderr>> {
261        self.stderr.take()
262    }
263
264    /// Get the session UUID if available
265    /// Returns an error if no response has been received yet
266    pub fn session_uuid(&self) -> Result<Uuid> {
267        self.session_uuid.ok_or(Error::SessionNotInitialized)
268    }
269
270    /// Test if the Claude connection is working by sending a ping message
271    /// Returns true if Claude responds with "pong", false otherwise
272    pub async fn ping(&mut self) -> bool {
273        // Send a simple ping request
274        let ping_input = ClaudeInput::user_message(
275            "ping - respond with just the word 'pong' and nothing else",
276            self.session_uuid.unwrap_or_else(Uuid::new_v4),
277        );
278
279        // Try to send the ping
280        if let Err(e) = self.send(&ping_input).await {
281            debug!("Ping failed to send: {}", e);
282            return false;
283        }
284
285        // Try to receive responses until we get a result or error
286        let mut found_pong = false;
287        let mut message_count = 0;
288        const MAX_MESSAGES: usize = 10;
289
290        loop {
291            match self.receive().await {
292                Ok(output) => {
293                    message_count += 1;
294
295                    // Check if it's an assistant message containing "pong"
296                    if let ClaudeOutput::Assistant(msg) = &output {
297                        for content in &msg.message.content {
298                            if let ContentBlock::Text(text) = content {
299                                if text.text.to_lowercase().contains("pong") {
300                                    found_pong = true;
301                                }
302                            }
303                        }
304                    }
305
306                    // Stop on result message
307                    if matches!(output, ClaudeOutput::Result(_)) {
308                        break;
309                    }
310
311                    // Safety limit
312                    if message_count >= MAX_MESSAGES {
313                        debug!("Ping exceeded message limit");
314                        break;
315                    }
316                }
317                Err(e) => {
318                    debug!("Ping failed to receive response: {}", e);
319                    break;
320                }
321            }
322        }
323
324        found_pong
325    }
326
327    // =========================================================================
328    // Tool Approval Protocol
329    // =========================================================================
330
331    /// Enable the tool approval protocol by performing the initialization handshake.
332    ///
333    /// After calling this method, the CLI will send `ControlRequest` messages when
334    /// Claude wants to use a tool. You must handle these by calling
335    /// `send_control_response()` with an appropriate response.
336    ///
337    /// **Important**: The client must have been created with
338    /// `ClaudeCliBuilder::permission_prompt_tool("stdio")` for this to work.
339    ///
340    /// # Example
341    ///
342    /// ```no_run
343    /// use claude_codes::{AsyncClient, ClaudeCliBuilder, ClaudeOutput, ControlRequestPayload};
344    ///
345    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
346    /// let child = ClaudeCliBuilder::new()
347    ///     .model("sonnet")
348    ///     .permission_prompt_tool("stdio")
349    ///     .spawn()
350    ///     .await?;
351    ///
352    /// let mut client = AsyncClient::new(child)?;
353    /// client.enable_tool_approval().await?;
354    ///
355    /// // Now when you receive messages, you may get ControlRequest messages
356    /// // that need responses
357    /// # Ok(())
358    /// # }
359    /// ```
360    pub async fn enable_tool_approval(&mut self) -> Result<()> {
361        if self.tool_approval_enabled {
362            debug!("[TOOL_APPROVAL] Already enabled, skipping initialization");
363            return Ok(());
364        }
365
366        let request_id = format!("init-{}", Uuid::new_v4());
367        let init_request = ControlRequestMessage::initialize(&request_id);
368
369        debug!("[TOOL_APPROVAL] Sending initialization handshake");
370        let json_line = Protocol::serialize(&init_request)?;
371        self.stdin
372            .write_all(json_line.as_bytes())
373            .await
374            .map_err(Error::Io)?;
375        self.stdin.flush().await.map_err(Error::Io)?;
376
377        // Wait for the initialization response
378        loop {
379            let mut line = String::new();
380            let bytes_read = self.stdout.read_line(&mut line).await.map_err(Error::Io)?;
381
382            if bytes_read == 0 {
383                return Err(Error::ConnectionClosed);
384            }
385
386            let trimmed = line.trim();
387            if trimmed.is_empty() {
388                continue;
389            }
390
391            debug!("[TOOL_APPROVAL] Received: {}", trimmed);
392
393            // Try to parse as ClaudeOutput
394            match ClaudeOutput::parse_json_tolerant(trimmed) {
395                Ok(ClaudeOutput::ControlResponse(resp)) => {
396                    use crate::io::ControlResponsePayload;
397                    match &resp.response {
398                        ControlResponsePayload::Success {
399                            request_id: rid, ..
400                        } if rid == &request_id => {
401                            debug!("[TOOL_APPROVAL] Initialization successful");
402                            self.tool_approval_enabled = true;
403                            return Ok(());
404                        }
405                        ControlResponsePayload::Error { error, .. } => {
406                            return Err(Error::Protocol(format!(
407                                "Tool approval initialization failed: {}",
408                                error
409                            )));
410                        }
411                        _ => {
412                            // Different request_id, keep waiting
413                            continue;
414                        }
415                    }
416                }
417                Ok(_) => {
418                    // Got a different message type (system, etc.), keep waiting
419                    continue;
420                }
421                Err(e) => {
422                    return Err(Error::Deserialization(e.to_string()));
423                }
424            }
425        }
426    }
427
428    /// Send a control response back to the CLI.
429    ///
430    /// Use this to respond to `ControlRequest` messages received during tool approval.
431    /// The easiest way to create responses is using the helper methods on
432    /// `ToolPermissionRequest`:
433    ///
434    /// # Example
435    ///
436    /// ```no_run
437    /// use claude_codes::{AsyncClient, ClaudeOutput, ControlRequestPayload};
438    ///
439    /// # async fn example(client: &mut AsyncClient) -> Result<(), Box<dyn std::error::Error>> {
440    /// # let output = client.receive().await?;
441    /// if let ClaudeOutput::ControlRequest(req) = output {
442    ///     if let ControlRequestPayload::CanUseTool(perm_req) = &req.request {
443    ///         // Use the ergonomic helpers on ToolPermissionRequest
444    ///         let response = if perm_req.tool_name == "Bash" {
445    ///             perm_req.deny("Bash commands not allowed", &req.request_id)
446    ///         } else {
447    ///             perm_req.allow(&req.request_id)
448    ///         };
449    ///         client.send_control_response(response).await?;
450    ///     }
451    /// }
452    /// # Ok(())
453    /// # }
454    /// ```
455    pub async fn send_control_response(&mut self, response: ControlResponse) -> Result<()> {
456        let message: ControlResponseMessage = response.into();
457        let json_line = Protocol::serialize(&message)?;
458        debug!(
459            "[TOOL_APPROVAL] Sending control response: {}",
460            json_line.trim()
461        );
462
463        self.stdin
464            .write_all(json_line.as_bytes())
465            .await
466            .map_err(Error::Io)?;
467        self.stdin.flush().await.map_err(Error::Io)?;
468        Ok(())
469    }
470
471    /// Check if tool approval protocol is enabled
472    pub fn is_tool_approval_enabled(&self) -> bool {
473        self.tool_approval_enabled
474    }
475}
476
477/// A response stream that yields ClaudeOutput messages
478/// Holds a reference to the client to read from
479pub struct ResponseStream<'a> {
480    client: &'a mut AsyncClient,
481    finished: bool,
482}
483
484impl ResponseStream<'_> {
485    /// Convert to a vector by collecting all responses
486    pub async fn collect(mut self) -> Result<Vec<ClaudeOutput>> {
487        let mut responses = Vec::new();
488
489        while !self.finished {
490            let output = self.client.receive().await?;
491            let is_result = matches!(&output, ClaudeOutput::Result(_));
492            responses.push(output);
493
494            if is_result {
495                self.finished = true;
496                break;
497            }
498        }
499
500        Ok(responses)
501    }
502
503    /// Get the next response
504    pub async fn next(&mut self) -> Option<Result<ClaudeOutput>> {
505        if self.finished {
506            return None;
507        }
508
509        match self.client.receive().await {
510            Ok(output) => {
511                if matches!(&output, ClaudeOutput::Result(_)) {
512                    self.finished = true;
513                }
514                Some(Ok(output))
515            }
516            Err(e) => {
517                self.finished = true;
518                Some(Err(e))
519            }
520        }
521    }
522}
523
524impl Drop for AsyncClient {
525    fn drop(&mut self) {
526        if self.is_alive() {
527            // Try to kill the process
528            if let Err(e) = self.child.start_kill() {
529                error!("Failed to kill Claude process on drop: {}", e);
530            }
531        }
532    }
533}
534
535// Protocol extension methods for asynchronous I/O
536impl Protocol {
537    /// Write a message to an async writer
538    pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
539        writer: &mut W,
540        message: &T,
541    ) -> Result<()> {
542        let line = Self::serialize(message)?;
543        debug!("[PROTOCOL] Sending async: {}", line.trim());
544        writer.write_all(line.as_bytes()).await?;
545        writer.flush().await?;
546        Ok(())
547    }
548
549    /// Read a message from an async reader
550    pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
551        reader: &mut R,
552    ) -> Result<T> {
553        let mut line = String::new();
554        let bytes_read = reader.read_line(&mut line).await?;
555        if bytes_read == 0 {
556            return Err(Error::ConnectionClosed);
557        }
558        debug!("[PROTOCOL] Received async: {}", line.trim());
559        Self::deserialize(&line)
560    }
561}
562
563/// Async stream processor for handling continuous message streams
564pub struct AsyncStreamProcessor<R> {
565    reader: AsyncBufReader<R>,
566}
567
568impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
569    /// Create a new async stream processor
570    pub fn new(reader: R) -> Self {
571        Self {
572            reader: AsyncBufReader::new(reader),
573        }
574    }
575
576    /// Process the next message from the stream
577    pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
578        Protocol::read_async(&mut self.reader).await
579    }
580
581    /// Process all messages in the stream
582    pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
583    where
584        T: for<'de> Deserialize<'de>,
585        F: FnMut(T) -> Fut,
586        Fut: std::future::Future<Output = Result<()>>,
587    {
588        loop {
589            match self.next_message().await {
590                Ok(message) => handler(message).await?,
591                Err(Error::ConnectionClosed) => break,
592                Err(e) => return Err(e),
593            }
594        }
595        Ok(())
596    }
597}