distant_net/common/listener/
unix.rs

1use std::os::unix::fs::PermissionsExt;
2use std::path::{Path, PathBuf};
3use std::{fmt, io};
4
5use async_trait::async_trait;
6use tokio::net::{UnixListener, UnixStream};
7
8use super::Listener;
9use crate::common::UnixSocketTransport;
10
11/// Represents a [`Listener`] for incoming connections over a Unix socket
12pub struct UnixSocketListener {
13    path: PathBuf,
14    inner: tokio::net::UnixListener,
15}
16
17impl UnixSocketListener {
18    /// Creates a new listener by binding to the specified path, failing if the path already
19    /// exists. Sets permission of unix socket to `0o600` where only the owner can read from and
20    /// write to the socket.
21    pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
22        Self::bind_with_permissions(path, Self::default_unix_socket_file_permissions()).await
23    }
24
25    /// Creates a new listener by binding to the specified path, failing if the path already
26    /// exists. Sets the unix socket file permissions to `mode`.
27    pub async fn bind_with_permissions(path: impl AsRef<Path>, mode: u32) -> io::Result<Self> {
28        // Attempt to bind to the path, and if we fail, we see if we can connect
29        // to the path -- if not, we can try to delete the path and start again
30        let listener = match UnixListener::bind(path.as_ref()) {
31            Ok(listener) => listener,
32            Err(_) => {
33                // If we can connect to the path, then it's already in use
34                if UnixStream::connect(path.as_ref()).await.is_ok() {
35                    return Err(io::Error::from(io::ErrorKind::AddrInUse));
36                }
37
38                // Otherwise, remove the file and try again
39                tokio::fs::remove_file(path.as_ref()).await?;
40
41                UnixListener::bind(path.as_ref())?
42            }
43        };
44
45        // TODO: We should be setting this permission during bind, but neither std library nor
46        //       tokio have support for this. We would need to create our own raw socket and
47        //       use libc to change the permissions via the raw file descriptor
48        //
49        // See https://github.com/chipsenkbeil/distant/issues/111
50        let mut permissions = tokio::fs::metadata(path.as_ref()).await?.permissions();
51        permissions.set_mode(mode);
52        tokio::fs::set_permissions(path.as_ref(), permissions).await?;
53
54        Ok(Self {
55            path: path.as_ref().to_path_buf(),
56            inner: listener,
57        })
58    }
59
60    /// Returns the path to the socket
61    pub fn path(&self) -> &Path {
62        &self.path
63    }
64
65    /// Returns the default unix socket file permissions as an octal (e.g. `0o600`)
66    pub const fn default_unix_socket_file_permissions() -> u32 {
67        0o600
68    }
69}
70
71impl fmt::Debug for UnixSocketListener {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        f.debug_struct("UnixSocketListener")
74            .field("path", &self.path)
75            .finish()
76    }
77}
78
79#[async_trait]
80impl Listener for UnixSocketListener {
81    type Output = UnixSocketTransport;
82
83    async fn accept(&mut self) -> io::Result<Self::Output> {
84        // NOTE: Address provided is unnamed, or at least the `as_pathname()` method is
85        //       returning none, so we use our listener's path, which is the same as
86        //       what is being connected, anyway
87        let (stream, _) = tokio::net::UnixListener::accept(&self.inner).await?;
88        Ok(UnixSocketTransport {
89            path: self.path.to_path_buf(),
90            inner: stream,
91        })
92    }
93}
94
95#[cfg(test)]
96mod tests {
97    use tempfile::NamedTempFile;
98    use test_log::test;
99    use tokio::sync::oneshot;
100    use tokio::task::JoinHandle;
101
102    use super::*;
103    use crate::common::TransportExt;
104
105    #[test(tokio::test)]
106    async fn should_succeed_to_bind_if_file_exists_at_path_but_nothing_listening() {
107        // Generate a socket path
108        let path = NamedTempFile::new()
109            .expect("Failed to create file")
110            .into_temp_path();
111
112        // This should fail as we're already got a file at the path
113        UnixSocketListener::bind(&path)
114            .await
115            .expect("Unexpectedly failed to bind to existing file");
116    }
117
118    #[test(tokio::test)]
119    async fn should_fail_to_bind_if_socket_already_bound() {
120        // Generate a socket path and delete the file after
121        let path = NamedTempFile::new()
122            .expect("Failed to create socket file")
123            .path()
124            .to_path_buf();
125
126        // Listen at the socket
127        let _listener = UnixSocketListener::bind(&path)
128            .await
129            .expect("Unexpectedly failed to bind first time");
130
131        // Now this should fail as we're already bound to the path
132        UnixSocketListener::bind(&path)
133            .await
134            .expect_err("Unexpectedly succeeded in binding to same socket");
135    }
136
137    #[test(tokio::test)]
138    async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() {
139        let (tx, rx) = oneshot::channel();
140
141        // Spawn a task that will wait for two connections and then
142        // return the success or failure
143        let task: JoinHandle<io::Result<()>> = tokio::spawn(async move {
144            // Generate a socket path and delete the file after
145            let path = NamedTempFile::new()
146                .expect("Failed to create socket file")
147                .path()
148                .to_path_buf();
149
150            // Listen at the socket
151            let mut listener = UnixSocketListener::bind(&path).await?;
152
153            // Send the name path to our main test thread
154            tx.send(path)
155                .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?;
156
157            // Get first connection
158            let conn_1 = listener.accept().await?;
159
160            // Send some data to the first connection (12 bytes)
161            conn_1.write_all(b"hello conn 1").await?;
162
163            // Get some data from the first connection (14 bytes)
164            let mut buf: [u8; 14] = [0; 14];
165            let _ = conn_1.read_exact(&mut buf).await?;
166            assert_eq!(&buf, b"hello server 1");
167
168            // Get second connection
169            let conn_2 = listener.accept().await?;
170
171            // Send some data on to second connection (12 bytes)
172            conn_2.write_all(b"hello conn 2").await?;
173
174            // Get some data from the second connection (14 bytes)
175            let mut buf: [u8; 14] = [0; 14];
176            let _ = conn_2.read_exact(&mut buf).await?;
177            assert_eq!(&buf, b"hello server 2");
178
179            Ok(())
180        });
181
182        // Wait for the server to be ready
183        let path = rx.await.expect("Failed to get server socket path");
184
185        // Connect to the listener twice, sending some bytes and receiving some bytes from each
186        let mut buf: [u8; 12] = [0; 12];
187
188        let conn = UnixSocketTransport::connect(&path)
189            .await
190            .expect("Conn 1 failed to connect");
191        conn.write_all(b"hello server 1")
192            .await
193            .expect("Conn 1 failed to write");
194        conn.read_exact(&mut buf)
195            .await
196            .expect("Conn 1 failed to read");
197        assert_eq!(&buf, b"hello conn 1");
198
199        let conn = UnixSocketTransport::connect(&path)
200            .await
201            .expect("Conn 2 failed to connect");
202        conn.write_all(b"hello server 2")
203            .await
204            .expect("Conn 2 failed to write");
205        conn.read_exact(&mut buf)
206            .await
207            .expect("Conn 2 failed to read");
208        assert_eq!(&buf, b"hello conn 2");
209
210        // Verify that the task has completed by waiting on it
211        let _ = task.await.expect("Listener task failed unexpectedly");
212    }
213}