mcpkit_transport/
unix.rs

1//! Unix domain socket transport for MCP.
2//!
3//! This module provides Unix domain socket transport for local inter-process
4//! communication. This is only available on Unix-like systems (Linux, macOS, BSDs).
5//!
6//! # Features
7//!
8//! - Low-latency local IPC
9//! - File system-based addressing
10//! - Abstract socket namespace support (Linux)
11//! - Automatic cleanup of socket files
12//! - Newline-delimited JSON message framing
13//!
14//! # Example
15//!
16//! ```rust
17//! #[cfg(unix)]
18//! fn example() {
19//!     use mcpkit_transport::unix::UnixSocketConfig;
20//!
21//!     // Configure a Unix socket
22//!     let config = UnixSocketConfig::new("/tmp/mcp.sock")
23//!         .with_cleanup_on_close(true);
24//!
25//!     assert_eq!(config.path.to_str().unwrap(), "/tmp/mcp.sock");
26//!     assert!(config.cleanup_on_close);
27//! }
28//! ```
29
30#![cfg(unix)]
31
32use crate::error::TransportError;
33use crate::runtime::AsyncMutex;
34use crate::traits::{Transport, TransportListener, TransportMetadata};
35use mcpkit_core::protocol::Message;
36use std::path::{Path, PathBuf};
37use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
38
39#[cfg(feature = "tokio-runtime")]
40use tokio::{
41    io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter},
42    net::{UnixListener as TokioUnixListener, UnixStream},
43};
44
45/// Default maximum message size (16 MB).
46pub const DEFAULT_MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
47
48/// Configuration for Unix socket transport.
49#[derive(Debug, Clone)]
50pub struct UnixSocketConfig {
51    /// Socket path.
52    pub path: PathBuf,
53    /// Whether to remove the socket file on close.
54    pub cleanup_on_close: bool,
55    /// Buffer size for reading.
56    pub read_buffer_size: usize,
57    /// Buffer size for writing.
58    pub write_buffer_size: usize,
59    /// Maximum message size in bytes.
60    pub max_message_size: usize,
61}
62
63impl UnixSocketConfig {
64    /// Create a new Unix socket configuration.
65    pub fn new(path: impl AsRef<Path>) -> Self {
66        Self {
67            path: path.as_ref().to_path_buf(),
68            cleanup_on_close: true,
69            read_buffer_size: 64 * 1024,  // 64 KB
70            write_buffer_size: 64 * 1024, // 64 KB
71            max_message_size: DEFAULT_MAX_MESSAGE_SIZE,
72        }
73    }
74
75    /// Set whether to cleanup the socket file on close.
76    #[must_use]
77    pub const fn with_cleanup_on_close(mut self, cleanup: bool) -> Self {
78        self.cleanup_on_close = cleanup;
79        self
80    }
81
82    /// Set the read buffer size.
83    #[must_use]
84    pub const fn with_read_buffer_size(mut self, size: usize) -> Self {
85        self.read_buffer_size = size;
86        self
87    }
88
89    /// Set the write buffer size.
90    #[must_use]
91    pub const fn with_write_buffer_size(mut self, size: usize) -> Self {
92        self.write_buffer_size = size;
93        self
94    }
95
96    /// Set the maximum message size.
97    #[must_use]
98    pub const fn with_max_message_size(mut self, size: usize) -> Self {
99        self.max_message_size = size;
100        self
101    }
102}
103
104/// Split Unix stream for reading.
105#[cfg(feature = "tokio-runtime")]
106type UnixReader = BufReader<tokio::net::unix::OwnedReadHalf>;
107
108/// Split Unix stream for writing.
109#[cfg(feature = "tokio-runtime")]
110type UnixWriter = BufWriter<tokio::net::unix::OwnedWriteHalf>;
111
112/// Internal state for Unix socket transport.
113struct UnixTransportState {
114    /// Reader half of the Unix stream.
115    #[cfg(feature = "tokio-runtime")]
116    reader: Option<UnixReader>,
117    /// Writer half of the Unix stream.
118    #[cfg(feature = "tokio-runtime")]
119    writer: Option<UnixWriter>,
120    /// Line buffer for reading complete messages.
121    line_buffer: String,
122}
123
124/// Unix domain socket transport.
125///
126/// Provides low-latency local IPC using Unix domain sockets.
127pub struct UnixTransport {
128    config: UnixSocketConfig,
129    state: AsyncMutex<UnixTransportState>,
130    connected: AtomicBool,
131    messages_sent: AtomicU64,
132    messages_received: AtomicU64,
133    is_server_side: bool,
134}
135
136impl UnixTransport {
137    /// Create a new Unix socket transport from an existing stream.
138    #[cfg(feature = "tokio-runtime")]
139    fn from_stream(config: UnixSocketConfig, stream: UnixStream, is_server_side: bool) -> Self {
140        let (read_half, write_half) = stream.into_split();
141        let reader = BufReader::new(read_half);
142        let writer = BufWriter::new(write_half);
143
144        Self {
145            state: AsyncMutex::new(UnixTransportState {
146                reader: Some(reader),
147                writer: Some(writer),
148                line_buffer: String::with_capacity(4096),
149            }),
150            config,
151            connected: AtomicBool::new(true),
152            messages_sent: AtomicU64::new(0),
153            messages_received: AtomicU64::new(0),
154            is_server_side,
155        }
156    }
157
158    /// Create disconnected transport (for non-tokio runtimes or testing).
159    #[cfg(not(feature = "tokio-runtime"))]
160    fn new_disconnected(config: UnixSocketConfig, is_server_side: bool) -> Self {
161        Self {
162            state: AsyncMutex::new(UnixTransportState {
163                line_buffer: String::with_capacity(4096),
164            }),
165            config,
166            connected: AtomicBool::new(false),
167            messages_sent: AtomicU64::new(0),
168            messages_received: AtomicU64::new(0),
169            is_server_side,
170        }
171    }
172
173    /// Connect to a Unix socket server.
174    #[cfg(feature = "tokio-runtime")]
175    pub async fn connect(path: impl AsRef<Path>) -> Result<Self, TransportError> {
176        let config = UnixSocketConfig::new(path);
177        Self::connect_with_config(config).await
178    }
179
180    /// Connect to a Unix socket server (stub for non-tokio runtimes).
181    #[cfg(not(feature = "tokio-runtime"))]
182    pub async fn connect(path: impl AsRef<Path>) -> Result<Self, TransportError> {
183        Err(TransportError::Connection {
184            message: "Unix socket transport requires 'tokio-runtime' feature".to_string(),
185        })
186    }
187
188    /// Connect with custom configuration.
189    #[cfg(feature = "tokio-runtime")]
190    pub async fn connect_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
191        let stream =
192            UnixStream::connect(&config.path)
193                .await
194                .map_err(|e| TransportError::Connection {
195                    message: format!(
196                        "Failed to connect to Unix socket '{}': {}",
197                        config.path.display(),
198                        e
199                    ),
200                })?;
201
202        tracing::debug!(path = %config.path.display(), "Connected to Unix socket");
203        Ok(Self::from_stream(config, stream, false))
204    }
205
206    /// Connect with custom configuration (stub for non-tokio runtimes).
207    #[cfg(not(feature = "tokio-runtime"))]
208    pub async fn connect_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
209        Err(TransportError::Connection {
210            message: "Unix socket transport requires 'tokio-runtime' feature".to_string(),
211        })
212    }
213
214    /// Get the socket path.
215    pub fn path(&self) -> &Path {
216        &self.config.path
217    }
218
219    /// Get the number of messages sent.
220    pub fn messages_sent(&self) -> u64 {
221        self.messages_sent.load(Ordering::Relaxed)
222    }
223
224    /// Get the number of messages received.
225    pub fn messages_received(&self) -> u64 {
226        self.messages_received.load(Ordering::Relaxed)
227    }
228}
229
230impl Transport for UnixTransport {
231    type Error = TransportError;
232
233    #[cfg(feature = "tokio-runtime")]
234    async fn send(&self, msg: Message) -> Result<(), Self::Error> {
235        if !self.connected.load(Ordering::Acquire) {
236            return Err(TransportError::Connection {
237                message: "Unix socket not connected".to_string(),
238            });
239        }
240
241        // Serialize the message with newline delimiter
242        let mut data = serde_json::to_vec(&msg).map_err(|e| TransportError::Serialization {
243            message: format!("Failed to serialize message: {e}"),
244        })?;
245
246        // Check message size limit
247        if data.len() > self.config.max_message_size {
248            return Err(TransportError::MessageTooLarge {
249                size: data.len(),
250                max: self.config.max_message_size,
251            });
252        }
253
254        data.push(b'\n');
255
256        // Write to the socket
257        let mut state = self.state.lock().await;
258        if let Some(writer) = state.writer.as_mut() {
259            writer
260                .write_all(&data)
261                .await
262                .map_err(|e| TransportError::Io {
263                    message: format!("Failed to write to Unix socket: {e}"),
264                })?;
265            writer.flush().await.map_err(|e| TransportError::Io {
266                message: format!("Failed to flush Unix socket: {e}"),
267            })?;
268        } else {
269            return Err(TransportError::Connection {
270                message: "Unix socket writer not available".to_string(),
271            });
272        }
273
274        self.messages_sent.fetch_add(1, Ordering::Relaxed);
275        Ok(())
276    }
277
278    #[cfg(not(feature = "tokio-runtime"))]
279    async fn send(&self, _msg: Message) -> Result<(), Self::Error> {
280        Err(TransportError::Connection {
281            message: "Unix socket transport requires 'tokio-runtime' feature".to_string(),
282        })
283    }
284
285    #[cfg(feature = "tokio-runtime")]
286    async fn recv(&self) -> Result<Option<Message>, Self::Error> {
287        if !self.connected.load(Ordering::Acquire) {
288            return Ok(None);
289        }
290
291        let mut state = self.state.lock().await;
292
293        // Take the reader temporarily to avoid borrowing issues
294        let reader = match state.reader.take() {
295            Some(r) => r,
296            None => return Ok(None),
297        };
298
299        // Clear the buffer and read a line
300        state.line_buffer.clear();
301
302        // We need to read into a separate buffer to avoid borrow issues
303        let (result, reader) = {
304            let mut reader = reader;
305            let result = reader.read_line(&mut state.line_buffer).await;
306            (result, reader)
307        };
308
309        // Put the reader back
310        state.reader = Some(reader);
311
312        match result {
313            Ok(0) => {
314                // EOF - connection closed
315                self.connected.store(false, Ordering::Release);
316                Ok(None)
317            }
318            Ok(_) => {
319                // Parse the message (trim the newline)
320                let line = state.line_buffer.trim_end();
321                if line.is_empty() {
322                    return Ok(None);
323                }
324
325                // Check message size limit
326                if line.len() > self.config.max_message_size {
327                    return Err(TransportError::MessageTooLarge {
328                        size: line.len(),
329                        max: self.config.max_message_size,
330                    });
331                }
332
333                let msg: Message =
334                    serde_json::from_str(line).map_err(|e| TransportError::Deserialization {
335                        message: format!("Failed to deserialize message: {e}"),
336                    })?;
337
338                self.messages_received.fetch_add(1, Ordering::Relaxed);
339                Ok(Some(msg))
340            }
341            Err(e) => {
342                self.connected.store(false, Ordering::Release);
343                Err(TransportError::Io {
344                    message: format!("Failed to read from Unix socket: {e}"),
345                })
346            }
347        }
348    }
349
350    #[cfg(not(feature = "tokio-runtime"))]
351    async fn recv(&self) -> Result<Option<Message>, Self::Error> {
352        Ok(None)
353    }
354
355    #[cfg(feature = "tokio-runtime")]
356    async fn close(&self) -> Result<(), Self::Error> {
357        self.connected.store(false, Ordering::Release);
358
359        // Drop the stream parts
360        let mut state = self.state.lock().await;
361        state.reader = None;
362        state.writer = None;
363
364        // Cleanup socket file if this is server-side and cleanup is enabled
365        if self.is_server_side && self.config.cleanup_on_close && self.config.path.exists() {
366            let _ = std::fs::remove_file(&self.config.path);
367        }
368
369        Ok(())
370    }
371
372    #[cfg(not(feature = "tokio-runtime"))]
373    async fn close(&self) -> Result<(), Self::Error> {
374        self.connected.store(false, Ordering::Release);
375        Ok(())
376    }
377
378    fn is_connected(&self) -> bool {
379        self.connected.load(Ordering::Acquire)
380    }
381
382    fn metadata(&self) -> TransportMetadata {
383        TransportMetadata::new("unix").remote_addr(self.config.path.display().to_string())
384    }
385}
386
387/// Unix domain socket listener.
388///
389/// Listens for incoming connections on a Unix domain socket.
390pub struct UnixListener {
391    config: UnixSocketConfig,
392    #[cfg(feature = "tokio-runtime")]
393    listener: AsyncMutex<Option<TokioUnixListener>>,
394    running: AtomicBool,
395}
396
397impl UnixListener {
398    /// Bind to a Unix socket path.
399    pub async fn bind(path: impl AsRef<Path>) -> Result<Self, TransportError> {
400        let config = UnixSocketConfig::new(path);
401        Self::bind_with_config(config).await
402    }
403
404    /// Bind with custom configuration.
405    #[cfg(feature = "tokio-runtime")]
406    pub async fn bind_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
407        // Remove existing socket file if it exists
408        if config.path.exists() {
409            std::fs::remove_file(&config.path).map_err(|e| TransportError::Io {
410                message: format!("Failed to remove existing socket file: {e}"),
411            })?;
412        }
413
414        // Bind the socket
415        let listener =
416            TokioUnixListener::bind(&config.path).map_err(|e| TransportError::Connection {
417                message: format!(
418                    "Failed to bind Unix socket '{}': {}",
419                    config.path.display(),
420                    e
421                ),
422            })?;
423
424        tracing::info!(path = %config.path.display(), "Unix socket listener bound");
425
426        Ok(Self {
427            config,
428            listener: AsyncMutex::new(Some(listener)),
429            running: AtomicBool::new(true),
430        })
431    }
432
433    /// Bind with custom configuration (stub for non-tokio runtimes).
434    #[cfg(not(feature = "tokio-runtime"))]
435    pub async fn bind_with_config(config: UnixSocketConfig) -> Result<Self, TransportError> {
436        Err(TransportError::Connection {
437            message: "Unix socket listener requires 'tokio-runtime' feature".to_string(),
438        })
439    }
440
441    /// Get the socket path.
442    pub fn path(&self) -> &Path {
443        &self.config.path
444    }
445
446    /// Check if the listener is running.
447    pub fn is_running(&self) -> bool {
448        self.running.load(Ordering::Acquire)
449    }
450
451    /// Stop the listener.
452    #[cfg(feature = "tokio-runtime")]
453    pub async fn stop(&self) {
454        self.running.store(false, Ordering::Release);
455        // Drop the listener
456        let mut guard = self.listener.lock().await;
457        *guard = None;
458    }
459
460    /// Stop the listener (non-tokio version).
461    #[cfg(not(feature = "tokio-runtime"))]
462    pub fn stop(&self) {
463        self.running.store(false, Ordering::Release);
464    }
465}
466
467impl TransportListener for UnixListener {
468    type Transport = UnixTransport;
469    type Error = TransportError;
470
471    #[cfg(feature = "tokio-runtime")]
472    async fn accept(&self) -> Result<Self::Transport, Self::Error> {
473        if !self.running.load(Ordering::Acquire) {
474            return Err(TransportError::Connection {
475                message: "Listener not running".to_string(),
476            });
477        }
478
479        let mut guard = self.listener.lock().await;
480        if let Some(listener) = guard.as_mut() {
481            let (stream, addr) =
482                listener
483                    .accept()
484                    .await
485                    .map_err(|e| TransportError::Connection {
486                        message: format!("Failed to accept connection: {e}"),
487                    })?;
488
489            tracing::debug!(addr = ?addr, "Accepted Unix socket connection");
490
491            Ok(UnixTransport::from_stream(
492                self.config.clone(),
493                stream,
494                true,
495            ))
496        } else {
497            Err(TransportError::Connection {
498                message: "Listener has been stopped".to_string(),
499            })
500        }
501    }
502
503    #[cfg(not(feature = "tokio-runtime"))]
504    async fn accept(&self) -> Result<Self::Transport, Self::Error> {
505        Err(TransportError::Connection {
506            message: "Unix socket listener requires 'tokio-runtime' feature".to_string(),
507        })
508    }
509
510    fn local_addr(&self) -> Option<String> {
511        Some(self.config.path.display().to_string())
512    }
513}
514
515impl Drop for UnixListener {
516    fn drop(&mut self) {
517        if self.config.cleanup_on_close && self.config.path.exists() {
518            let _ = std::fs::remove_file(&self.config.path);
519        }
520    }
521}
522
523/// Abstract Unix socket address (Linux-only).
524///
525/// Abstract sockets don't create files in the filesystem and are
526/// automatically cleaned up when all references are closed.
527#[cfg(target_os = "linux")]
528pub struct AbstractSocket {
529    name: String,
530}
531
532#[cfg(target_os = "linux")]
533impl AbstractSocket {
534    /// Create a new abstract socket name.
535    ///
536    /// The name should not start with a null byte; this is added automatically.
537    pub fn new(name: impl Into<String>) -> Self {
538        Self { name: name.into() }
539    }
540
541    /// Get the socket name.
542    #[must_use]
543    pub fn name(&self) -> &str {
544        &self.name
545    }
546
547    /// Convert to a socket path for use with standard Unix socket APIs.
548    ///
549    /// Returns a path starting with a null byte to indicate an abstract socket.
550    #[must_use]
551    pub fn to_path(&self) -> Vec<u8> {
552        let mut path = vec![0u8];
553        path.extend_from_slice(self.name.as_bytes());
554        path
555    }
556}
557
558/// Builder for Unix socket transport.
559pub struct UnixTransportBuilder {
560    config: UnixSocketConfig,
561}
562
563impl UnixTransportBuilder {
564    /// Create a new builder with the given socket path.
565    pub fn new(path: impl AsRef<Path>) -> Self {
566        Self {
567            config: UnixSocketConfig::new(path),
568        }
569    }
570
571    /// Set whether to cleanup the socket file on close.
572    #[must_use]
573    pub const fn cleanup_on_close(mut self, cleanup: bool) -> Self {
574        self.config.cleanup_on_close = cleanup;
575        self
576    }
577
578    /// Set the read buffer size.
579    #[must_use]
580    pub const fn read_buffer_size(mut self, size: usize) -> Self {
581        self.config.read_buffer_size = size;
582        self
583    }
584
585    /// Set the write buffer size.
586    #[must_use]
587    pub const fn write_buffer_size(mut self, size: usize) -> Self {
588        self.config.write_buffer_size = size;
589        self
590    }
591
592    /// Connect to the socket.
593    pub async fn connect(self) -> Result<UnixTransport, TransportError> {
594        UnixTransport::connect_with_config(self.config).await
595    }
596
597    /// Create a listener on the socket.
598    pub async fn listen(self) -> Result<UnixListener, TransportError> {
599        UnixListener::bind_with_config(self.config).await
600    }
601}
602
603#[cfg(test)]
604mod tests {
605    use super::*;
606
607    #[test]
608    fn test_config_creation() {
609        let config = UnixSocketConfig::new("/tmp/test.sock")
610            .with_cleanup_on_close(false)
611            .with_read_buffer_size(128 * 1024);
612
613        assert_eq!(config.path, PathBuf::from("/tmp/test.sock"));
614        assert!(!config.cleanup_on_close);
615        assert_eq!(config.read_buffer_size, 128 * 1024);
616    }
617
618    #[test]
619    fn test_builder() {
620        let builder = UnixTransportBuilder::new("/tmp/mcp.sock")
621            .cleanup_on_close(true)
622            .read_buffer_size(32 * 1024)
623            .write_buffer_size(32 * 1024);
624
625        assert_eq!(builder.config.read_buffer_size, 32 * 1024);
626        assert_eq!(builder.config.write_buffer_size, 32 * 1024);
627    }
628
629    #[cfg(target_os = "linux")]
630    #[test]
631    fn test_abstract_socket() {
632        let socket = AbstractSocket::new("mcp-test");
633        assert_eq!(socket.name(), "mcp-test");
634
635        let path = socket.to_path();
636        assert_eq!(path[0], 0u8);
637        assert_eq!(&path[1..], b"mcp-test");
638    }
639
640    /// Integration test: Test Unix socket client-server communication.
641    #[cfg(feature = "tokio-runtime")]
642    #[tokio::test]
643    async fn test_unix_socket_communication() {
644        use mcpkit_core::protocol::Request;
645        use std::sync::Arc;
646        use tokio::sync::Barrier;
647
648        let socket_path = format!("/tmp/mcp-test-{}.sock", std::process::id());
649
650        // Clean up any existing socket
651        let _ = std::fs::remove_file(&socket_path);
652
653        // Create server listener
654        let listener = UnixListener::bind(&socket_path).await.unwrap();
655        assert!(listener.is_running());
656
657        // Use a barrier to synchronize
658        let barrier = Arc::new(Barrier::new(2));
659        let barrier_clone = barrier.clone();
660        let socket_path_clone = socket_path.clone();
661
662        // Server task
663        let server_handle = tokio::spawn(async move {
664            // Wait for client to be ready
665            barrier_clone.wait().await;
666
667            // Accept connection
668            let transport = listener.accept().await.unwrap();
669            assert!(transport.is_connected());
670
671            // Receive message
672            let msg = transport.recv().await.unwrap();
673            assert!(msg.is_some());
674
675            // Echo it back
676            if let Some(m) = msg {
677                transport.send(m).await.unwrap();
678            }
679
680            transport.close().await.unwrap();
681        });
682
683        // Client task
684        let client_handle = tokio::spawn(async move {
685            // Signal we're ready
686            barrier.wait().await;
687
688            // Give server time to start accepting
689            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
690
691            // Connect
692            let transport = UnixTransport::connect(&socket_path_clone).await.unwrap();
693            assert!(transport.is_connected());
694
695            // Send a message
696            let request = Request::new("test/echo", 1);
697            let msg = Message::Request(request);
698            transport.send(msg.clone()).await.unwrap();
699
700            // Receive echo
701            let response = transport.recv().await.unwrap();
702            assert!(response.is_some());
703
704            transport.close().await.unwrap();
705        });
706
707        // Wait for both tasks
708        let (server_result, client_result) = tokio::join!(server_handle, client_handle);
709        server_result.unwrap();
710        client_result.unwrap();
711
712        // Clean up socket file
713        let _ = std::fs::remove_file(&socket_path);
714    }
715}