1use crate::cli::ClaudeCliBuilder;
4use crate::error::{Error, Result};
5use crate::io::{
6 ClaudeInput, ClaudeOutput, ContentBlock, ControlRequestMessage, ControlResponse,
7 ControlResponseMessage,
8};
9use crate::protocol::Protocol;
10use log::{debug, error, info, warn};
11use serde::{Deserialize, Serialize};
12use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufReader as AsyncBufReader};
13use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout};
14use uuid::Uuid;
15
16pub struct AsyncClient {
18 child: Child,
19 stdin: ChildStdin,
20 stdout: BufReader<ChildStdout>,
21 stderr: Option<BufReader<ChildStderr>>,
22 session_uuid: Option<Uuid>,
23 tool_approval_enabled: bool,
25}
26
27const STDOUT_BUFFER_SIZE: usize = 10 * 1024 * 1024;
29
30impl AsyncClient {
31 pub fn new(mut child: Child) -> Result<Self> {
33 let stdin = child
34 .stdin
35 .take()
36 .ok_or_else(|| Error::Io(std::io::Error::other("Failed to get stdin handle")))?;
37
38 let stdout = BufReader::with_capacity(
39 STDOUT_BUFFER_SIZE,
40 child
41 .stdout
42 .take()
43 .ok_or_else(|| Error::Io(std::io::Error::other("Failed to get stdout handle")))?,
44 );
45
46 let stderr = child.stderr.take().map(BufReader::new);
47
48 Ok(Self {
49 child,
50 stdin,
51 stdout,
52 stderr,
53 session_uuid: None,
54 tool_approval_enabled: false,
55 })
56 }
57
58 pub async fn with_defaults() -> Result<Self> {
60 crate::version::check_claude_version_async().await?;
66 Self::with_model("sonnet").await
67 }
68
69 pub async fn with_model(model: &str) -> Result<Self> {
71 let child = ClaudeCliBuilder::new().model(model).spawn().await?;
72
73 info!("Started Claude process with model: {}", model);
74 Self::new(child)
75 }
76
77 pub async fn from_builder(builder: ClaudeCliBuilder) -> Result<Self> {
79 let child = builder.spawn().await?;
80 info!("Started Claude process from custom builder");
81 Self::new(child)
82 }
83
84 pub async fn resume_session(session_uuid: Uuid) -> Result<Self> {
87 let child = ClaudeCliBuilder::new()
88 .resume(Some(session_uuid.to_string()))
89 .spawn()
90 .await?;
91
92 info!("Resuming Claude session with UUID: {}", session_uuid);
93 let mut client = Self::new(child)?;
94 client.session_uuid = Some(session_uuid);
96 Ok(client)
97 }
98
99 pub async fn resume_session_with_model(session_uuid: Uuid, model: &str) -> Result<Self> {
101 let child = ClaudeCliBuilder::new()
102 .model(model)
103 .resume(Some(session_uuid.to_string()))
104 .spawn()
105 .await?;
106
107 info!(
108 "Resuming Claude session with UUID: {} and model: {}",
109 session_uuid, model
110 );
111 let mut client = Self::new(child)?;
112 client.session_uuid = Some(session_uuid);
114 Ok(client)
115 }
116
117 pub async fn query(&mut self, text: &str) -> Result<Vec<ClaudeOutput>> {
120 let session_id = Uuid::new_v4();
121 self.query_with_session(text, session_id).await
122 }
123
124 pub async fn query_with_session(
126 &mut self,
127 text: &str,
128 session_id: Uuid,
129 ) -> Result<Vec<ClaudeOutput>> {
130 let input = ClaudeInput::user_message(text, session_id);
132 self.send(&input).await?;
133
134 let mut responses = Vec::new();
136
137 loop {
138 let output = self.receive().await?;
139 let is_result = matches!(&output, ClaudeOutput::Result(_));
140 responses.push(output);
141
142 if is_result {
143 break;
144 }
145 }
146
147 Ok(responses)
148 }
149
150 pub async fn query_stream(&mut self, text: &str) -> Result<ResponseStream<'_>> {
153 let session_id = Uuid::new_v4();
154 self.query_stream_with_session(text, session_id).await
155 }
156
157 pub async fn query_stream_with_session(
159 &mut self,
160 text: &str,
161 session_id: Uuid,
162 ) -> Result<ResponseStream<'_>> {
163 let input = ClaudeInput::user_message(text, session_id);
165 self.send(&input).await?;
166
167 Ok(ResponseStream {
169 client: self,
170 finished: false,
171 })
172 }
173
174 pub async fn send(&mut self, input: &ClaudeInput) -> Result<()> {
176 let json_line = Protocol::serialize(input)?;
177 debug!("[OUTGOING] Sending JSON to Claude: {}", json_line.trim());
178
179 self.stdin
180 .write_all(json_line.as_bytes())
181 .await
182 .map_err(Error::Io)?;
183
184 self.stdin.flush().await.map_err(Error::Io)?;
185 Ok(())
186 }
187
188 pub async fn receive(&mut self) -> Result<ClaudeOutput> {
206 let mut line = String::new();
207
208 loop {
209 line.clear();
210 let bytes_read = self.stdout.read_line(&mut line).await.map_err(Error::Io)?;
211
212 if bytes_read == 0 {
213 return Err(Error::ConnectionClosed);
214 }
215
216 let trimmed = line.trim();
217 if trimmed.is_empty() {
218 continue;
219 }
220
221 debug!("[INCOMING] Received JSON from Claude: {}", trimmed);
222
223 match ClaudeOutput::parse_json_tolerant(trimmed) {
225 Ok(output) => {
226 debug!("[INCOMING] Parsed output type: {}", output.message_type());
227
228 if self.session_uuid.is_none() {
230 if let ClaudeOutput::Assistant(ref msg) = output {
231 if let Some(ref uuid_str) = msg.uuid {
232 if let Ok(uuid) = Uuid::parse_str(uuid_str) {
233 debug!("[INCOMING] Captured session UUID: {}", uuid);
234 self.session_uuid = Some(uuid);
235 }
236 }
237 } else if let ClaudeOutput::Result(ref msg) = output {
238 if let Some(ref uuid_str) = msg.uuid {
239 if let Ok(uuid) = Uuid::parse_str(uuid_str) {
240 debug!("[INCOMING] Captured session UUID: {}", uuid);
241 self.session_uuid = Some(uuid);
242 }
243 }
244 }
245 }
246
247 return Ok(output);
248 }
249 Err(parse_error) => {
250 warn!("[INCOMING] Failed to deserialize message from Claude CLI. Please report this at https://github.com/meawoppl/rust-claude-codes/issues with the raw message below.");
251 warn!("[INCOMING] Parse error: {}", parse_error);
252 warn!("[INCOMING] Raw message: {}", trimmed);
253 return Err(Error::Deserialization(format!(
254 "{} (raw: {})",
255 parse_error.error_message, trimmed
256 )));
257 }
258 }
259 }
260 }
261
262 pub fn is_alive(&mut self) -> bool {
264 self.child.try_wait().ok().flatten().is_none()
265 }
266
267 pub async fn shutdown(mut self) -> Result<()> {
269 info!("Shutting down Claude process...");
270 self.child.kill().await.map_err(Error::Io)?;
271 Ok(())
272 }
273
274 pub fn pid(&self) -> Option<u32> {
276 self.child.id()
277 }
278
279 pub fn take_stderr(&mut self) -> Option<BufReader<ChildStderr>> {
281 self.stderr.take()
282 }
283
284 pub fn session_uuid(&self) -> Result<Uuid> {
287 self.session_uuid.ok_or(Error::SessionNotInitialized)
288 }
289
290 pub async fn ping(&mut self) -> bool {
293 let ping_input = ClaudeInput::user_message(
295 "ping - respond with just the word 'pong' and nothing else",
296 self.session_uuid.unwrap_or_else(Uuid::new_v4),
297 );
298
299 if let Err(e) = self.send(&ping_input).await {
301 debug!("Ping failed to send: {}", e);
302 return false;
303 }
304
305 let mut found_pong = false;
307 let mut message_count = 0;
308 const MAX_MESSAGES: usize = 10;
309
310 loop {
311 match self.receive().await {
312 Ok(output) => {
313 message_count += 1;
314
315 if let ClaudeOutput::Assistant(msg) = &output {
317 for content in &msg.message.content {
318 if let ContentBlock::Text(text) = content {
319 if text.text.to_lowercase().contains("pong") {
320 found_pong = true;
321 }
322 }
323 }
324 }
325
326 if matches!(output, ClaudeOutput::Result(_)) {
328 break;
329 }
330
331 if message_count >= MAX_MESSAGES {
333 debug!("Ping exceeded message limit");
334 break;
335 }
336 }
337 Err(e) => {
338 debug!("Ping failed to receive response: {}", e);
339 break;
340 }
341 }
342 }
343
344 found_pong
345 }
346
347 pub async fn enable_tool_approval(&mut self) -> Result<()> {
381 if self.tool_approval_enabled {
382 debug!("[TOOL_APPROVAL] Already enabled, skipping initialization");
383 return Ok(());
384 }
385
386 let request_id = format!("init-{}", Uuid::new_v4());
387 let init_request = ControlRequestMessage::initialize(&request_id);
388
389 debug!("[TOOL_APPROVAL] Sending initialization handshake");
390 let json_line = Protocol::serialize(&init_request)?;
391 self.stdin
392 .write_all(json_line.as_bytes())
393 .await
394 .map_err(Error::Io)?;
395 self.stdin.flush().await.map_err(Error::Io)?;
396
397 loop {
399 let mut line = String::new();
400 let bytes_read = self.stdout.read_line(&mut line).await.map_err(Error::Io)?;
401
402 if bytes_read == 0 {
403 return Err(Error::ConnectionClosed);
404 }
405
406 let trimmed = line.trim();
407 if trimmed.is_empty() {
408 continue;
409 }
410
411 debug!("[TOOL_APPROVAL] Received: {}", trimmed);
412
413 match ClaudeOutput::parse_json_tolerant(trimmed) {
415 Ok(ClaudeOutput::ControlResponse(resp)) => {
416 use crate::io::ControlResponsePayload;
417 match &resp.response {
418 ControlResponsePayload::Success {
419 request_id: rid, ..
420 } if rid == &request_id => {
421 debug!("[TOOL_APPROVAL] Initialization successful");
422 self.tool_approval_enabled = true;
423 return Ok(());
424 }
425 ControlResponsePayload::Error { error, .. } => {
426 return Err(Error::Protocol(format!(
427 "Tool approval initialization failed: {}",
428 error
429 )));
430 }
431 _ => {
432 continue;
434 }
435 }
436 }
437 Ok(_) => {
438 continue;
440 }
441 Err(e) => {
442 return Err(Error::Deserialization(e.to_string()));
443 }
444 }
445 }
446 }
447
448 pub async fn send_control_response(&mut self, response: ControlResponse) -> Result<()> {
476 let message: ControlResponseMessage = response.into();
477 let json_line = Protocol::serialize(&message)?;
478 debug!(
479 "[TOOL_APPROVAL] Sending control response: {}",
480 json_line.trim()
481 );
482
483 self.stdin
484 .write_all(json_line.as_bytes())
485 .await
486 .map_err(Error::Io)?;
487 self.stdin.flush().await.map_err(Error::Io)?;
488 Ok(())
489 }
490
491 pub fn is_tool_approval_enabled(&self) -> bool {
493 self.tool_approval_enabled
494 }
495}
496
497pub struct ResponseStream<'a> {
500 client: &'a mut AsyncClient,
501 finished: bool,
502}
503
504impl ResponseStream<'_> {
505 pub async fn collect(mut self) -> Result<Vec<ClaudeOutput>> {
507 let mut responses = Vec::new();
508
509 while !self.finished {
510 let output = self.client.receive().await?;
511 let is_result = matches!(&output, ClaudeOutput::Result(_));
512 responses.push(output);
513
514 if is_result {
515 self.finished = true;
516 break;
517 }
518 }
519
520 Ok(responses)
521 }
522
523 pub async fn next(&mut self) -> Option<Result<ClaudeOutput>> {
525 if self.finished {
526 return None;
527 }
528
529 match self.client.receive().await {
530 Ok(output) => {
531 if matches!(&output, ClaudeOutput::Result(_)) {
532 self.finished = true;
533 }
534 Some(Ok(output))
535 }
536 Err(e) => {
537 self.finished = true;
538 Some(Err(e))
539 }
540 }
541 }
542}
543
544impl Drop for AsyncClient {
545 fn drop(&mut self) {
546 if self.is_alive() {
547 if let Err(e) = self.child.start_kill() {
549 error!("Failed to kill Claude process on drop: {}", e);
550 }
551 }
552 }
553}
554
555impl Protocol {
557 pub async fn write_async<W: AsyncWriteExt + Unpin, T: Serialize>(
559 writer: &mut W,
560 message: &T,
561 ) -> Result<()> {
562 let line = Self::serialize(message)?;
563 debug!("[PROTOCOL] Sending async: {}", line.trim());
564 writer.write_all(line.as_bytes()).await?;
565 writer.flush().await?;
566 Ok(())
567 }
568
569 pub async fn read_async<R: AsyncBufReadExt + Unpin, T: for<'de> Deserialize<'de>>(
571 reader: &mut R,
572 ) -> Result<T> {
573 let mut line = String::new();
574 let bytes_read = reader.read_line(&mut line).await?;
575 if bytes_read == 0 {
576 return Err(Error::ConnectionClosed);
577 }
578 debug!("[PROTOCOL] Received async: {}", line.trim());
579 Self::deserialize(&line)
580 }
581}
582
583pub struct AsyncStreamProcessor<R> {
585 reader: AsyncBufReader<R>,
586}
587
588impl<R: tokio::io::AsyncRead + Unpin> AsyncStreamProcessor<R> {
589 pub fn new(reader: R) -> Self {
591 Self {
592 reader: AsyncBufReader::new(reader),
593 }
594 }
595
596 pub async fn next_message<T: for<'de> Deserialize<'de>>(&mut self) -> Result<T> {
598 Protocol::read_async(&mut self.reader).await
599 }
600
601 pub async fn process_all<T, F, Fut>(&mut self, mut handler: F) -> Result<()>
603 where
604 T: for<'de> Deserialize<'de>,
605 F: FnMut(T) -> Fut,
606 Fut: std::future::Future<Output = Result<()>>,
607 {
608 loop {
609 match self.next_message().await {
610 Ok(message) => handler(message).await?,
611 Err(Error::ConnectionClosed) => break,
612 Err(e) => return Err(e),
613 }
614 }
615 Ok(())
616 }
617}