1use std::collections::HashMap;
11use std::io::{Read, Write};
12use std::sync::atomic::{AtomicU32, Ordering};
13use std::sync::Arc;
14
15use tokio::io::{AsyncReadExt, AsyncWriteExt};
16use tokio::sync::Mutex;
17
18use super::native_host::get_socket_path;
19
20const NATIVE_HOST_VERSION: &str = "1.0.0";
22const MAX_MESSAGE_SIZE: u32 = 1048576;
24
25#[cfg(unix)]
27#[allow(dead_code)]
28struct McpClientInfo {
29 id: u32,
30 writer: tokio::net::unix::OwnedWriteHalf,
31}
32
33#[cfg(windows)]
35#[allow(dead_code)]
36struct McpClientInfo {
37 id: u32,
38 pipe: Arc<Mutex<tokio::net::windows::named_pipe::NamedPipeServer>>,
39}
40
41pub struct SocketServer {
43 mcp_clients: Arc<Mutex<HashMap<u32, McpClientInfo>>>,
44 next_client_id: AtomicU32,
45 running: Arc<Mutex<bool>>,
46}
47
48impl SocketServer {
49 pub fn new() -> Self {
51 Self {
52 mcp_clients: Arc::new(Mutex::new(HashMap::new())),
53 next_client_id: AtomicU32::new(1),
54 running: Arc::new(Mutex::new(false)),
55 }
56 }
57
58 #[cfg(unix)]
60 pub async fn start(&self) -> Result<(), String> {
61 let mut running = self.running.lock().await;
62 if *running {
63 return Ok(());
64 }
65
66 let socket_path = get_socket_path();
67 log_message(&format!("Creating socket listener: {}", socket_path));
68
69 let _ = std::fs::remove_file(&socket_path);
71
72 let listener = tokio::net::UnixListener::bind(&socket_path)
73 .map_err(|e| format!("Failed to bind socket: {}", e))?;
74
75 {
77 use std::os::unix::fs::PermissionsExt;
78 let perms = std::fs::Permissions::from_mode(0o600);
79 let _ = std::fs::set_permissions(&socket_path, perms);
80 }
81
82 *running = true;
83 log_message("Socket server listening for connections");
84
85 let clients = Arc::clone(&self.mcp_clients);
86 let next_id = &self.next_client_id;
87
88 loop {
90 match listener.accept().await {
91 Ok((stream, _)) => {
92 let id = next_id.fetch_add(1, Ordering::SeqCst);
93 self.handle_mcp_client(id, stream, Arc::clone(&clients))
94 .await;
95 }
96 Err(e) => {
97 log_message(&format!("Accept error: {}", e));
98 }
99 }
100 }
101 }
102
103 #[cfg(windows)]
105 pub async fn start(&self) -> Result<(), String> {
106 use tokio::net::windows::named_pipe::ServerOptions;
107
108 let mut running = self.running.lock().await;
109 if *running {
110 return Ok(());
111 }
112
113 let pipe_path = get_socket_path();
114 log_message(&format!("Creating named pipe server: {}", pipe_path));
115
116 *running = true;
117 log_message("Named pipe server listening for connections");
118
119 let clients = Arc::clone(&self.mcp_clients);
120 let next_id = &self.next_client_id;
121
122 loop {
124 let server = ServerOptions::new()
126 .first_pipe_instance(false)
127 .create(&pipe_path)
128 .map_err(|e| format!("Failed to create named pipe: {}", e))?;
129
130 if let Err(e) = server.connect().await {
132 log_message(&format!("Named pipe connect error: {}", e));
133 continue;
134 }
135
136 let id = next_id.fetch_add(1, Ordering::SeqCst);
137 self.handle_mcp_client_windows(id, server, Arc::clone(&clients))
138 .await;
139 }
140 }
141
142 #[cfg(unix)]
144 async fn handle_mcp_client(
145 &self,
146 id: u32,
147 stream: tokio::net::UnixStream,
148 clients: Arc<Mutex<HashMap<u32, McpClientInfo>>>,
149 ) {
150 let (mut reader, writer) = stream.into_split();
151
152 {
153 let mut clients = clients.lock().await;
154 clients.insert(id, McpClientInfo { id, writer });
155 log_message(&format!(
156 "MCP client {} connected. Total: {}",
157 id,
158 clients.len()
159 ));
160 }
161
162 send_to_chrome(&serde_json::json!({ "type": "mcp_connected" }));
164
165 let clients_clone = Arc::clone(&clients);
166
167 tokio::spawn(async move {
169 let mut buffer = Vec::new();
170 let mut read_buf = [0u8; 4096];
171
172 loop {
173 match reader.read(&mut read_buf).await {
174 Ok(0) => break,
175 Ok(n) => {
176 buffer.extend_from_slice(&read_buf[..n]);
177 Self::process_mcp_buffer(&mut buffer, id).await;
178 }
179 Err(_) => break,
180 }
181 }
182
183 let mut clients = clients_clone.lock().await;
184 clients.remove(&id);
185 log_message(&format!(
186 "MCP client {} disconnected. Total: {}",
187 id,
188 clients.len()
189 ));
190 });
191 }
192
193 #[cfg(windows)]
195 async fn handle_mcp_client_windows(
196 &self,
197 id: u32,
198 server: tokio::net::windows::named_pipe::NamedPipeServer,
199 clients: Arc<Mutex<HashMap<u32, McpClientInfo>>>,
200 ) {
201 let pipe = Arc::new(Mutex::new(server));
202
203 {
204 let mut clients = clients.lock().await;
205 clients.insert(
206 id,
207 McpClientInfo {
208 id,
209 pipe: Arc::clone(&pipe),
210 },
211 );
212 log_message(&format!(
213 "MCP client {} connected. Total: {}",
214 id,
215 clients.len()
216 ));
217 }
218
219 send_to_chrome(&serde_json::json!({ "type": "mcp_connected" }));
221
222 let clients_clone = Arc::clone(&clients);
223 let pipe_clone = Arc::clone(&pipe);
224
225 tokio::spawn(async move {
227 let mut buffer = Vec::new();
228 let mut read_buf = [0u8; 4096];
229
230 loop {
231 let read_result = {
232 let mut pipe = pipe_clone.lock().await;
233 pipe.read(&mut read_buf).await
234 };
235
236 match read_result {
237 Ok(0) => break,
238 Ok(n) => {
239 buffer.extend_from_slice(&read_buf[..n]);
240 Self::process_mcp_buffer(&mut buffer, id).await;
241 }
242 Err(_) => break,
243 }
244 }
245
246 let mut clients = clients_clone.lock().await;
247 clients.remove(&id);
248 log_message(&format!(
249 "MCP client {} disconnected. Total: {}",
250 id,
251 clients.len()
252 ));
253 });
254 }
255
256 async fn process_mcp_buffer(buffer: &mut Vec<u8>, client_id: u32) {
258 while buffer.len() >= 4 {
259 let msg_len = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
260
261 if msg_len == 0 || msg_len > MAX_MESSAGE_SIZE {
262 log_message(&format!(
263 "Invalid message length from client {}: {}",
264 client_id, msg_len
265 ));
266 buffer.clear();
267 return;
268 }
269
270 let total_len = 4 + msg_len as usize;
271 if buffer.len() < total_len {
272 return;
273 }
274
275 let msg_data = &buffer[4..total_len];
276 if let Ok(msg_str) = std::str::from_utf8(msg_data) {
277 if let Ok(message) = serde_json::from_str::<serde_json::Value>(msg_str) {
278 log_message(&format!(
279 "Received from MCP client {}: {}",
280 client_id,
281 msg_str.get(..msg_str.len().min(200)).unwrap_or(msg_str)
282 ));
283 send_to_chrome(&message);
285 }
286 }
287
288 buffer.drain(..total_len);
289 }
290 }
291
292 pub async fn handle_chrome_message(&self, message: &str) -> Result<(), String> {
294 log_message(&format!(
295 "Chrome message: {}",
296 message.get(..message.len().min(300)).unwrap_or(message)
297 ));
298
299 let data: serde_json::Value =
300 serde_json::from_str(message).map_err(|e| format!("Parse error: {}", e))?;
301
302 if data.get("result").is_some() || data.get("error").is_some() {
304 log_message("Received tool response, forwarding to MCP clients");
305 self.forward_to_mcp_clients(&data).await;
306 return Ok(());
307 }
308
309 if let Some(msg_type) = data.get("type").and_then(|v| v.as_str()) {
311 match msg_type {
312 "ping" => {
313 send_to_chrome(&serde_json::json!({
314 "type": "pong",
315 "timestamp": chrono::Utc::now().timestamp_millis()
316 }));
317 }
318 "get_status" => {
319 send_to_chrome(&serde_json::json!({
320 "type": "status_response",
321 "native_host_version": NATIVE_HOST_VERSION
322 }));
323 }
324 _ => {
325 self.forward_to_mcp_clients(&data).await;
326 }
327 }
328 } else {
329 self.forward_to_mcp_clients(&data).await;
330 }
331
332 Ok(())
333 }
334
335 #[cfg(unix)]
337 async fn forward_to_mcp_clients(&self, data: &serde_json::Value) {
338 let mut clients = self.mcp_clients.lock().await;
339 if clients.is_empty() {
340 return;
341 }
342
343 log_message(&format!("Forwarding to {} MCP clients", clients.len()));
344
345 let json = serde_json::to_vec(data).unwrap_or_default();
346 let mut header = [0u8; 4];
347 header.copy_from_slice(&(json.len() as u32).to_le_bytes());
348
349 let mut failed_ids = Vec::new();
350
351 for (id, client) in clients.iter_mut() {
352 if client.writer.write_all(&header).await.is_err()
353 || client.writer.write_all(&json).await.is_err()
354 {
355 failed_ids.push(*id);
356 }
357 }
358
359 for id in failed_ids {
360 clients.remove(&id);
361 }
362 }
363
364 #[cfg(windows)]
366 async fn forward_to_mcp_clients(&self, data: &serde_json::Value) {
367 let mut clients = self.mcp_clients.lock().await;
368 if clients.is_empty() {
369 return;
370 }
371
372 log_message(&format!("Forwarding to {} MCP clients", clients.len()));
373
374 let json = serde_json::to_vec(data).unwrap_or_default();
375 let mut header = [0u8; 4];
376 header.copy_from_slice(&(json.len() as u32).to_le_bytes());
377
378 let mut failed_ids = Vec::new();
379
380 for (id, client) in clients.iter_mut() {
381 let mut pipe = client.pipe.lock().await;
382 if pipe.write_all(&header).await.is_err() || pipe.write_all(&json).await.is_err() {
383 failed_ids.push(*id);
384 }
385 }
386
387 for id in failed_ids {
388 clients.remove(&id);
389 }
390 }
391
392 #[cfg(unix)]
394 pub async fn stop(&self) {
395 let mut running = self.running.lock().await;
396 if !*running {
397 return;
398 }
399 *running = false;
400
401 let socket_path = get_socket_path();
403 let _ = std::fs::remove_file(&socket_path);
404
405 let mut clients = self.mcp_clients.lock().await;
407 clients.clear();
408
409 log_message("Socket server stopped");
410 }
411
412 #[cfg(windows)]
414 pub async fn stop(&self) {
415 let mut running = self.running.lock().await;
416 if !*running {
417 return;
418 }
419 *running = false;
420
421 let mut clients = self.mcp_clients.lock().await;
423 clients.clear();
424
425 log_message("Named pipe server stopped");
426 }
427}
428
429impl Default for SocketServer {
430 fn default() -> Self {
431 Self::new()
432 }
433}
434
435fn log_message(message: &str) {
437 let timestamp = chrono::Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ");
438 eprintln!("[{}] [Native Host] {}", timestamp, message);
439
440 if let Some(home) = dirs::home_dir() {
442 let log_file = home.join(".aster").join("native-host.log");
443 if let Ok(mut file) = std::fs::OpenOptions::new()
444 .create(true)
445 .append(true)
446 .open(&log_file)
447 {
448 let _ = writeln!(file, "[{}] {}", timestamp, message);
449 }
450 }
451}
452
453fn send_to_chrome(message: &serde_json::Value) {
455 let json_str = serde_json::to_string(message).unwrap_or_default();
456 log_message(&format!(
457 "Sending to Chrome: {}",
458 json_str.get(..json_str.len().min(200)).unwrap_or(&json_str)
459 ));
460
461 let json = json_str.as_bytes();
462 let mut header = [0u8; 4];
463 header.copy_from_slice(&(json.len() as u32).to_le_bytes());
464
465 let mut stdout = std::io::stdout().lock();
466 let _ = stdout.write_all(&header);
467 let _ = stdout.write_all(json);
468 let _ = stdout.flush();
469}
470
471#[allow(dead_code)]
473pub struct NativeMessageReader {
474 buffer: Vec<u8>,
475}
476
477impl NativeMessageReader {
478 pub fn new() -> Self {
479 Self { buffer: Vec::new() }
480 }
481
482 pub fn read(&mut self) -> Option<String> {
484 let mut stdin = std::io::stdin().lock();
485 let mut header = [0u8; 4];
486
487 if stdin.read_exact(&mut header).is_err() {
488 return None;
489 }
490
491 let msg_len = u32::from_le_bytes(header);
492 if msg_len == 0 || msg_len > MAX_MESSAGE_SIZE {
493 log_message(&format!("Invalid message length: {}", msg_len));
494 return None;
495 }
496
497 let mut msg_buf = vec![0u8; msg_len as usize];
498 if stdin.read_exact(&mut msg_buf).is_err() {
499 return None;
500 }
501
502 String::from_utf8(msg_buf).ok()
503 }
504}
505
506impl Default for NativeMessageReader {
507 fn default() -> Self {
508 Self::new()
509 }
510}
511
512pub async fn run_native_host() -> Result<(), String> {
514 log_message("Initializing Native Host...");
515
516 let server = SocketServer::new();
517 let mut reader = NativeMessageReader::new();
518
519 tokio::spawn(async move {
521 let s = SocketServer::new();
522 if let Err(e) = s.start().await {
523 log_message(&format!("Socket server error: {}", e));
524 }
525 });
526
527 log_message("Running in Native Messaging mode");
529 while let Some(message) = reader.read() {
530 if let Err(e) = server.handle_chrome_message(&message).await {
531 log_message(&format!("Handle message error: {}", e));
532 }
533 }
534
535 server.stop().await;
536 Ok(())
537}