1use crate::{
4 errors::{Result, SdkError},
5 transport::{InputMessage, SubprocessTransport, Transport},
6 types::{ClaudeCodeOptions, Message},
7};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio::sync::{Mutex, RwLock, mpsc};
11use tracing::{debug, error, info};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ClientState {
16 Disconnected,
17 Connected,
18 Error,
19}
20
21pub struct ClaudeSDKClientWorking {
23 options: ClaudeCodeOptions,
25 transport: Arc<Mutex<Option<SubprocessTransport>>>,
27 message_rx: Arc<Mutex<Option<mpsc::Receiver<Message>>>>,
29 state: Arc<RwLock<ClientState>>,
31}
32
33impl ClaudeSDKClientWorking {
34 pub fn new(options: ClaudeCodeOptions) -> Self {
36 Self {
37 options,
38 transport: Arc::new(Mutex::new(None)),
39 message_rx: Arc::new(Mutex::new(None)),
40 state: Arc::new(RwLock::new(ClientState::Disconnected)),
41 }
42 }
43
44 pub async fn connect(&mut self, initial_prompt: Option<String>) -> Result<()> {
46 {
48 let state = self.state.read().await;
49 if *state == ClientState::Connected {
50 return Ok(());
51 }
52 }
53
54 let mut new_transport = SubprocessTransport::new(self.options.clone())?;
56 new_transport.connect().await?;
57
58 let (tx, rx) = mpsc::channel::<Message>(100);
60
61 {
63 let mut transport = self.transport.lock().await;
64 *transport = Some(new_transport);
65 }
66
67 {
69 let mut message_rx = self.message_rx.lock().await;
70 *message_rx = Some(rx);
71 }
72
73 {
75 let mut state = self.state.write().await;
76 *state = ClientState::Connected;
77 }
78
79 let transport_clone = self.transport.clone();
81 let state_clone = self.state.clone();
82 let tx_clone = tx.clone();
83
84 tokio::spawn(async move {
85 loop {
86 let msg_result = {
88 let mut transport_guard = transport_clone.lock().await;
89 if let Some(transport) = transport_guard.as_mut() {
90 let mut stream = transport.receive_messages();
92 stream.next().await
93 } else {
94 break;
95 }
96 };
97
98 if let Some(result) = msg_result {
100 match result {
101 Ok(msg) => {
102 debug!("Received message: {:?}", msg);
103 if tx_clone.send(msg).await.is_err() {
104 break;
105 }
106 }
107 Err(e) => {
108 error!("Error receiving message: {}", e);
109 let mut state = state_clone.write().await;
110 *state = ClientState::Error;
111 break;
112 }
113 }
114 } else {
115 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
117 }
118
119 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
121
122 let should_continue = {
123 let state = state_clone.read().await;
124 *state == ClientState::Connected
125 };
126
127 if !should_continue {
128 break;
129 }
130 }
131
132 debug!("Message reader task ended");
133 });
134
135 info!("Connected to Claude CLI");
136
137 if let Some(prompt) = initial_prompt {
139 self.send_user_message(prompt).await?;
140 }
141
142 Ok(())
143 }
144
145 pub async fn send_user_message(&mut self, prompt: String) -> Result<()> {
147 {
149 let state = self.state.read().await;
150 if *state != ClientState::Connected {
151 return Err(SdkError::InvalidState {
152 message: "Not connected".into(),
153 });
154 }
155 }
156
157 let message = InputMessage::user(prompt, "default".to_string());
159
160 {
162 let mut transport_guard = self.transport.lock().await;
163 if let Some(transport) = transport_guard.as_mut() {
164 transport.send_message(message).await?;
165 debug!("User message sent");
166 } else {
167 return Err(SdkError::InvalidState {
168 message: "Transport not available".into(),
169 });
170 }
171 }
172
173 Ok(())
174 }
175
176 pub async fn receive_message(&mut self) -> Result<Option<Message>> {
178 let mut rx_guard = self.message_rx.lock().await;
179 if let Some(rx) = rx_guard.as_mut() {
180 Ok(rx.recv().await)
181 } else {
182 Err(SdkError::InvalidState {
183 message: "Not connected".into(),
184 })
185 }
186 }
187
188 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
190 let mut messages = Vec::new();
191
192 while let Some(msg) = self.receive_message().await? {
193 let is_result = matches!(msg, Message::Result { .. });
194 messages.push(msg);
195 if is_result {
196 break;
197 }
198 }
199
200 Ok(messages)
201 }
202
203 pub async fn disconnect(&mut self) -> Result<()> {
205 {
207 let mut state = self.state.write().await;
208 if *state == ClientState::Disconnected {
209 return Ok(());
210 }
211 *state = ClientState::Disconnected;
212 }
213
214 {
216 let mut transport_guard = self.transport.lock().await;
217 if let Some(mut transport) = transport_guard.take() {
218 transport.disconnect().await?;
219 }
220 }
221
222 {
224 let mut rx_guard = self.message_rx.lock().await;
225 rx_guard.take();
226 }
227
228 info!("Disconnected from Claude CLI");
229 Ok(())
230 }
231
232 pub async fn is_connected(&self) -> bool {
234 let state = self.state.read().await;
235 *state == ClientState::Connected
236 }
237}