1use crate::{
4 errors::{Result, SdkError},
5 transport::{InputMessage, SubprocessTransport, Transport},
6 types::{ClaudeCodeOptions, ControlRequest, Message},
7};
8use futures::{Stream, StreamExt};
9use std::sync::Arc;
10use tokio::sync::Mutex;
11use tokio_stream::wrappers::ReceiverStream;
12use tracing::{debug, info};
13
14pub struct InteractiveClient {
19 transport: Arc<Mutex<Box<dyn Transport + Send>>>,
20 connected: bool,
21}
22
23impl InteractiveClient {
24 pub fn new(options: ClaudeCodeOptions) -> Result<Self> {
26 unsafe {
27 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
28 }
29 let transport: Box<dyn Transport + Send> = Box::new(SubprocessTransport::new(options)?);
30 Ok(Self {
31 transport: Arc::new(Mutex::new(transport)),
32 connected: false,
33 })
34 }
35
36 pub async fn connect(&mut self) -> Result<()> {
38 if self.connected {
39 return Ok(());
40 }
41
42 let mut transport = self.transport.lock().await;
43 transport.connect().await?;
44 drop(transport); self.connected = true;
47 info!("Connected to Claude CLI");
48 Ok(())
49 }
50
51 pub async fn send_and_receive(&mut self, prompt: String) -> Result<Vec<Message>> {
53 if !self.connected {
54 return Err(SdkError::InvalidState {
55 message: "Not connected".into(),
56 });
57 }
58
59 {
61 let mut transport = self.transport.lock().await;
62 let message = InputMessage::user(prompt, "default".to_string());
63 transport.send_message(message).await?;
64 } debug!("Message sent, waiting for response");
67
68 let mut messages = Vec::new();
70 loop {
71 let msg_result = {
73 let mut transport = self.transport.lock().await;
74 let mut stream = transport.receive_messages();
75 stream.next().await
76 }; if let Some(result) = msg_result {
80 match result {
81 Ok(msg) => {
82 debug!("Received: {:?}", msg);
83 let is_result = matches!(msg, Message::Result { .. });
84 messages.push(msg);
85 if is_result {
86 break;
87 }
88 }
89 Err(e) => return Err(e),
90 }
91 } else {
92 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
94 }
95 }
96
97 Ok(messages)
98 }
99
100 pub async fn send_message(&mut self, prompt: String) -> Result<()> {
102 if !self.connected {
103 return Err(SdkError::InvalidState {
104 message: "Not connected".into(),
105 });
106 }
107
108 let mut transport = self.transport.lock().await;
109 let message = InputMessage::user(prompt, "default".to_string());
110 transport.send_message(message).await?;
111 drop(transport);
112
113 debug!("Message sent");
114 Ok(())
115 }
116
117 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
119 if !self.connected {
120 return Err(SdkError::InvalidState {
121 message: "Not connected".into(),
122 });
123 }
124
125 let mut messages = Vec::new();
126 loop {
127 let msg_result = {
129 let mut transport = self.transport.lock().await;
130 let mut stream = transport.receive_messages();
131 stream.next().await
132 }; if let Some(result) = msg_result {
136 match result {
137 Ok(msg) => {
138 debug!("Received: {:?}", msg);
139 let is_result = matches!(msg, Message::Result { .. });
140 messages.push(msg);
141 if is_result {
142 break;
143 }
144 }
145 Err(e) => return Err(e),
146 }
147 } else {
148 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
150 }
151 }
152
153 Ok(messages)
154 }
155
156 pub async fn receive_messages_stream(&mut self) -> impl Stream<Item = Result<Message>> + '_ {
188 let (tx, rx) = tokio::sync::mpsc::channel(100);
190 let transport = self.transport.clone();
191
192 tokio::spawn(async move {
194 let mut transport = transport.lock().await;
195 let mut stream = transport.receive_messages();
196
197 while let Some(result) = stream.next().await {
198 if tx.send(result).await.is_err() {
200 break;
202 }
203 }
204 });
205
206 ReceiverStream::new(rx)
208 }
209
210 pub async fn receive_response_stream(&mut self) -> impl Stream<Item = Result<Message>> + '_ {
215 async_stream::stream! {
217 let mut stream = self.receive_messages_stream().await;
218
219 while let Some(result) = stream.next().await {
220 match &result {
221 Ok(msg) => {
222 let is_result = matches!(msg, Message::Result { .. });
223 yield result;
224 if is_result {
225 break;
226 }
227 }
228 Err(_) => {
229 yield result;
230 break;
231 }
232 }
233 }
234 }
235 }
236
237 pub async fn interrupt(&mut self) -> Result<()> {
239 if !self.connected {
240 return Err(SdkError::InvalidState {
241 message: "Not connected".into(),
242 });
243 }
244
245 let mut transport = self.transport.lock().await;
246 let request = ControlRequest::Interrupt {
247 request_id: uuid::Uuid::new_v4().to_string(),
248 };
249 transport.send_control_request(request).await?;
250 drop(transport);
251
252 info!("Interrupt sent");
253 Ok(())
254 }
255
256 pub async fn disconnect(&mut self) -> Result<()> {
258 if !self.connected {
259 return Ok(());
260 }
261
262 let mut transport = self.transport.lock().await;
263 transport.disconnect().await?;
264 drop(transport);
265
266 self.connected = false;
267 info!("Disconnected from Claude CLI");
268 Ok(())
269 }
270}