distant_net/common/listener/
unix.rs1use std::os::unix::fs::PermissionsExt;
2use std::path::{Path, PathBuf};
3use std::{fmt, io};
4
5use async_trait::async_trait;
6use tokio::net::{UnixListener, UnixStream};
7
8use super::Listener;
9use crate::common::UnixSocketTransport;
10
11pub struct UnixSocketListener {
13 path: PathBuf,
14 inner: tokio::net::UnixListener,
15}
16
17impl UnixSocketListener {
18 pub async fn bind(path: impl AsRef<Path>) -> io::Result<Self> {
22 Self::bind_with_permissions(path, Self::default_unix_socket_file_permissions()).await
23 }
24
25 pub async fn bind_with_permissions(path: impl AsRef<Path>, mode: u32) -> io::Result<Self> {
28 let listener = match UnixListener::bind(path.as_ref()) {
31 Ok(listener) => listener,
32 Err(_) => {
33 if UnixStream::connect(path.as_ref()).await.is_ok() {
35 return Err(io::Error::from(io::ErrorKind::AddrInUse));
36 }
37
38 tokio::fs::remove_file(path.as_ref()).await?;
40
41 UnixListener::bind(path.as_ref())?
42 }
43 };
44
45 let mut permissions = tokio::fs::metadata(path.as_ref()).await?.permissions();
51 permissions.set_mode(mode);
52 tokio::fs::set_permissions(path.as_ref(), permissions).await?;
53
54 Ok(Self {
55 path: path.as_ref().to_path_buf(),
56 inner: listener,
57 })
58 }
59
60 pub fn path(&self) -> &Path {
62 &self.path
63 }
64
65 pub const fn default_unix_socket_file_permissions() -> u32 {
67 0o600
68 }
69}
70
71impl fmt::Debug for UnixSocketListener {
72 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73 f.debug_struct("UnixSocketListener")
74 .field("path", &self.path)
75 .finish()
76 }
77}
78
79#[async_trait]
80impl Listener for UnixSocketListener {
81 type Output = UnixSocketTransport;
82
83 async fn accept(&mut self) -> io::Result<Self::Output> {
84 let (stream, _) = tokio::net::UnixListener::accept(&self.inner).await?;
88 Ok(UnixSocketTransport {
89 path: self.path.to_path_buf(),
90 inner: stream,
91 })
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use tempfile::NamedTempFile;
98 use test_log::test;
99 use tokio::sync::oneshot;
100 use tokio::task::JoinHandle;
101
102 use super::*;
103 use crate::common::TransportExt;
104
105 #[test(tokio::test)]
106 async fn should_succeed_to_bind_if_file_exists_at_path_but_nothing_listening() {
107 let path = NamedTempFile::new()
109 .expect("Failed to create file")
110 .into_temp_path();
111
112 UnixSocketListener::bind(&path)
114 .await
115 .expect("Unexpectedly failed to bind to existing file");
116 }
117
118 #[test(tokio::test)]
119 async fn should_fail_to_bind_if_socket_already_bound() {
120 let path = NamedTempFile::new()
122 .expect("Failed to create socket file")
123 .path()
124 .to_path_buf();
125
126 let _listener = UnixSocketListener::bind(&path)
128 .await
129 .expect("Unexpectedly failed to bind first time");
130
131 UnixSocketListener::bind(&path)
133 .await
134 .expect_err("Unexpectedly succeeded in binding to same socket");
135 }
136
137 #[test(tokio::test)]
138 async fn should_be_able_to_receive_connections_and_read_and_write_data_with_them() {
139 let (tx, rx) = oneshot::channel();
140
141 let task: JoinHandle<io::Result<()>> = tokio::spawn(async move {
144 let path = NamedTempFile::new()
146 .expect("Failed to create socket file")
147 .path()
148 .to_path_buf();
149
150 let mut listener = UnixSocketListener::bind(&path).await?;
152
153 tx.send(path)
155 .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?;
156
157 let conn_1 = listener.accept().await?;
159
160 conn_1.write_all(b"hello conn 1").await?;
162
163 let mut buf: [u8; 14] = [0; 14];
165 let _ = conn_1.read_exact(&mut buf).await?;
166 assert_eq!(&buf, b"hello server 1");
167
168 let conn_2 = listener.accept().await?;
170
171 conn_2.write_all(b"hello conn 2").await?;
173
174 let mut buf: [u8; 14] = [0; 14];
176 let _ = conn_2.read_exact(&mut buf).await?;
177 assert_eq!(&buf, b"hello server 2");
178
179 Ok(())
180 });
181
182 let path = rx.await.expect("Failed to get server socket path");
184
185 let mut buf: [u8; 12] = [0; 12];
187
188 let conn = UnixSocketTransport::connect(&path)
189 .await
190 .expect("Conn 1 failed to connect");
191 conn.write_all(b"hello server 1")
192 .await
193 .expect("Conn 1 failed to write");
194 conn.read_exact(&mut buf)
195 .await
196 .expect("Conn 1 failed to read");
197 assert_eq!(&buf, b"hello conn 1");
198
199 let conn = UnixSocketTransport::connect(&path)
200 .await
201 .expect("Conn 2 failed to connect");
202 conn.write_all(b"hello server 2")
203 .await
204 .expect("Conn 2 failed to write");
205 conn.read_exact(&mut buf)
206 .await
207 .expect("Conn 2 failed to read");
208 assert_eq!(&buf, b"hello conn 2");
209
210 let _ = task.await.expect("Listener task failed unexpectedly");
212 }
213}