Skip to main content

roam_local/
unix.rs

1//! Unix socket implementation for local IPC.
2
3use std::io;
4use std::path::Path;
5use tokio::net::{UnixListener, UnixStream};
6
7/// A local IPC stream (Unix socket on Unix platforms).
8pub type LocalStream = UnixStream;
9
10/// A local IPC listener (Unix socket listener on Unix platforms).
11pub struct LocalListener {
12    inner: UnixListener,
13}
14
15impl LocalListener {
16    /// Bind to the given socket path.
17    ///
18    /// The path should be a filesystem path where the socket file will be created.
19    /// Parent directories must exist.
20    pub fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
21        let inner = UnixListener::bind(path)?;
22        Ok(Self { inner })
23    }
24
25    /// Accept a new connection.
26    ///
27    /// Returns the stream for the new connection.
28    pub async fn accept(&self) -> io::Result<LocalStream> {
29        let (stream, _addr) = self.inner.accept().await?;
30        Ok(stream)
31    }
32}
33
34/// Connect to a local IPC endpoint.
35///
36/// On Unix, this connects to a Unix socket at the given path.
37pub async fn connect(path: impl AsRef<Path>) -> io::Result<LocalStream> {
38    UnixStream::connect(path).await
39}
40
41/// Check if a local IPC endpoint exists.
42///
43/// On Unix, this checks if the socket file exists.
44pub fn endpoint_exists(path: impl AsRef<Path>) -> bool {
45    path.as_ref().exists()
46}
47
48/// Remove a local IPC endpoint.
49///
50/// On Unix, this removes the socket file.
51pub fn remove_endpoint(path: impl AsRef<Path>) -> io::Result<()> {
52    std::fs::remove_file(path)
53}
54
55#[cfg(test)]
56mod tests {
57    use super::*;
58    use std::time::{SystemTime, UNIX_EPOCH};
59    use tokio::io::{AsyncReadExt, AsyncWriteExt};
60
61    fn unique_socket_path(tag: &str) -> std::path::PathBuf {
62        let nanos = SystemTime::now()
63            .duration_since(UNIX_EPOCH)
64            .expect("clock should be after unix epoch")
65            .as_nanos();
66        std::path::PathBuf::from(format!("/tmp/rl-{tag}-{}-{nanos}.sock", std::process::id()))
67    }
68
69    #[tokio::test]
70    async fn endpoint_lifecycle_bind_connect_accept_remove() {
71        let path = unique_socket_path("lifecycle");
72        assert!(!endpoint_exists(&path));
73
74        let listener = LocalListener::bind(&path).expect("bind should succeed");
75        assert!(endpoint_exists(&path));
76
77        let server = tokio::spawn(async move {
78            let mut stream = listener.accept().await.expect("accept should succeed");
79            let mut buf = [0_u8; 4];
80            stream
81                .read_exact(&mut buf)
82                .await
83                .expect("server read should succeed");
84            assert_eq!(&buf, b"ping");
85            stream
86                .write_all(b"pong")
87                .await
88                .expect("server write should succeed");
89        });
90
91        let mut client = connect(&path).await.expect("connect should succeed");
92        client
93            .write_all(b"ping")
94            .await
95            .expect("client write should succeed");
96        let mut buf = [0_u8; 4];
97        client
98            .read_exact(&mut buf)
99            .await
100            .expect("client read should succeed");
101        assert_eq!(&buf, b"pong");
102
103        server.await.expect("server task should complete");
104        remove_endpoint(&path).expect("remove endpoint should succeed");
105        assert!(!endpoint_exists(&path));
106    }
107
108    #[tokio::test]
109    async fn connect_to_missing_endpoint_returns_not_found() {
110        let path = unique_socket_path("missing-connect");
111        let err = connect(&path)
112            .await
113            .expect_err("connect should fail for missing endpoint");
114        assert_eq!(err.kind(), io::ErrorKind::NotFound);
115    }
116
117    #[test]
118    fn remove_missing_endpoint_returns_not_found() {
119        let path = unique_socket_path("missing-remove");
120        let err = remove_endpoint(&path).expect_err("remove should fail for missing endpoint");
121        assert_eq!(err.kind(), io::ErrorKind::NotFound);
122    }
123}