distant_net/common/transport/
unix.rs1use std::path::{Path, PathBuf};
2use std::{fmt, io};
3
4use async_trait::async_trait;
5use tokio::net::UnixStream;
6
7use super::{Interest, Ready, Reconnectable, Transport};
8
9pub struct UnixSocketTransport {
11 pub(crate) path: PathBuf,
12 pub(crate) inner: UnixStream,
13}
14
15impl UnixSocketTransport {
16 pub async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
18 let stream = UnixStream::connect(path.as_ref()).await?;
19 Ok(Self {
20 path: path.as_ref().to_path_buf(),
21 inner: stream,
22 })
23 }
24
25 pub fn path(&self) -> &Path {
27 &self.path
28 }
29}
30
31impl fmt::Debug for UnixSocketTransport {
32 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33 f.debug_struct("UnixSocketTransport")
34 .field("path", &self.path)
35 .finish()
36 }
37}
38
39#[async_trait]
40impl Reconnectable for UnixSocketTransport {
41 async fn reconnect(&mut self) -> io::Result<()> {
42 self.inner = UnixStream::connect(self.path.as_path()).await?;
43 Ok(())
44 }
45}
46
47#[async_trait]
48impl Transport for UnixSocketTransport {
49 fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
50 self.inner.try_read(buf)
51 }
52
53 fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
54 self.inner.try_write(buf)
55 }
56
57 async fn ready(&self, interest: Interest) -> io::Result<Ready> {
58 self.inner.ready(interest).await
59 }
60}
61
62#[cfg(test)]
63mod tests {
64 use tempfile::NamedTempFile;
65 use test_log::test;
66 use tokio::io::{AsyncReadExt, AsyncWriteExt};
67 use tokio::net::UnixListener;
68 use tokio::sync::oneshot;
69 use tokio::task::JoinHandle;
70
71 use super::*;
72 use crate::common::TransportExt;
73
74 async fn start_and_run_server(tx: oneshot::Sender<PathBuf>) -> io::Result<()> {
75 let path = NamedTempFile::new()
77 .expect("Failed to create socket file")
78 .path()
79 .to_path_buf();
80
81 let listener = UnixListener::bind(&path)?;
83
84 tx.send(path)
86 .map_err(|x| io::Error::new(io::ErrorKind::Other, x.display().to_string()))?;
87
88 run_server(listener).await
89 }
90
91 async fn run_server(listener: UnixListener) -> io::Result<()> {
92 let (mut conn, _) = listener.accept().await?;
94
95 conn.write_all(b"hello conn").await?;
97
98 let mut buf: [u8; 12] = [0; 12];
100 let _ = conn.read_exact(&mut buf).await?;
101 assert_eq!(&buf, b"hello server");
102
103 Ok(())
104 }
105
106 #[test(tokio::test)]
107 async fn should_fail_to_connect_if_socket_does_not_exist() {
108 let path = NamedTempFile::new()
110 .expect("Failed to create socket file")
111 .path()
112 .to_path_buf();
113
114 UnixSocketTransport::connect(&path)
116 .await
117 .expect_err("Unexpectedly succeeded in connecting to missing socket");
118 }
119
120 #[test(tokio::test)]
121 async fn should_fail_to_connect_if_path_is_not_a_socket() {
122 let path = NamedTempFile::new()
124 .expect("Failed to create socket file")
125 .into_temp_path();
126
127 UnixSocketTransport::connect(&path)
129 .await
130 .expect_err("Unexpectedly succeeded in connecting to regular file");
131 }
132
133 #[test(tokio::test)]
134 async fn should_be_able_to_read_and_write_data() {
135 let (tx, rx) = oneshot::channel();
136
137 let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
140
141 let path = rx.await.expect("Failed to get server socket path");
143
144 let mut buf: [u8; 10] = [0; 10];
146
147 let conn = UnixSocketTransport::connect(&path)
148 .await
149 .expect("Conn failed to connect");
150 conn.read_exact(&mut buf)
151 .await
152 .expect("Conn failed to read");
153 assert_eq!(&buf, b"hello conn");
154
155 conn.write_all(b"hello server")
156 .await
157 .expect("Conn failed to write");
158
159 let _ = task.await.expect("Server task failed unexpectedly");
161 }
162
163 #[test(tokio::test)]
164 async fn should_be_able_to_reconnect() {
165 let (tx, rx) = oneshot::channel();
166
167 let task: JoinHandle<io::Result<()>> = tokio::spawn(start_and_run_server(tx));
170
171 let path = rx.await.expect("Failed to get server socket path");
173
174 let mut conn = UnixSocketTransport::connect(&path)
176 .await
177 .expect("Conn failed to connect");
178
179 task.abort();
181
182 conn.readable()
184 .await
185 .expect("Failed to wait for conn to be readable");
186 let res = conn.read_exact(&mut [0; 10]).await;
187 assert!(
188 matches!(res, Ok(0) | Err(_)),
189 "Unexpected read result: {res:?}"
190 );
191
192 let _ = tokio::fs::remove_file(&path).await;
194 let task: JoinHandle<io::Result<()>> = tokio::spawn(run_server(
195 UnixListener::bind(&path).expect("Failed to rebind server"),
196 ));
197
198 let mut buf: [u8; 10] = [0; 10];
200 conn.reconnect().await.expect("Conn failed to reconnect");
201
202 conn.read_exact(&mut buf)
204 .await
205 .expect("Conn failed to read");
206 assert_eq!(&buf, b"hello conn");
207
208 conn.write_all(b"hello server")
209 .await
210 .expect("Conn failed to write");
211
212 let _ = task.await.expect("Server task failed unexpectedly");
214 }
215}