1use std::io::{self, Seek};
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use rustls::ServerConfig;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpListener;
8#[cfg(unix)]
9use tokio::net::UnixListener;
10use tokio_rustls::TlsAcceptor;
11
12#[cfg(windows)]
13mod win_uds_compat {
14 use std::io;
15 use std::pin::Pin;
16 use std::task::{Context, Poll};
17 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18 use tokio_util::compat::Compat;
19 use win_uds::net::{AsyncListener, AsyncStream};
20
21 pub struct WinUnixStream(Compat<AsyncStream>);
22
23 impl WinUnixStream {
24 pub async fn connect<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
25 use tokio_util::compat::FuturesAsyncReadCompatExt;
26 let stream = AsyncStream::connect(path).await?;
27 Ok(Self(stream.compat()))
28 }
29 }
30
31 impl AsyncRead for WinUnixStream {
32 fn poll_read(
33 mut self: Pin<&mut Self>,
34 cx: &mut Context<'_>,
35 buf: &mut ReadBuf<'_>,
36 ) -> Poll<io::Result<()>> {
37 Pin::new(&mut self.0).poll_read(cx, buf)
38 }
39 }
40
41 impl AsyncWrite for WinUnixStream {
42 fn poll_write(
43 mut self: Pin<&mut Self>,
44 cx: &mut Context<'_>,
45 buf: &[u8],
46 ) -> Poll<io::Result<usize>> {
47 Pin::new(&mut self.0).poll_write(cx, buf)
48 }
49
50 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
51 Pin::new(&mut self.0).poll_flush(cx)
52 }
53
54 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
55 Pin::new(&mut self.0).poll_shutdown(cx)
56 }
57 }
58
59 pub struct WinUnixListener {
60 inner: AsyncListener,
61 path: std::path::PathBuf,
62 }
63
64 impl WinUnixListener {
65 pub fn bind<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
66 let path_buf = path.as_ref().to_path_buf();
67 Ok(Self {
68 inner: AsyncListener::bind(path)?,
69 path: path_buf,
70 })
71 }
72
73 pub async fn accept(&self) -> io::Result<(WinUnixStream, ())> {
74 use tokio_util::compat::FuturesAsyncReadCompatExt;
75 let (stream, _addr) = self.inner.accept().await?;
76 Ok((WinUnixStream(stream.compat()), ()))
77 }
78
79 pub fn local_addr(&self) -> io::Result<std::path::PathBuf> {
80 Ok(self.path.clone())
81 }
82 }
83}
84
85#[cfg(windows)]
86use win_uds_compat::WinUnixListener;
87
88pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
89
90impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
91
92pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
93
94pub struct TlsConfig {
95 pub config: Arc<ServerConfig>,
96 acceptor: TlsAcceptor,
97}
98
99impl TlsConfig {
100 pub fn from_pem(pem_path: PathBuf) -> io::Result<Self> {
101 let pem = std::fs::File::open(&pem_path).map_err(|e| {
102 io::Error::new(
103 io::ErrorKind::NotFound,
104 format!("Failed to open PEM file {}: {}", pem_path.display(), e),
105 )
106 })?;
107 let mut pem = std::io::BufReader::new(pem);
108
109 let certs = rustls_pemfile::certs(&mut pem)
110 .collect::<Result<Vec<_>, _>>()
111 .map_err(|e| {
112 io::Error::new(
113 io::ErrorKind::InvalidData,
114 format!("Invalid certificate: {e}"),
115 )
116 })?;
117
118 if certs.is_empty() {
119 return Err(io::Error::new(
120 io::ErrorKind::InvalidData,
121 "No certificates found",
122 ));
123 }
124
125 pem.seek(std::io::SeekFrom::Start(0))?;
126
127 let key = rustls_pemfile::private_key(&mut pem)
128 .map_err(|e| {
129 io::Error::new(
130 io::ErrorKind::InvalidData,
131 format!("Invalid private key: {e}"),
132 )
133 })?
134 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found"))?;
135
136 let mut config = rustls::ServerConfig::builder()
137 .with_no_client_auth()
138 .with_single_cert(certs, key)
139 .map_err(|e| {
140 io::Error::new(io::ErrorKind::InvalidData, format!("TLS config error: {e}"))
141 })?;
142
143 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
145
146 let config = Arc::new(config);
147 let acceptor = TlsAcceptor::from(config.clone());
148 Ok(Self { config, acceptor })
149 }
150}
151
152pub enum Listener {
153 Tcp {
154 listener: Arc<TcpListener>,
155 tls_config: Option<TlsConfig>,
156 },
157 #[cfg(unix)]
158 Unix(UnixListener),
159 #[cfg(windows)]
160 Unix(WinUnixListener),
161}
162
163impl Listener {
164 pub async fn accept(
165 &mut self,
166 ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
167 match self {
168 Listener::Tcp {
169 listener,
170 tls_config,
171 } => {
172 let (stream, addr) = listener.accept().await?;
173
174 let stream = if let Some(tls) = tls_config {
175 match tls.acceptor.accept(stream).await {
177 Ok(tls_stream) => Box::new(tls_stream) as AsyncReadWriteBox,
178 Err(e) => {
179 return Err(io::Error::new(
180 io::ErrorKind::ConnectionAborted,
181 format!("TLS error: {e}"),
182 ));
183 }
184 }
185 } else {
186 Box::new(stream)
188 };
189
190 Ok((stream, Some(addr)))
191 }
192 #[cfg(unix)]
193 Listener::Unix(listener) => {
194 let (stream, _) = listener.accept().await?;
195 Ok((Box::new(stream), None))
196 }
197 #[cfg(windows)]
198 Listener::Unix(listener) => {
199 let (stream, _) = listener.accept().await?;
200 Ok((Box::new(stream), None))
201 }
202 }
203 }
204
205 pub async fn bind(addr: &str, tls_config: Option<TlsConfig>) -> io::Result<Self> {
206 fn is_unix_path(addr: &str) -> bool {
208 addr.starts_with('/') || addr.starts_with('.')
209 }
210
211 #[cfg(windows)]
212 fn is_windows_path(s: &str) -> bool {
213 let bytes = s.as_bytes();
214 bytes.len() >= 3
215 && bytes[0].is_ascii_alphabetic()
216 && bytes[1] == b':'
217 && (bytes[2] == b'\\' || bytes[2] == b'/')
218 }
219
220 #[cfg(windows)]
221 {
222 if is_unix_path(addr) || is_windows_path(addr) {
223 if tls_config.is_some() {
224 return Err(io::Error::new(
225 io::ErrorKind::InvalidInput,
226 "TLS is not supported with Unix domain sockets",
227 ));
228 }
229 let _ = std::fs::remove_file(addr);
230 let listener = WinUnixListener::bind(addr)?;
231 Ok(Listener::Unix(listener))
232 } else {
233 let mut addr = addr.to_owned();
234 if addr.starts_with(':') {
235 addr = format!("127.0.0.1{addr}");
236 }
237 let listener = TcpListener::bind(addr).await?;
238 Ok(Listener::Tcp {
239 listener: Arc::new(listener),
240 tls_config,
241 })
242 }
243 }
244
245 #[cfg(unix)]
246 {
247 if is_unix_path(addr) {
248 if tls_config.is_some() {
249 return Err(io::Error::new(
250 io::ErrorKind::InvalidInput,
251 "TLS is not supported with Unix domain sockets",
252 ));
253 }
254 let _ = std::fs::remove_file(addr);
255 let listener = UnixListener::bind(addr)?;
256 Ok(Listener::Unix(listener))
257 } else {
258 let mut addr = addr.to_owned();
259 if addr.starts_with(':') {
260 addr = format!("127.0.0.1{addr}");
261 }
262 let listener = TcpListener::bind(addr).await?;
263 Ok(Listener::Tcp {
264 listener: Arc::new(listener),
265 tls_config,
266 })
267 }
268 }
269 }
270}
271
272impl Clone for Listener {
273 fn clone(&self) -> Self {
274 match self {
275 Listener::Tcp {
276 listener,
277 tls_config,
278 } => Listener::Tcp {
279 listener: listener.clone(),
280 tls_config: tls_config.clone(),
281 },
282 #[cfg(unix)]
283 Listener::Unix(_) => {
284 panic!("Cannot clone a Unix listener")
285 }
286 #[cfg(windows)]
287 Listener::Unix(_) => {
288 panic!("Cannot clone a Unix listener")
289 }
290 }
291 }
292}
293
294impl Clone for TlsConfig {
295 fn clone(&self) -> Self {
296 TlsConfig {
297 config: self.config.clone(),
298 acceptor: TlsAcceptor::from(self.config.clone()),
299 }
300 }
301}
302
303impl std::fmt::Display for Listener {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 match self {
306 Listener::Tcp {
307 listener,
308 tls_config,
309 } => {
310 let addr = listener.local_addr().unwrap();
311 let tls_suffix = if tls_config.is_some() { " (TLS)" } else { "" };
312 write!(f, "{}:{}{}", addr.ip(), addr.port(), tls_suffix)
313 }
314 #[cfg(unix)]
315 Listener::Unix(listener) => {
316 let addr = listener.local_addr().unwrap();
317 let path = addr.as_pathname().unwrap();
318 write!(f, "{}", path.display())
319 }
320 #[cfg(windows)]
321 Listener::Unix(listener) => {
322 let path = listener.local_addr().unwrap();
323 write!(f, "{}", path.display())
324 }
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use tokio::net::TcpStream;
333
334 use tokio::io::AsyncReadExt;
335 use tokio::io::AsyncWriteExt;
336
337 #[cfg(windows)]
338 use super::win_uds_compat::WinUnixStream;
339
340 async fn exercise_listener(addr: &str) {
341 let mut listener = Listener::bind(addr, None).await.unwrap();
342 let listener_addr = match &listener {
343 Listener::Tcp { listener, .. } => {
344 let addr = listener.local_addr().unwrap();
345 format!("{}:{}", addr.ip(), addr.port())
346 }
347 #[cfg(unix)]
348 Listener::Unix(listener) => {
349 let addr = listener.local_addr().unwrap();
350 addr.as_pathname().unwrap().to_string_lossy().to_string()
351 }
352 #[cfg(windows)]
353 Listener::Unix(listener) => {
354 let path = listener.local_addr().unwrap();
355 path.to_string_lossy().to_string()
356 }
357 };
358
359 let client_task: tokio::task::JoinHandle<
360 Result<Box<dyn AsyncReadWrite + Send + Unpin>, std::io::Error>,
361 > = tokio::spawn(async move {
362 #[cfg(unix)]
363 if listener_addr.starts_with('/') {
364 use tokio::net::UnixStream;
365 let stream = UnixStream::connect(&listener_addr).await?;
366 return Ok(Box::new(stream) as AsyncReadWriteBox);
367 }
368 #[cfg(windows)]
369 if listener_addr.starts_with('/') || listener_addr.chars().nth(1) == Some(':') {
370 let stream = WinUnixStream::connect(&listener_addr).await?;
371 return Ok(Box::new(stream) as AsyncReadWriteBox);
372 }
373 let stream = TcpStream::connect(&listener_addr).await?;
374 Ok(Box::new(stream) as AsyncReadWriteBox)
375 });
376
377 let (mut serve, _) = listener.accept().await.unwrap();
378 let want = b"Hello from server!";
379 serve.write_all(want).await.unwrap();
380 drop(serve);
381
382 let mut client = client_task.await.unwrap().unwrap();
383 let mut got = Vec::new();
384 client.read_to_end(&mut got).await.unwrap();
385 assert_eq!(want.to_vec(), got);
386 }
387
388 #[tokio::test]
389 async fn test_bind_tcp() {
390 exercise_listener("127.0.0.1:0").await;
391 }
392
393 #[cfg(unix)]
394 #[tokio::test]
395 async fn test_bind_unix() {
396 let temp_dir = tempfile::tempdir().unwrap();
397 let path = temp_dir.path().join("test.sock");
398 let path = path.to_str().unwrap();
399 exercise_listener(path).await;
400 }
401}