1use crate::{
4 errors::{Result, SdkError},
5 transport::{InputMessage, SubprocessTransport, Transport},
6 types::{ClaudeCodeOptions, Message},
7};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock, mpsc};
11use tracing::{debug, error, info};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ClientState {
16 Disconnected,
17 Connected,
18 Error,
19}
20
21pub struct ClaudeSDKClientWorking {
23 options: ClaudeCodeOptions,
25 transport: Arc<Mutex<Option<SubprocessTransport>>>,
27 message_rx: Arc<Mutex<Option<mpsc::Receiver<Message>>>>,
29 state: Arc<RwLock<ClientState>>,
31}
32
33impl ClaudeSDKClientWorking {
34 pub fn new(options: ClaudeCodeOptions) -> Self {
36 unsafe {
37 std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
38 }
39
40 Self {
41 options,
42 transport: Arc::new(Mutex::new(None)),
43 message_rx: Arc::new(Mutex::new(None)),
44 state: Arc::new(RwLock::new(ClientState::Disconnected)),
45 }
46 }
47
48 pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
50 {
52 let state = self.state.read().await;
53 if *state == ClientState::Connected {
54 return Ok(());
55 }
56 }
57
58 let mut new_transport = SubprocessTransport::new(self.options.clone())?;
60 new_transport.connect().await?;
61
62 let (tx, rx) = mpsc::channel::<Message>(100);
64
65 {
67 let mut transport = self.transport.lock().await;
68 *transport = Some(new_transport);
69 }
70
71 {
73 let mut message_rx = self.message_rx.lock().await;
74 *message_rx = Some(rx);
75 }
76
77 {
79 let mut state = self.state.write().await;
80 *state = ClientState::Connected;
81 }
82
83 let transport_clone = self.transport.clone();
85 let state_clone = self.state.clone();
86 let tx_clone = tx.clone();
87
88 tokio::spawn(async move {
89 loop {
90 let msg_result = {
92 let mut transport_guard = transport_clone.lock().await;
93 if let Some(transport) = transport_guard.as_mut() {
94 let mut stream = transport.receive_messages();
96 stream.next().await
97 } else {
98 break;
99 }
100 };
101
102 if let Some(result) = msg_result {
104 match result {
105 Ok(msg) => {
106 debug!("Received message: {:?}", msg);
107 if tx_clone.send(msg).await.is_err() {
108 break;
109 }
110 }
111 Err(e) => {
112 error!("Error receiving message: {}", e);
113 let mut state = state_clone.write().await;
114 *state = ClientState::Error;
115 break;
116 }
117 }
118 } else {
119 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
121 }
122
123 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
125
126 let should_continue = {
127 let state = state_clone.read().await;
128 *state == ClientState::Connected
129 };
130
131 if !should_continue {
132 break;
133 }
134 }
135
136 debug!("Message reader task ended");
137 });
138
139 info!("Connected to Claude CLI");
140
141 if let Some(prompt) = initial_prompt {
143 self.send_user_message(prompt).await?;
144 }
145
146 Ok(())
147 }
148
149 pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
151 {
153 let state = self.state.read().await;
154 if *state != ClientState::Connected {
155 return Err(SdkError::InvalidState {
156 message: "Not connected".into(),
157 });
158 }
159 }
160
161 let message = InputMessage::user(prompt, "default".to_string());
163
164 {
166 let mut transport_guard = self.transport.lock().await;
167 if let Some(transport) = transport_guard.as_mut() {
168 transport.send_message(message).await?;
169 debug!("User message sent");
170 } else {
171 return Err(SdkError::InvalidState {
172 message: "Transport not available".into(),
173 });
174 }
175 }
176
177 Ok(())
178 }
179
180 pub async fn receive_message(&mut self) -> Result<Option<Message>> {
182 let mut rx_guard = self.message_rx.lock().await;
183 if let Some(rx) = rx_guard.as_mut() {
184 Ok(rx.recv().await)
185 } else {
186 Err(SdkError::InvalidState {
187 message: "Not connected".into(),
188 })
189 }
190 }
191
192 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
194 let mut messages = Vec::new();
195
196 while let Some(msg) = self.receive_message().await? {
197 let is_result = matches!(msg, Message::Result { .. });
198 messages.push(msg);
199 if is_result {
200 break;
201 }
202 }
203
204 Ok(messages)
205 }
206
207 pub async fn disconnect(&mut self) -> Result<()> {
209 {
211 let mut state = self.state.write().await;
212 if *state == ClientState::Disconnected {
213 return Ok(());
214 }
215 *state = ClientState::Disconnected;
216 }
217
218 {
220 let mut transport_guard = self.transport.lock().await;
221 if let Some(mut transport) = transport_guard.take() {
222 transport.disconnect().await?;
223 }
224 }
225
226 {
228 let mut rx_guard = self.message_rx.lock().await;
229 rx_guard.take();
230 }
231
232 info!("Disconnected from Claude CLI");
233 Ok(())
234 }
235
236 pub async fn is_connected(&self) -> bool {
238 let state = self.state.read().await;
239 *state == ClientState::Connected
240 }
241}