1#![cfg(unix)]
31
32use crate::error::TransportError;
33use crate::runtime::AsyncMutex;
34use crate::traits::{Transport, TransportListener, TransportMetadata};
35use mcpkit_core::protocol::Message;
36use std::path::{Path, PathBuf};
37use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
38
39#[cfg(feature = "tokio-runtime")]
40use tokio::{
41 io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
42 net::{UnixListener as TokioUnixListener, UnixStream},
43};
44
45pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
47
48#[derive(Debug, Clone)]
50pub struct UnixSocketConfig {
51 pub path: PathBuf,
53 pub cleanup_on_close: bool,
55 pub read_buffer_size: usize,
57 pub write_buffer_size: usize,
59 pub max_message_size: usize,
61}
62
63impl UnixSocketConfig {
64 pub fn new(path: impl AsRef<Path>) -> Self {
66 Self {
67 path: path.as_ref().to_path_buf(),
68 cleanup_on_close: true,
69 read_buffer_size: 64 * 1024, write_buffer_size: 64 * 1024, max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
72 }
73 }
74
75 #[must_use]
77 pub const fn with_cleanup_on_close(mut self, cleanup: bool) -> Self {
78 self.cleanup_on_close = cleanup;
79 self
80 }
81
82 #[must_use]
84 pub const fn with_read_buffer_size(mut self, size: usize) -> Self {
85 self.read_buffer_size = size;
86 self
87 }
88
89 #[must_use]
91 pub const fn with_write_buffer_size(mut self, size: usize) -> Self {
92 self.write_buffer_size = size;
93 self
94 }
95
96 #[must_use]
98 pub const fn with_max_message_size(mut self, size: usize) -> Self {
99 self.max_message_size = size;
100 self
101 }
102}
103
104#[cfg(feature = "tokio-runtime")]
106type UnixReader = BufReader<tokio::net::unix::OwnedReadHalf>;
107
108#[cfg(feature = "tokio-runtime")]
110type UnixWriter = BufWriter<tokio::net::unix::OwnedWriteHalf>;
111
112struct UnixTransportState {
114 #[cfg(feature = "tokio-runtime")]
116 reader: Option<UnixReader>,
117 #[cfg(feature = "tokio-runtime")]
119 writer: Option<UnixWriter>,
120 line_buffer: String,
122}
123
124pub struct UnixTransport {
128 config: UnixSocketConfig,
129 state: AsyncMutex<UnixTransportState>,
130 connected: AtomicBool,
131 messages_sent: AtomicU64,
132 messages_received: AtomicU64,
133 is_server_side: bool,
134}
135
136impl UnixTransport {
137 #[cfg(feature = "tokio-runtime")]
139 fn from_stream(config: UnixSocketConfig, stream: UnixStream, is_server_side: bool) -> Self {
140 let (read_half, write_half) = stream.into_split();
141 let reader = BufReader::new(read_half);
142 let writer = BufWriter::new(write_half);
143
144 Self {
145 state: AsyncMutex::new(UnixTransportState {
146 reader: Some(reader),
147 writer: Some(writer),
148 line_buffer: String::with_capacity(4096),
149 }),
150 config,
151 connected: AtomicBool::new(true),
152 messages_sent: AtomicU64::new(0),
153 messages_received: AtomicU64::new(0),
154 is_server_side,
155 }
156 }
157
158 #[cfg(not(feature = "tokio-runtime"))]
160 fn new_disconnected(config: UnixSocketConfig, is_server_side: bool) -> Self {
161 Self {
162 state: AsyncMutex::new(UnixTransportState {
163 line_buffer: String::with_capacity(4096),
164 }),
165 config,
166 connected: AtomicBool::new(false),
167 messages_sent: AtomicU64::new(0),
168 messages_received: AtomicU64::new(0),
169 is_server_side,
170 }
171 }
172
173 #[cfg(feature = "tokio-runtime")]
175 pub async fn connect(path: impl AsRef<Path>) -> Result<Self, TransportError> {
176 let config = UnixSocketConfig::new(path);
177 Self::connect_with_config(config).await
178 }
179
180 #[cfg(not(feature = "tokio-runtime"))]
182 pub async fn connect(path: impl AsRef<Path>) -> Result<Self, TransportError> {
183 Err(TransportError::Connection {
184 message: "Unix socket transport requires 'tokio-runtime' feature".to_string(),
185 })
186 }
187
188 #[cfg(feature = "tokio-runtime")]
190 pub async fn connect_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
191 let stream =
192 UnixStream::connect(&config.path)
193 .await
194 .map_err(|e| TransportError::Connection {
195 message: format!(
196 "Failed to connect to Unix socket '{}': {}",
197 config.path.display(),
198 e
199 ),
200 })?;
201
202 tracing::debug!(path = %config.path.display(), "Connected to Unix socket");
203 Ok(Self::from_stream(config, stream, false))
204 }
205
206 #[cfg(not(feature = "tokio-runtime"))]
208 pub async fn connect_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
209 Err(TransportError::Connection {
210 message: "Unix socket transport requires 'tokio-runtime' feature".to_string(),
211 })
212 }
213
214 pub fn path(&self) -> &Path {
216 &self.config.path
217 }
218
219 pub fn messages_sent(&self) -> u64 {
221 self.messages_sent.load(Ordering::Relaxed)
222 }
223
224 pub fn messages_received(&self) -> u64 {
226 self.messages_received.load(Ordering::Relaxed)
227 }
228}
229
230impl Transport for UnixTransport {
231 type Error = TransportError;
232
233 #[cfg(feature = "tokio-runtime")]
234 async fn send(&self, msg: Message) -> Result<(), Self::Error> {
235 if !self.connected.load(Ordering::Acquire) {
236 return Err(TransportError::Connection {
237 message: "Unix socket not connected".to_string(),
238 });
239 }
240
241 let mut data = serde_json::to_vec(&msg).map_err(|e| TransportError::Serialization {
243 message: format!("Failed to serialize message: {e}"),
244 })?;
245
246 if data.len() > self.config.max_message_size {
248 return Err(TransportError::MessageTooLarge {
249 size: data.len(),
250 max: self.config.max_message_size,
251 });
252 }
253
254 data.push(b'\n');
255
256 let mut state = self.state.lock().await;
258 if let Some(writer) = state.writer.as_mut() {
259 writer
260 .write_all(&data)
261 .await
262 .map_err(|e| TransportError::Io {
263 message: format!("Failed to write to Unix socket: {e}"),
264 })?;
265 writer.flush().await.map_err(|e| TransportError::Io {
266 message: format!("Failed to flush Unix socket: {e}"),
267 })?;
268 } else {
269 return Err(TransportError::Connection {
270 message: "Unix socket writer not available".to_string(),
271 });
272 }
273
274 self.messages_sent.fetch_add(1, Ordering::Relaxed);
275 Ok(())
276 }
277
278 #[cfg(not(feature = "tokio-runtime"))]
279 async fn send(&self, _msg: Message) -> Result<(), Self::Error> {
280 Err(TransportError::Connection {
281 message: "Unix socket transport requires 'tokio-runtime' feature".to_string(),
282 })
283 }
284
285 #[cfg(feature = "tokio-runtime")]
286 async fn recv(&self) -> Result<Option<Message>, Self::Error> {
287 if !self.connected.load(Ordering::Acquire) {
288 return Ok(None);
289 }
290
291 let mut state = self.state.lock().await;
292
293 let reader = match state.reader.take() {
295 Some(r) => r,
296 None => return Ok(None),
297 };
298
299 state.line_buffer.clear();
301
302 let (result, reader) = {
304 let mut reader = reader;
305 let result = reader.read_line(&mut state.line_buffer).await;
306 (result, reader)
307 };
308
309 state.reader = Some(reader);
311
312 match result {
313 Ok(0) => {
314 self.connected.store(false, Ordering::Release);
316 Ok(None)
317 }
318 Ok(_) => {
319 let line = state.line_buffer.trim_end();
321 if line.is_empty() {
322 return Ok(None);
323 }
324
325 if line.len() > self.config.max_message_size {
327 return Err(TransportError::MessageTooLarge {
328 size: line.len(),
329 max: self.config.max_message_size,
330 });
331 }
332
333 let msg: Message =
334 serde_json::from_str(line).map_err(|e| TransportError::Deserialization {
335 message: format!("Failed to deserialize message: {e}"),
336 })?;
337
338 self.messages_received.fetch_add(1, Ordering::Relaxed);
339 Ok(Some(msg))
340 }
341 Err(e) => {
342 self.connected.store(false, Ordering::Release);
343 Err(TransportError::Io {
344 message: format!("Failed to read from Unix socket: {e}"),
345 })
346 }
347 }
348 }
349
350 #[cfg(not(feature = "tokio-runtime"))]
351 async fn recv(&self) -> Result<Option<Message>, Self::Error> {
352 Ok(None)
353 }
354
355 #[cfg(feature = "tokio-runtime")]
356 async fn close(&self) -> Result<(), Self::Error> {
357 self.connected.store(false, Ordering::Release);
358
359 let mut state = self.state.lock().await;
361 state.reader = None;
362 state.writer = None;
363
364 if self.is_server_side && self.config.cleanup_on_close && self.config.path.exists() {
366 let _ = std::fs::remove_file(&self.config.path);
367 }
368
369 Ok(())
370 }
371
372 #[cfg(not(feature = "tokio-runtime"))]
373 async fn close(&self) -> Result<(), Self::Error> {
374 self.connected.store(false, Ordering::Release);
375 Ok(())
376 }
377
378 fn is_connected(&self) -> bool {
379 self.connected.load(Ordering::Acquire)
380 }
381
382 fn metadata(&self) -> TransportMetadata {
383 TransportMetadata::new("unix").remote_addr(self.config.path.display().to_string())
384 }
385}
386
387pub struct UnixListener {
391 config: UnixSocketConfig,
392 #[cfg(feature = "tokio-runtime")]
393 listener: AsyncMutex<Option<TokioUnixListener>>,
394 running: AtomicBool,
395}
396
397impl UnixListener {
398 pub async fn bind(path: impl AsRef<Path>) -> Result<Self, TransportError> {
400 let config = UnixSocketConfig::new(path);
401 Self::bind_with_config(config).await
402 }
403
404 #[cfg(feature = "tokio-runtime")]
406 pub async fn bind_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
407 if config.path.exists() {
409 std::fs::remove_file(&config.path).map_err(|e| TransportError::Io {
410 message: format!("Failed to remove existing socket file: {e}"),
411 })?;
412 }
413
414 let listener =
416 TokioUnixListener::bind(&config.path).map_err(|e| TransportError::Connection {
417 message: format!(
418 "Failed to bind Unix socket '{}': {}",
419 config.path.display(),
420 e
421 ),
422 })?;
423
424 tracing::info!(path = %config.path.display(), "Unix socket listener bound");
425
426 Ok(Self {
427 config,
428 listener: AsyncMutex::new(Some(listener)),
429 running: AtomicBool::new(true),
430 })
431 }
432
433 #[cfg(not(feature = "tokio-runtime"))]
435 pub async fn bind_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
436 Err(TransportError::Connection {
437 message: "Unix socket listener requires 'tokio-runtime' feature".to_string(),
438 })
439 }
440
441 pub fn path(&self) -> &Path {
443 &self.config.path
444 }
445
446 pub fn is_running(&self) -> bool {
448 self.running.load(Ordering::Acquire)
449 }
450
451 #[cfg(feature = "tokio-runtime")]
453 pub async fn stop(&self) {
454 self.running.store(false, Ordering::Release);
455 let mut guard = self.listener.lock().await;
457 *guard = None;
458 }
459
460 #[cfg(not(feature = "tokio-runtime"))]
462 pub fn stop(&self) {
463 self.running.store(false, Ordering::Release);
464 }
465}
466
467impl TransportListener for UnixListener {
468 type Transport = UnixTransport;
469 type Error = TransportError;
470
471 #[cfg(feature = "tokio-runtime")]
472 async fn accept(&self) -> Result<Self::Transport, Self::Error> {
473 if !self.running.load(Ordering::Acquire) {
474 return Err(TransportError::Connection {
475 message: "Listener not running".to_string(),
476 });
477 }
478
479 let mut guard = self.listener.lock().await;
480 if let Some(listener) = guard.as_mut() {
481 let (stream, addr) =
482 listener
483 .accept()
484 .await
485 .map_err(|e| TransportError::Connection {
486 message: format!("Failed to accept connection: {e}"),
487 })?;
488
489 tracing::debug!(addr = ?addr, "Accepted Unix socket connection");
490
491 Ok(UnixTransport::from_stream(
492 self.config.clone(),
493 stream,
494 true,
495 ))
496 } else {
497 Err(TransportError::Connection {
498 message: "Listener has been stopped".to_string(),
499 })
500 }
501 }
502
503 #[cfg(not(feature = "tokio-runtime"))]
504 async fn accept(&self) -> Result<Self::Transport, Self::Error> {
505 Err(TransportError::Connection {
506 message: "Unix socket listener requires 'tokio-runtime' feature".to_string(),
507 })
508 }
509
510 fn local_addr(&self) -> Option<String> {
511 Some(self.config.path.display().to_string())
512 }
513}
514
515impl Drop for UnixListener {
516 fn drop(&mut self) {
517 if self.config.cleanup_on_close && self.config.path.exists() {
518 let _ = std::fs::remove_file(&self.config.path);
519 }
520 }
521}
522
523#[cfg(target_os = "linux")]
528pub struct AbstractSocket {
529 name: String,
530}
531
532#[cfg(target_os = "linux")]
533impl AbstractSocket {
534 pub fn new(name: impl Into<String>) -> Self {
538 Self { name: name.into() }
539 }
540
541 #[must_use]
543 pub fn name(&self) -> &str {
544 &self.name
545 }
546
547 #[must_use]
551 pub fn to_path(&self) -> Vec<u8> {
552 let mut path = vec![0u8];
553 path.extend_from_slice(self.name.as_bytes());
554 path
555 }
556}
557
558pub struct UnixTransportBuilder {
560 config: UnixSocketConfig,
561}
562
563impl UnixTransportBuilder {
564 pub fn new(path: impl AsRef<Path>) -> Self {
566 Self {
567 config: UnixSocketConfig::new(path),
568 }
569 }
570
571 #[must_use]
573 pub const fn cleanup_on_close(mut self, cleanup: bool) -> Self {
574 self.config.cleanup_on_close = cleanup;
575 self
576 }
577
578 #[must_use]
580 pub const fn read_buffer_size(mut self, size: usize) -> Self {
581 self.config.read_buffer_size = size;
582 self
583 }
584
585 #[must_use]
587 pub const fn write_buffer_size(mut self, size: usize) -> Self {
588 self.config.write_buffer_size = size;
589 self
590 }
591
592 pub async fn connect(self) -> Result<UnixTransport, TransportError> {
594 UnixTransport::connect_with_config(self.config).await
595 }
596
597 pub async fn listen(self) -> Result<UnixListener, TransportError> {
599 UnixListener::bind_with_config(self.config).await
600 }
601}
602
603#[cfg(test)]
604mod tests {
605 use super::*;
606
607 #[test]
608 fn test_config_creation() {
609 let config = UnixSocketConfig::new("/tmp/test.sock")
610 .with_cleanup_on_close(false)
611 .with_read_buffer_size(128 * 1024);
612
613 assert_eq!(config.path, PathBuf::from("/tmp/test.sock"));
614 assert!(!config.cleanup_on_close);
615 assert_eq!(config.read_buffer_size, 128 * 1024);
616 }
617
618 #[test]
619 fn test_builder() {
620 let builder = UnixTransportBuilder::new("/tmp/mcp.sock")
621 .cleanup_on_close(true)
622 .read_buffer_size(32 * 1024)
623 .write_buffer_size(32 * 1024);
624
625 assert_eq!(builder.config.read_buffer_size, 32 * 1024);
626 assert_eq!(builder.config.write_buffer_size, 32 * 1024);
627 }
628
629 #[cfg(target_os = "linux")]
630 #[test]
631 fn test_abstract_socket() {
632 let socket = AbstractSocket::new("mcp-test");
633 assert_eq!(socket.name(), "mcp-test");
634
635 let path = socket.to_path();
636 assert_eq!(path[0], 0u8);
637 assert_eq!(&path[1..], b"mcp-test");
638 }
639
640 #[cfg(feature = "tokio-runtime")]
642 #[tokio::test]
643 async fn test_unix_socket_communication() {
644 use mcpkit_core::protocol::Request;
645 use std::sync::Arc;
646 use tokio::sync::Barrier;
647
648 let socket_path = format!("/tmp/mcp-test-{}.sock", std::process::id());
649
650 let _ = std::fs::remove_file(&socket_path);
652
653 let listener = UnixListener::bind(&socket_path).await.unwrap();
655 assert!(listener.is_running());
656
657 let barrier = Arc::new(Barrier::new(2));
659 let barrier_clone = barrier.clone();
660 let socket_path_clone = socket_path.clone();
661
662 let server_handle = tokio::spawn(async move {
664 barrier_clone.wait().await;
666
667 let transport = listener.accept().await.unwrap();
669 assert!(transport.is_connected());
670
671 let msg = transport.recv().await.unwrap();
673 assert!(msg.is_some());
674
675 if let Some(m) = msg {
677 transport.send(m).await.unwrap();
678 }
679
680 transport.close().await.unwrap();
681 });
682
683 let client_handle = tokio::spawn(async move {
685 barrier.wait().await;
687
688 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
690
691 let transport = UnixTransport::connect(&socket_path_clone).await.unwrap();
693 assert!(transport.is_connected());
694
695 let request = Request::new("test/echo", 1);
697 let msg = Message::Request(request);
698 transport.send(msg.clone()).await.unwrap();
699
700 let response = transport.recv().await.unwrap();
702 assert!(response.is_some());
703
704 transport.close().await.unwrap();
705 });
706
707 let (server_result, client_result) = tokio::join!(server_handle, client_handle);
709 server_result.unwrap();
710 client_result.unwrap();
711
712 let _ = std::fs::remove_file(&socket_path);
714 }
715}