distant_net/common/transport/
unix.rs

1use std::path::{Path, PathBuf};
2use std::{fmt, io};
3
4use async_trait::async_trait;
5use tokio::net::UnixStream;
6
7use super::{Interest, Ready, Reconnectable, Transport};
8
9/// Represents a [`Transport`] that leverages a Unix socket
10pub struct UnixSocketTransport {
11    pub(crate) path: PathBuf,
12    pub(crate) inner: UnixStream,
13}
14
15impl UnixSocketTransport {
16    /// Creates a new stream by connecting to the specified path
17    pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
18        let stream = UnixStream::connect(path.as_ref()).await?;
19        Ok(Self {
20            path: path.as_ref().to_path_buf(),
21            inner: stream,
22        })
23    }
24
25    /// Returns the path to the socket
26    pub fn path(&self) -> &Path {
27        &self.path
28    }
29}
30
31impl fmt::Debug for UnixSocketTransport {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        f.debug_struct("UnixSocketTransport")
34            .field("path", &self.path)
35            .finish()
36    }
37}
38
39#[async_trait]
40impl Reconnectable for UnixSocketTransport {
41    async fn reconnect(&mut self) -> io::Result<()> {
42        self.inner = UnixStream::connect(self.path.as_path()).await?;
43        Ok(())
44    }
45}
46
47#[async_trait]
48impl Transport for UnixSocketTransport {
49    fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
50        self.inner.try_read(buf)
51    }
52
53    fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
54        self.inner.try_write(buf)
55    }
56
57    async fn ready(&self, interest: Interest) -> io::Result<Ready> {
58        self.inner.ready(interest).await
59    }
60}
61
62#[cfg(test)]
63mod tests {
64    use tempfile::NamedTempFile;
65    use test_log::test;
66    use tokio::io::{AsyncReadExt, AsyncWriteExt};
67    use tokio::net::UnixListener;
68    use tokio::sync::oneshot;
69    use tokio::task::JoinHandle;
70
71    use super::*;
72    use crate::common::TransportExt;
73
74    async fn start_and_run_server(tx: oneshot::Sender<PathBuf>) -> io::Result<()> {
75        // Generate a socket path and delete the file after so there is nothing there
76        let path = NamedTempFile::new()
77            .expect("Failed to create socket file")
78            .path()
79            .to_path_buf();
80
81        // Start listening at the socket path
82        let listener = UnixListener::bind(&path)?;
83
84        // Send the path back to our main test thread
85        tx.send(path)
86            .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?;
87
88        run_server(listener).await
89    }
90
91    async fn run_server(listener: UnixListener) -> io::Result<()> {
92        // Get the connection
93        let (mut conn, _) = listener.accept().await?;
94
95        // Send some data to the connection (10 bytes)
96        conn.write_all(b"hello conn").await?;
97
98        // Receive some data from the connection (12 bytes)
99        let mut buf: [u8; 12] = [0; 12];
100        let _ = conn.read_exact(&mut buf).await?;
101        assert_eq!(&buf, b"hello server");
102
103        Ok(())
104    }
105
106    #[test(tokio::test)]
107    async fn should_fail_to_connect_if_socket_does_not_exist() {
108        // Generate a socket path and delete the file after so there is nothing there
109        let path = NamedTempFile::new()
110            .expect("Failed to create socket file")
111            .path()
112            .to_path_buf();
113
114        // Now this should fail as we're already bound to the name
115        UnixSocketTransport::connect(&path)
116            .await
117            .expect_err("Unexpectedly succeeded in connecting to missing socket");
118    }
119
120    #[test(tokio::test)]
121    async fn should_fail_to_connect_if_path_is_not_a_socket() {
122        // Generate a regular file
123        let path = NamedTempFile::new()
124            .expect("Failed to create socket file")
125            .into_temp_path();
126
127        // Now this should fail as this file is not a socket
128        UnixSocketTransport::connect(&path)
129            .await
130            .expect_err("Unexpectedly succeeded in connecting to regular file");
131    }
132
133    #[test(tokio::test)]
134    async fn should_be_able_to_read_and_write_data() {
135        let (tx, rx) = oneshot::channel();
136
137        // Spawn a task that will wait for a connection, send data,
138        // and receive data that it will return in the task
139        let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
140
141        // Wait for the server to be ready
142        let path = rx.await.expect("Failed to get server socket path");
143
144        // Connect to the socket, send some bytes, and get some bytes
145        let mut buf: [u8; 10] = [0; 10];
146
147        let conn = UnixSocketTransport::connect(&path)
148            .await
149            .expect("Conn failed to connect");
150        conn.read_exact(&mut buf)
151            .await
152            .expect("Conn failed to read");
153        assert_eq!(&buf, b"hello conn");
154
155        conn.write_all(b"hello server")
156            .await
157            .expect("Conn failed to write");
158
159        // Verify that the task has completed by waiting on it
160        let _ = task.await.expect("Server task failed unexpectedly");
161    }
162
163    #[test(tokio::test)]
164    async fn should_be_able_to_reconnect() {
165        let (tx, rx) = oneshot::channel();
166
167        // Spawn a task that will wait for a connection, send data,
168        // and receive data that it will return in the task
169        let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
170
171        // Wait for the server to be ready
172        let path = rx.await.expect("Failed to get server socket path");
173
174        // Connect to the server
175        let mut conn = UnixSocketTransport::connect(&path)
176            .await
177            .expect("Conn failed to connect");
178
179        // Kill the server to make the connection fail
180        task.abort();
181
182        // Verify the connection fails by trying to read from it (should get connection reset)
183        conn.readable()
184            .await
185            .expect("Failed to wait for conn to be readable");
186        let res = conn.read_exact(&mut [0; 10]).await;
187        assert!(
188            matches!(res, Ok(0) | Err(_)),
189            "Unexpected read result: {res:?}"
190        );
191
192        // Restart the server (need to remove the socket file)
193        let _ = tokio::fs::remove_file(&path).await;
194        let task: JoinHandle<io::Result<()>> = tokio::spawn(run_server(
195            UnixListener::bind(&path).expect("Failed to rebind server"),
196        ));
197
198        // Reconnect to the socket, send some bytes, and get some bytes
199        let mut buf: [u8; 10] = [0; 10];
200        conn.reconnect().await.expect("Conn failed to reconnect");
201
202        // Continually read until we get all of the data
203        conn.read_exact(&mut buf)
204            .await
205            .expect("Conn failed to read");
206        assert_eq!(&buf, b"hello conn");
207
208        conn.write_all(b"hello server")
209            .await
210            .expect("Conn failed to write");
211
212        // Verify that the task has completed by waiting on it
213        let _ = task.await.expect("Server task failed unexpectedly");
214    }
215}