Skip to main content

ipckit/
local_socket.rs

1//! Local Socket implementation for IPC
2//!
3//! This module provides a cross-platform local socket abstraction for IPC.
4//! When the `backend-interprocess` feature is enabled, it uses the `interprocess` crate
5//! for a more robust implementation. Otherwise, it falls back to the native implementation.
6//!
7//! # Features
8//! - Unix Domain Sockets on Unix systems
9//! - Named Pipes on Windows
10//! - Server/Client architecture
11//! - Async support (with `async` feature)
12
13use crate::error::Result;
14use std::io::{Read, Write};
15
16// ============================================================================
17// Backend: interprocess
18// ============================================================================
19
20#[cfg(feature = "backend-interprocess")]
21mod interprocess_backend {
22    use super::*;
23    use crate::error::IpcError;
24    use interprocess::local_socket::{
25        prelude::*, GenericFilePath, GenericNamespaced, ListenerOptions, Stream, ToFsName, ToNsName,
26    };
27
28    /// A local socket listener that accepts incoming connections.
29    pub struct LocalSocketListener {
30        listener: interprocess::local_socket::Listener,
31        name: String,
32    }
33
34    /// A local socket stream for bidirectional communication.
35    pub struct LocalSocketStream {
36        inner: Stream,
37        name: String,
38    }
39
40    impl LocalSocketListener {
41        /// Create a new local socket listener bound to the given name.
42        pub fn bind(name: &str) -> Result<Self> {
43            let socket_name = get_socket_name(name)?;
44
45            let listener = ListenerOptions::new()
46                .name(socket_name)
47                .create_sync()
48                .map_err(|e| IpcError::Io(std::io::Error::other(e)))?;
49
50            Ok(Self {
51                listener,
52                name: name.to_string(),
53            })
54        }
55
56        /// Accept a new incoming connection.
57        pub fn accept(&self) -> Result<LocalSocketStream> {
58            let stream = self
59                .listener
60                .accept()
61                .map_err(|e| IpcError::Io(std::io::Error::other(e)))?;
62
63            Ok(LocalSocketStream {
64                inner: stream,
65                name: self.name.clone(),
66            })
67        }
68
69        /// Get the name of this listener.
70        pub fn name(&self) -> &str {
71            &self.name
72        }
73
74        /// Returns an iterator over incoming connections.
75        pub fn incoming(&self) -> impl Iterator<Item = Result<LocalSocketStream>> + '_ {
76            std::iter::from_fn(move || Some(self.accept()))
77        }
78    }
79
80    impl LocalSocketStream {
81        /// Connect to a local socket server.
82        pub fn connect(name: &str) -> Result<Self> {
83            let socket_name = get_socket_name(name)?;
84
85            let stream =
86                Stream::connect(socket_name).map_err(|e| IpcError::Io(std::io::Error::other(e)))?;
87
88            Ok(Self {
89                inner: stream,
90                name: name.to_string(),
91            })
92        }
93
94        /// Get the name of this stream.
95        pub fn name(&self) -> &str {
96            &self.name
97        }
98    }
99
100    impl Read for LocalSocketStream {
101        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
102            self.inner.read(buf)
103        }
104    }
105
106    impl Write for LocalSocketStream {
107        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
108            self.inner.write(buf)
109        }
110
111        fn flush(&mut self) -> std::io::Result<()> {
112            self.inner.flush()
113        }
114    }
115
116    /// Get the appropriate socket name for the current platform.
117    fn get_socket_name(name: &str) -> Result<interprocess::local_socket::Name<'static>> {
118        // Try namespaced name first (works on Linux with abstract sockets and Windows)
119        if let Ok(ns_name) = name.to_string().to_ns_name::<GenericNamespaced>() {
120            return Ok(ns_name);
121        }
122
123        // Fall back to filesystem path
124        let path = if cfg!(unix) {
125            if name.starts_with('/') {
126                name.to_string()
127            } else {
128                format!("/tmp/{}.sock", name)
129            }
130        } else {
131            // Windows named pipe
132            if name.starts_with(r"\\.\pipe\") {
133                name.to_string()
134            } else {
135                format!(r"\\.\pipe\{}", name)
136            }
137        };
138
139        path.to_fs_name::<GenericFilePath>()
140            .map_err(|e| IpcError::Io(std::io::Error::other(e)))
141    }
142}
143
144#[cfg(feature = "backend-interprocess")]
145pub use interprocess_backend::{LocalSocketListener, LocalSocketStream};
146
147// ============================================================================
148// Backend: Native (fallback)
149// ============================================================================
150
151#[cfg(not(feature = "backend-interprocess"))]
152mod native_backend {
153    use super::*;
154    #[cfg(unix)]
155    use crate::error::IpcError;
156
157    #[cfg(unix)]
158    use std::os::unix::net::{UnixListener, UnixStream};
159
160    /// A local socket listener that accepts incoming connections.
161    pub struct LocalSocketListener {
162        #[cfg(unix)]
163        listener: UnixListener,
164        #[cfg(unix)]
165        path: String,
166        #[cfg(windows)]
167        pipe_name: String,
168        name: String,
169    }
170
171    /// A local socket stream for bidirectional communication.
172    pub struct LocalSocketStream {
173        #[cfg(unix)]
174        stream: UnixStream,
175        #[cfg(windows)]
176        handle: crate::windows::PipeHandle,
177        name: String,
178    }
179
180    impl LocalSocketListener {
181        /// Create a new local socket listener bound to the given name.
182        pub fn bind(name: &str) -> Result<Self> {
183            #[cfg(unix)]
184            {
185                let path = if name.starts_with('/') {
186                    name.to_string()
187                } else {
188                    format!("/tmp/{}.sock", name)
189                };
190
191                // Remove existing socket if any
192                let _ = std::fs::remove_file(&path);
193
194                let listener = UnixListener::bind(&path).map_err(|e| match e.kind() {
195                    std::io::ErrorKind::PermissionDenied => {
196                        IpcError::PermissionDenied(path.clone())
197                    }
198                    _ => IpcError::Io(e),
199                })?;
200
201                Ok(Self {
202                    listener,
203                    path,
204                    name: name.to_string(),
205                })
206            }
207
208            #[cfg(windows)]
209            {
210                let pipe_name = if name.starts_with(r"\\.\pipe\") {
211                    name.to_string()
212                } else {
213                    format!(r"\\.\pipe\{}", name)
214                };
215
216                Ok(Self {
217                    pipe_name,
218                    name: name.to_string(),
219                })
220            }
221        }
222
223        /// Accept a new incoming connection.
224        pub fn accept(&self) -> Result<LocalSocketStream> {
225            #[cfg(unix)]
226            {
227                let (stream, _) = self.listener.accept()?;
228                Ok(LocalSocketStream {
229                    stream,
230                    name: self.name.clone(),
231                })
232            }
233
234            #[cfg(windows)]
235            {
236                use crate::windows;
237                let handle = windows::create_named_pipe_for_server(&self.pipe_name)?;
238                windows::wait_for_client_handle(&handle)?;
239                Ok(LocalSocketStream {
240                    handle,
241                    name: self.name.clone(),
242                })
243            }
244        }
245
246        /// Get the name of this listener.
247        pub fn name(&self) -> &str {
248            &self.name
249        }
250
251        /// Returns an iterator over incoming connections.
252        pub fn incoming(&self) -> impl Iterator<Item = Result<LocalSocketStream>> + '_ {
253            std::iter::from_fn(move || Some(self.accept()))
254        }
255    }
256
257    #[cfg(unix)]
258    impl Drop for LocalSocketListener {
259        fn drop(&mut self) {
260            let _ = std::fs::remove_file(&self.path);
261        }
262    }
263
264    impl LocalSocketStream {
265        /// Connect to a local socket server.
266        pub fn connect(name: &str) -> Result<Self> {
267            #[cfg(unix)]
268            {
269                let path = if name.starts_with('/') {
270                    name.to_string()
271                } else {
272                    format!("/tmp/{}.sock", name)
273                };
274
275                let stream = UnixStream::connect(&path).map_err(|e| match e.kind() {
276                    std::io::ErrorKind::NotFound => IpcError::NotFound(path.clone()),
277                    std::io::ErrorKind::PermissionDenied => {
278                        IpcError::PermissionDenied(path.clone())
279                    }
280                    std::io::ErrorKind::ConnectionRefused => {
281                        IpcError::NotFound(format!("Connection refused: {}", path))
282                    }
283                    _ => IpcError::Io(e),
284                })?;
285
286                Ok(Self {
287                    stream,
288                    name: name.to_string(),
289                })
290            }
291
292            #[cfg(windows)]
293            {
294                use crate::windows;
295                let pipe_name = if name.starts_with(r"\\.\pipe\") {
296                    name.to_string()
297                } else {
298                    format!(r"\\.\pipe\{}", name)
299                };
300
301                let handle = windows::connect_to_named_pipe(&pipe_name)?;
302                Ok(Self {
303                    handle,
304                    name: name.to_string(),
305                })
306            }
307        }
308
309        /// Get the name of this stream.
310        pub fn name(&self) -> &str {
311            &self.name
312        }
313    }
314
315    impl Read for LocalSocketStream {
316        fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
317            #[cfg(unix)]
318            {
319                self.stream.read(buf)
320            }
321            #[cfg(windows)]
322            {
323                crate::windows::read_pipe(&self.handle, buf)
324            }
325        }
326    }
327
328    impl Write for LocalSocketStream {
329        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
330            #[cfg(unix)]
331            {
332                self.stream.write(buf)
333            }
334            #[cfg(windows)]
335            {
336                crate::windows::write_pipe(&self.handle, buf)
337            }
338        }
339
340        fn flush(&mut self) -> std::io::Result<()> {
341            #[cfg(unix)]
342            {
343                self.stream.flush()
344            }
345            #[cfg(windows)]
346            {
347                Ok(())
348            }
349        }
350    }
351}
352
353#[cfg(not(feature = "backend-interprocess"))]
354pub use native_backend::{LocalSocketListener, LocalSocketStream};
355
356// ============================================================================
357// Async support
358// ============================================================================
359
360#[cfg(all(feature = "async", feature = "backend-interprocess"))]
361pub mod async_socket {
362    //! Async local socket support using tokio.
363
364    use super::*;
365    use crate::error::IpcError;
366    use interprocess::local_socket::{
367        tokio::prelude::*, GenericFilePath, GenericNamespaced, ListenerOptions, ToFsName, ToNsName,
368    };
369    use tokio::io::{AsyncRead, AsyncWrite};
370
371    /// Async local socket listener.
372    pub struct AsyncLocalSocketListener {
373        inner: interprocess::local_socket::tokio::Listener,
374        name: String,
375    }
376
377    /// Async local socket stream.
378    pub struct AsyncLocalSocketStream {
379        inner: interprocess::local_socket::tokio::Stream,
380        name: String,
381    }
382
383    impl AsyncLocalSocketListener {
384        /// Create a new async local socket listener.
385        pub async fn bind(name: &str) -> Result<Self> {
386            let socket_name = get_async_socket_name(name)?;
387
388            let listener = ListenerOptions::new()
389                .name(socket_name)
390                .create_tokio()
391                .map_err(|e| IpcError::Io(std::io::Error::other(e)))?;
392
393            Ok(Self {
394                inner: listener,
395                name: name.to_string(),
396            })
397        }
398
399        /// Accept a new incoming connection asynchronously.
400        pub async fn accept(&self) -> Result<AsyncLocalSocketStream> {
401            let stream = self
402                .inner
403                .accept()
404                .await
405                .map_err(|e| IpcError::Io(std::io::Error::other(e)))?;
406
407            Ok(AsyncLocalSocketStream {
408                inner: stream,
409                name: self.name.clone(),
410            })
411        }
412
413        /// Get the name of this listener.
414        pub fn name(&self) -> &str {
415            &self.name
416        }
417    }
418
419    impl AsyncLocalSocketStream {
420        /// Connect to a local socket server asynchronously.
421        pub async fn connect(name: &str) -> Result<Self> {
422            let socket_name = get_async_socket_name(name)?;
423
424            let stream = interprocess::local_socket::tokio::Stream::connect(socket_name)
425                .await
426                .map_err(|e| IpcError::Io(std::io::Error::other(e)))?;
427
428            Ok(Self {
429                inner: stream,
430                name: name.to_string(),
431            })
432        }
433
434        /// Get the name of this stream.
435        pub fn name(&self) -> &str {
436            &self.name
437        }
438
439        /// Split into read and write halves.
440        pub fn into_split(
441            self,
442        ) -> (
443            tokio::io::ReadHalf<interprocess::local_socket::tokio::Stream>,
444            tokio::io::WriteHalf<interprocess::local_socket::tokio::Stream>,
445        ) {
446            tokio::io::split(self.inner)
447        }
448    }
449
450    impl AsyncRead for AsyncLocalSocketStream {
451        fn poll_read(
452            mut self: std::pin::Pin<&mut Self>,
453            cx: &mut std::task::Context<'_>,
454            buf: &mut tokio::io::ReadBuf<'_>,
455        ) -> std::task::Poll<std::io::Result<()>> {
456            std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
457        }
458    }
459
460    impl AsyncWrite for AsyncLocalSocketStream {
461        fn poll_write(
462            mut self: std::pin::Pin<&mut Self>,
463            cx: &mut std::task::Context<'_>,
464            buf: &[u8],
465        ) -> std::task::Poll<std::io::Result<usize>> {
466            std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
467        }
468
469        fn poll_flush(
470            mut self: std::pin::Pin<&mut Self>,
471            cx: &mut std::task::Context<'_>,
472        ) -> std::task::Poll<std::io::Result<()>> {
473            std::pin::Pin::new(&mut self.inner).poll_flush(cx)
474        }
475
476        fn poll_shutdown(
477            mut self: std::pin::Pin<&mut Self>,
478            cx: &mut std::task::Context<'_>,
479        ) -> std::task::Poll<std::io::Result<()>> {
480            std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
481        }
482    }
483
484    fn get_async_socket_name(name: &str) -> Result<interprocess::local_socket::Name<'static>> {
485        if let Ok(ns_name) = name.to_string().to_ns_name::<GenericNamespaced>() {
486            return Ok(ns_name);
487        }
488
489        let path = if cfg!(unix) {
490            if name.starts_with('/') {
491                name.to_string()
492            } else {
493                format!("/tmp/{}.sock", name)
494            }
495        } else if name.starts_with(r"\\.\pipe\") {
496            name.to_string()
497        } else {
498            format!(r"\\.\pipe\{}", name)
499        };
500
501        path.to_fs_name::<GenericFilePath>()
502            .map_err(|e| IpcError::Io(std::io::Error::other(e)))
503    }
504}
505
506#[cfg(all(feature = "async", feature = "backend-interprocess"))]
507pub use async_socket::{AsyncLocalSocketListener, AsyncLocalSocketStream};
508
509#[cfg(test)]
510mod tests {
511    use super::*;
512    use std::thread;
513
514    #[test]
515    fn test_local_socket_communication() {
516        let server_name = format!("test_socket_{}", std::process::id());
517
518        // Create server in a separate thread
519        let server_name_clone = server_name.clone();
520        let server_thread = thread::spawn(move || {
521            let listener = LocalSocketListener::bind(&server_name_clone).unwrap();
522            let mut stream = listener.accept().unwrap();
523
524            let mut buf = [0u8; 32];
525            let n = stream.read(&mut buf).unwrap();
526            assert_eq!(&buf[..n], b"Hello, Server!");
527
528            stream.write_all(b"Hello, Client!").unwrap();
529        });
530
531        // Give server time to start
532        thread::sleep(std::time::Duration::from_millis(100));
533
534        // Connect as client
535        let mut client = LocalSocketStream::connect(&server_name).unwrap();
536        client.write_all(b"Hello, Server!").unwrap();
537
538        let mut buf = [0u8; 32];
539        let n = client.read(&mut buf).unwrap();
540        assert_eq!(&buf[..n], b"Hello, Client!");
541
542        server_thread.join().unwrap();
543    }
544}