1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
use crate::{Listener, UnixSocketTransport};
use async_trait::async_trait;
use std::{
    fmt, io,
    os::unix::fs::PermissionsExt,
    path::{Path, PathBuf},
};
use tokio::net::{UnixListener, UnixStream};

/// Represents a [`Listener`] for incoming connections over a Unix socket
pub struct UnixSocketListener {
    path: PathBuf,
    inner: tokio::net::UnixListener,
}

impl UnixSocketListener {
    /// Creates a new listener by binding to the specified path, failing if the path already
    /// exists. Sets permission of unix socket to `0o600` where only the owner can read from and
    /// write to the socket.
    pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
        Self::bind_with_permissions(path, Self::default_unix_socket_file_permissions()).await
    }

    /// Creates a new listener by binding to the specified path, failing if the path already
    /// exists. Sets the unix socket file permissions to `mode`.
    pub async fn bind_with_permissions(path: impl AsRef<Path>, mode: u32) -> io::Result<Self> {
        // Attempt to bind to the path, and if we fail, we see if we can connect
        // to the path -- if not, we can try to delete the path and start again
        let listener = match UnixListener::bind(path.as_ref()) {
            Ok(listener) => listener,
            Err(_) => {
                // If we can connect to the path, then it's already in use
                if UnixStream::connect(path.as_ref()).await.is_ok() {
                    return Err(io::Error::from(io::ErrorKind::AddrInUse));
                }

                // Otherwise, remove the file and try again
                tokio::fs::remove_file(path.as_ref()).await?;

                UnixListener::bind(path.as_ref())?
            }
        };

        // TODO: We should be setting this permission during bind, but neither std library nor
        //       tokio have support for this. We would need to create our own raw socket and
        //       use libc to change the permissions via the raw file descriptor
        //
        // See https://github.com/chipsenkbeil/distant/issues/111
        let mut permissions = tokio::fs::metadata(path.as_ref()).await?.permissions();
        permissions.set_mode(mode);
        tokio::fs::set_permissions(path.as_ref(), permissions).await?;

        Ok(Self {
            path: path.as_ref().to_path_buf(),
            inner: listener,
        })
    }

    /// Returns the path to the socket
    pub fn path(&self) -> &Path {
        &self.path
    }

    /// Returns the default unix socket file permissions as an octal (e.g. `0o600`)
    pub const fn default_unix_socket_file_permissions() -> u32 {
        0o600
    }
}

impl fmt::Debug for UnixSocketListener {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("UnixSocketListener")
            .field("path", &self.path)
            .finish()
    }
}

#[async_trait]
impl Listener for UnixSocketListener {
    type Output = UnixSocketTransport;

    async fn accept(&mut self) -> io::Result<Self::Output> {
        // NOTE: Address provided is unnamed, or at least the `as_pathname()` method is
        //       returning none, so we use our listener's path, which is the same as
        //       what is being connected, anyway
        let (stream, _) = tokio::net::UnixListener::accept(&self.inner).await?;
        Ok(UnixSocketTransport {
            path: self.path.to_path_buf(),
            inner: stream,
        })
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use tempfile::NamedTempFile;
    use tokio::{
        io::{AsyncReadExt, AsyncWriteExt},
        sync::oneshot,
        task::JoinHandle,
    };

    #[tokio::test]
    async fn should_succeed_to_bind_if_file_exists_at_path_but_nothing_listening() {
        // Generate a socket path
        let path = NamedTempFile::new()
            .expect("Failed to create file")
            .into_temp_path();

        // This should fail as we're already got a file at the path
        UnixSocketListener::bind(&path)
            .await
            .expect("Unexpectedly failed to bind to existing file");
    }

    #[tokio::test]
    async fn should_fail_to_bind_if_socket_already_bound() {
        // Generate a socket path and delete the file after
        let path = NamedTempFile::new()
            .expect("Failed to create socket file")
            .path()
            .to_path_buf();

        // Listen at the socket
        let _listener = UnixSocketListener::bind(&path)
            .await
            .expect("Unexpectedly failed to bind first time");

        // Now this should fail as we're already bound to the path
        UnixSocketListener::bind(&path)
            .await
            .expect_err("Unexpectedly succeeded in binding to same socket");
    }

    #[tokio::test]
    async fn should_be_able_to_receive_connections_and_send_and_receive_data_with_them() {
        let (tx, rx) = oneshot::channel();

        // Spawn a task that will wait for two connections and then
        // return the success or failure
        let task: JoinHandle<io::Result<()>> = tokio::spawn(async move {
            // Generate a socket path and delete the file after
            let path = NamedTempFile::new()
                .expect("Failed to create socket file")
                .path()
                .to_path_buf();

            // Listen at the socket
            let mut listener = UnixSocketListener::bind(&path).await?;

            // Send the name path to our main test thread
            tx.send(path)
                .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?;

            // Get first connection
            let mut conn_1 = listener.accept().await?;

            // Send some data to the first connection (12 bytes)
            conn_1.write_all(b"hello conn 1").await?;

            // Get some data from the first connection (14 bytes)
            let mut buf: [u8; 14] = [0; 14];
            let _ = conn_1.read_exact(&mut buf).await?;
            assert_eq!(&buf, b"hello server 1");

            // Get second connection
            let mut conn_2 = listener.accept().await?;

            // Send some data on to second connection (12 bytes)
            conn_2.write_all(b"hello conn 2").await?;

            // Get some data from the second connection (14 bytes)
            let mut buf: [u8; 14] = [0; 14];
            let _ = conn_2.read_exact(&mut buf).await?;
            assert_eq!(&buf, b"hello server 2");

            Ok(())
        });

        // Wait for the server to be ready
        let path = rx.await.expect("Failed to get server socket path");

        // Connect to the listener twice, sending some bytes and receiving some bytes from each
        let mut buf: [u8; 12] = [0; 12];

        let mut conn = UnixSocketTransport::connect(&path)
            .await
            .expect("Conn 1 failed to connect");
        conn.write_all(b"hello server 1")
            .await
            .expect("Conn 1 failed to write");
        conn.read_exact(&mut buf)
            .await
            .expect("Conn 1 failed to read");
        assert_eq!(&buf, b"hello conn 1");

        let mut conn = UnixSocketTransport::connect(&path)
            .await
            .expect("Conn 2 failed to connect");
        conn.write_all(b"hello server 2")
            .await
            .expect("Conn 2 failed to write");
        conn.read_exact(&mut buf)
            .await
            .expect("Conn 2 failed to read");
        assert_eq!(&buf, b"hello conn 2");

        // Verify that the task has completed by waiting on it
        let _ = task.await.expect("Listener task failed unexpectedly");
    }
}