1use std::collections::HashMap;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::Arc;
9use std::time::Duration;
10
11#[cfg(unix)]
12use tokio::io::AsyncReadExt;
13use tokio::io::AsyncWriteExt;
14use tokio::sync::{mpsc, oneshot, Mutex};
15use tokio::time::timeout;
16
17use super::native_host::get_socket_path;
18use super::types::ToolCallResult;
19
20const MAX_MESSAGE_SIZE: u32 = 1048576;
22const CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
24const TOOL_CALL_TIMEOUT: Duration = Duration::from_secs(60);
26#[allow(dead_code)]
28const RECONNECT_DELAY: Duration = Duration::from_secs(1);
29#[allow(dead_code)]
31const MAX_RECONNECT_ATTEMPTS: u32 = 10;
32
33#[derive(Debug, Clone)]
35pub struct SocketConnectionError {
36 pub message: String,
37}
38
39impl std::fmt::Display for SocketConnectionError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 write!(f, "SocketConnectionError: {}", self.message)
42 }
43}
44
45impl std::error::Error for SocketConnectionError {}
46
47impl SocketConnectionError {
48 pub fn new(message: impl Into<String>) -> Self {
49 Self {
50 message: message.into(),
51 }
52 }
53}
54
55struct PendingCall {
57 sender: oneshot::Sender<Result<ToolCallResult, SocketConnectionError>>,
58}
59
60struct ClientState {
62 connected: bool,
63 connecting: bool,
64 pending_calls: HashMap<String, PendingCall>,
65 reconnect_attempts: u32,
66}
67
68pub struct SocketClient {
70 state: Arc<Mutex<ClientState>>,
71 call_id: AtomicU64,
72 #[cfg(unix)]
73 writer: Arc<Mutex<Option<tokio::net::unix::OwnedWriteHalf>>>,
74 #[cfg(windows)]
75 writer: Arc<Mutex<Option<tokio::net::windows::named_pipe::NamedPipeClient>>>,
76 shutdown_tx: Arc<Mutex<Option<mpsc::Sender<()>>>>,
77}
78
79impl SocketClient {
80 pub fn new() -> Self {
82 Self {
83 state: Arc::new(Mutex::new(ClientState {
84 connected: false,
85 connecting: false,
86 pending_calls: HashMap::new(),
87 reconnect_attempts: 0,
88 })),
89 call_id: AtomicU64::new(0),
90 writer: Arc::new(Mutex::new(None)),
91 shutdown_tx: Arc::new(Mutex::new(None)),
92 }
93 }
94
95 pub async fn is_connected(&self) -> bool {
97 self.state.lock().await.connected
98 }
99
100 pub async fn ensure_connected(&self) -> bool {
102 {
103 let state = self.state.lock().await;
104 if state.connected {
105 return true;
106 }
107 if state.connecting {
108 drop(state);
109 for _ in 0..50 {
111 tokio::time::sleep(Duration::from_millis(100)).await;
112 let state = self.state.lock().await;
113 if state.connected {
114 return true;
115 }
116 if !state.connecting {
117 return false;
118 }
119 }
120 return false;
121 }
122 }
123
124 match self.connect().await {
125 Ok(_) => self.state.lock().await.connected,
126 Err(e) => {
127 tracing::warn!("Failed to connect to socket: {}", e);
128 false
129 }
130 }
131 }
132
133 #[cfg(unix)]
135 async fn connect(&self) -> Result<(), SocketConnectionError> {
136 {
137 let mut state = self.state.lock().await;
138 if state.connected || state.connecting {
139 return Ok(());
140 }
141 state.connecting = true;
142 }
143
144 let socket_path = get_socket_path();
145
146 let connect_result = timeout(
147 CONNECT_TIMEOUT,
148 tokio::net::UnixStream::connect(&socket_path),
149 )
150 .await;
151
152 match connect_result {
153 Ok(Ok(stream)) => {
154 let (reader, writer) = stream.into_split();
155 *self.writer.lock().await = Some(writer);
156
157 let state_clone = Arc::clone(&self.state);
158 let (shutdown_tx, shutdown_rx) = mpsc::channel::<()>(1);
159 *self.shutdown_tx.lock().await = Some(shutdown_tx);
160
161 tokio::spawn(async move {
163 Self::read_loop(reader, state_clone, shutdown_rx).await;
164 });
165
166 let mut state = self.state.lock().await;
167 state.connected = true;
168 state.connecting = false;
169 state.reconnect_attempts = 0;
170 tracing::info!("Connected to socket server");
171 Ok(())
172 }
173 Ok(Err(e)) => {
174 let mut state = self.state.lock().await;
175 state.connecting = false;
176 Err(SocketConnectionError::new(format!(
177 "Connection failed: {}",
178 e
179 )))
180 }
181 Err(_) => {
182 let mut state = self.state.lock().await;
183 state.connecting = false;
184 Err(SocketConnectionError::new("Connection timeout"))
185 }
186 }
187 }
188
189 #[cfg(windows)]
191 async fn connect(&self) -> Result<(), SocketConnectionError> {
192 {
193 let mut state = self.state.lock().await;
194 if state.connected || state.connecting {
195 return Ok(());
196 }
197 state.connecting = true;
198 }
199
200 let socket_path = get_socket_path();
201
202 let connect_result = timeout(CONNECT_TIMEOUT, async {
204 tokio::net::windows::named_pipe::ClientOptions::new().open(&socket_path)
205 })
206 .await;
207
208 match connect_result {
209 Ok(Ok(pipe)) => {
210 *self.writer.lock().await = Some(pipe);
211
212 let mut state = self.state.lock().await;
213 state.connected = true;
214 state.connecting = false;
215 state.reconnect_attempts = 0;
216 tracing::info!("Connected to socket server");
217 Ok(())
218 }
219 Ok(Err(e)) => {
220 let mut state = self.state.lock().await;
221 state.connecting = false;
222 Err(SocketConnectionError::new(format!(
223 "Connection failed: {}",
224 e
225 )))
226 }
227 Err(_) => {
228 let mut state = self.state.lock().await;
229 state.connecting = false;
230 Err(SocketConnectionError::new("Connection timeout"))
231 }
232 }
233 }
234
235 #[cfg(unix)]
237 async fn read_loop(
238 mut reader: tokio::net::unix::OwnedReadHalf,
239 state: Arc<Mutex<ClientState>>,
240 mut shutdown_rx: mpsc::Receiver<()>,
241 ) {
242 let mut buffer = Vec::new();
243 let mut read_buf = [0u8; 4096];
244
245 loop {
246 tokio::select! {
247 _ = shutdown_rx.recv() => {
248 tracing::info!("Socket read loop shutdown");
249 break;
250 }
251 result = reader.read(&mut read_buf) => {
252 match result {
253 Ok(0) => {
254 tracing::info!("Socket connection closed");
255 Self::handle_disconnect(state).await;
256 break;
257 }
258 Ok(n) => {
259 buffer.extend_from_slice(&read_buf[..n]);
260 Self::process_buffer(&mut buffer, &state).await;
261 }
262 Err(e) => {
263 tracing::error!("Socket read error: {}", e);
264 Self::handle_disconnect(state).await;
265 break;
266 }
267 }
268 }
269 }
270 }
271 }
272
273 async fn handle_disconnect(state: Arc<Mutex<ClientState>>) {
275 let mut state = state.lock().await;
276 state.connected = false;
277 state.connecting = false;
278
279 for (_, pending) in state.pending_calls.drain() {
281 let _ = pending
282 .sender
283 .send(Err(SocketConnectionError::new("Connection closed")));
284 }
285 }
286
287 async fn process_buffer(buffer: &mut Vec<u8>, state: &Arc<Mutex<ClientState>>) {
289 while buffer.len() >= 4 {
290 let msg_len = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
291
292 if msg_len == 0 || msg_len > MAX_MESSAGE_SIZE {
293 tracing::error!("Invalid message length: {}", msg_len);
294 buffer.clear();
295 return;
296 }
297
298 let total_len = 4 + msg_len as usize;
299 if buffer.len() < total_len {
300 return; }
302
303 let msg_data = &buffer[4..total_len];
304 if let Ok(msg_str) = std::str::from_utf8(msg_data) {
305 Self::handle_message(msg_str, state).await;
306 }
307
308 buffer.drain(..total_len);
309 }
310 }
311
312 async fn handle_message(msg_str: &str, state: &Arc<Mutex<ClientState>>) {
314 let msg: serde_json::Value = match serde_json::from_str(msg_str) {
315 Ok(v) => v,
316 Err(e) => {
317 tracing::error!("Failed to parse message: {}", e);
318 return;
319 }
320 };
321
322 tracing::debug!(
323 "Received message: {}",
324 msg_str.get(..msg_str.len().min(300)).unwrap_or(msg_str)
325 );
326
327 if msg.get("result").is_some() || msg.get("error").is_some() {
329 let result = super::types::ToolCallResult {
330 result: msg.get("result").and_then(|r| {
331 r.get("content").map(|c| super::types::ToolResultContent {
332 content: c.as_array().cloned().unwrap_or_default(),
333 })
334 }),
335 error: msg.get("error").and_then(|e| {
336 e.get("content").map(|c| super::types::ToolErrorContent {
337 content: c.as_array().cloned().unwrap_or_default(),
338 })
339 }),
340 };
341
342 let mut state = state.lock().await;
343 if let Some(call_id) = state.pending_calls.keys().next().cloned() {
345 if let Some(pending) = state.pending_calls.remove(&call_id) {
346 let _ = pending.sender.send(Ok(result));
347 }
348 }
349 }
350 }
351
352 pub async fn call_tool(
354 &self,
355 tool_name: &str,
356 args: serde_json::Value,
357 ) -> Result<ToolCallResult, SocketConnectionError> {
358 if !self.is_connected().await {
359 return Err(SocketConnectionError::new("Not connected"));
360 }
361
362 let call_id = format!(
363 "call_{}_{}",
364 self.call_id.fetch_add(1, Ordering::SeqCst),
365 chrono::Utc::now().timestamp_millis()
366 );
367
368 let (tx, rx) = oneshot::channel();
369
370 {
372 let mut state = self.state.lock().await;
373 state
374 .pending_calls
375 .insert(call_id.clone(), PendingCall { sender: tx });
376 }
377
378 let message = serde_json::json!({
380 "type": "tool_request",
381 "method": "execute_tool",
382 "params": {
383 "tool": tool_name,
384 "client_id": "aster",
385 "args": args
386 }
387 });
388
389 if let Err(e) = self.send_message(&message).await {
391 let mut state = self.state.lock().await;
392 state.pending_calls.remove(&call_id);
393 return Err(e);
394 }
395
396 match timeout(TOOL_CALL_TIMEOUT, rx).await {
398 Ok(Ok(result)) => result,
399 Ok(Err(_)) => Err(SocketConnectionError::new("Response channel closed")),
400 Err(_) => {
401 let mut state = self.state.lock().await;
402 state.pending_calls.remove(&call_id);
403 Err(SocketConnectionError::new("Tool call timeout"))
404 }
405 }
406 }
407
408 #[cfg(unix)]
410 async fn send_message(&self, message: &serde_json::Value) -> Result<(), SocketConnectionError> {
411 let json = serde_json::to_vec(message)
412 .map_err(|e| SocketConnectionError::new(format!("Serialize error: {}", e)))?;
413
414 let mut header = [0u8; 4];
415 header.copy_from_slice(&(json.len() as u32).to_le_bytes());
416
417 let mut writer = self.writer.lock().await;
418 if let Some(ref mut w) = *writer {
419 w.write_all(&header)
420 .await
421 .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
422 w.write_all(&json)
423 .await
424 .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
425 Ok(())
426 } else {
427 Err(SocketConnectionError::new("Not connected"))
428 }
429 }
430
431 #[cfg(windows)]
433 async fn send_message(&self, message: &serde_json::Value) -> Result<(), SocketConnectionError> {
434 let json = serde_json::to_vec(message)
435 .map_err(|e| SocketConnectionError::new(format!("Serialize error: {}", e)))?;
436
437 let mut header = [0u8; 4];
438 header.copy_from_slice(&(json.len() as u32).to_le_bytes());
439
440 let mut writer = self.writer.lock().await;
441 if let Some(ref mut w) = *writer {
442 w.write_all(&header)
443 .await
444 .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
445 w.write_all(&json)
446 .await
447 .map_err(|e| SocketConnectionError::new(format!("Write error: {}", e)))?;
448 Ok(())
449 } else {
450 Err(SocketConnectionError::new("Not connected"))
451 }
452 }
453
454 pub async fn disconnect(&self) {
456 if let Some(tx) = self.shutdown_tx.lock().await.take() {
458 let _ = tx.send(()).await;
459 }
460
461 *self.writer.lock().await = None;
463
464 let mut state = self.state.lock().await;
466 state.connected = false;
467 state.connecting = false;
468
469 for (_, pending) in state.pending_calls.drain() {
471 let _ = pending
472 .sender
473 .send(Err(SocketConnectionError::new("Disconnected")));
474 }
475 }
476}
477
478impl Default for SocketClient {
479 fn default() -> Self {
480 Self::new()
481 }
482}
483
484pub fn create_socket_client() -> SocketClient {
486 SocketClient::new()
487}