http_nu/
listener.rs

1use std::io::{self, Seek};
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use rustls::ServerConfig;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpListener;
8#[cfg(unix)]
9use tokio::net::UnixListener;
10use tokio_rustls::TlsAcceptor;
11
12#[cfg(windows)]
13mod win_uds_compat {
14    use std::io;
15    use std::pin::Pin;
16    use std::task::{Context, Poll};
17    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
18    use tokio_util::compat::Compat;
19    use win_uds::net::{AsyncListener, AsyncStream};
20
21    pub struct WinUnixStream(Compat<AsyncStream>);
22
23    impl WinUnixStream {
24        pub async fn connect<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
25            use tokio_util::compat::FuturesAsyncReadCompatExt;
26            let stream = AsyncStream::connect(path).await?;
27            Ok(Self(stream.compat()))
28        }
29    }
30
31    impl AsyncRead for WinUnixStream {
32        fn poll_read(
33            mut self: Pin<&mut Self>,
34            cx: &mut Context<'_>,
35            buf: &mut ReadBuf<'_>,
36        ) -> Poll<io::Result<()>> {
37            Pin::new(&mut self.0).poll_read(cx, buf)
38        }
39    }
40
41    impl AsyncWrite for WinUnixStream {
42        fn poll_write(
43            mut self: Pin<&mut Self>,
44            cx: &mut Context<'_>,
45            buf: &[u8],
46        ) -> Poll<io::Result<usize>> {
47            Pin::new(&mut self.0).poll_write(cx, buf)
48        }
49
50        fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
51            Pin::new(&mut self.0).poll_flush(cx)
52        }
53
54        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
55            Pin::new(&mut self.0).poll_shutdown(cx)
56        }
57    }
58
59    pub struct WinUnixListener {
60        inner: AsyncListener,
61        path: std::path::PathBuf,
62    }
63
64    impl WinUnixListener {
65        pub fn bind<P: AsRef<std::path::Path>>(path: P) -> io::Result<Self> {
66            let path_buf = path.as_ref().to_path_buf();
67            Ok(Self {
68                inner: AsyncListener::bind(path)?,
69                path: path_buf,
70            })
71        }
72
73        pub async fn accept(&self) -> io::Result<(WinUnixStream, ())> {
74            use tokio_util::compat::FuturesAsyncReadCompatExt;
75            let (stream, _addr) = self.inner.accept().await?;
76            Ok((WinUnixStream(stream.compat()), ()))
77        }
78
79        pub fn local_addr(&self) -> io::Result<std::path::PathBuf> {
80            Ok(self.path.clone())
81        }
82    }
83}
84
85#[cfg(windows)]
86use win_uds_compat::WinUnixListener;
87
88pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
89
90impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
91
92pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
93
94pub struct TlsConfig {
95    pub config: Arc<ServerConfig>,
96    acceptor: TlsAcceptor,
97}
98
99impl TlsConfig {
100    pub fn from_pem(pem_path: PathBuf) -> io::Result<Self> {
101        let pem = std::fs::File::open(&pem_path).map_err(|e| {
102            io::Error::new(
103                io::ErrorKind::NotFound,
104                format!("Failed to open PEM file {}: {}", pem_path.display(), e),
105            )
106        })?;
107        let mut pem = std::io::BufReader::new(pem);
108
109        let certs = rustls_pemfile::certs(&mut pem)
110            .collect::<Result<Vec<_>, _>>()
111            .map_err(|e| {
112                io::Error::new(
113                    io::ErrorKind::InvalidData,
114                    format!("Invalid certificate: {e}"),
115                )
116            })?;
117
118        if certs.is_empty() {
119            return Err(io::Error::new(
120                io::ErrorKind::InvalidData,
121                "No certificates found",
122            ));
123        }
124
125        pem.seek(std::io::SeekFrom::Start(0))?;
126
127        let key = rustls_pemfile::private_key(&mut pem)
128            .map_err(|e| {
129                io::Error::new(
130                    io::ErrorKind::InvalidData,
131                    format!("Invalid private key: {e}"),
132                )
133            })?
134            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found"))?;
135
136        let mut config = rustls::ServerConfig::builder()
137            .with_no_client_auth()
138            .with_single_cert(certs, key)
139            .map_err(|e| {
140                io::Error::new(io::ErrorKind::InvalidData, format!("TLS config error: {e}"))
141            })?;
142
143        // Enable HTTP/2 via ALPN (advertise h2 first, then http/1.1)
144        config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
145
146        let config = Arc::new(config);
147        let acceptor = TlsAcceptor::from(config.clone());
148        Ok(Self { config, acceptor })
149    }
150}
151
152pub enum Listener {
153    Tcp {
154        listener: Arc<TcpListener>,
155        tls_config: Option<TlsConfig>,
156    },
157    #[cfg(unix)]
158    Unix(UnixListener),
159    #[cfg(windows)]
160    Unix(WinUnixListener),
161}
162
163impl Listener {
164    pub async fn accept(
165        &mut self,
166    ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
167        match self {
168            Listener::Tcp {
169                listener,
170                tls_config,
171            } => {
172                let (stream, addr) = listener.accept().await?;
173
174                let stream = if let Some(tls) = tls_config {
175                    // Handle TLS connection
176                    match tls.acceptor.accept(stream).await {
177                        Ok(tls_stream) => Box::new(tls_stream) as AsyncReadWriteBox,
178                        Err(e) => {
179                            return Err(io::Error::new(
180                                io::ErrorKind::ConnectionAborted,
181                                format!("TLS error: {e}"),
182                            ));
183                        }
184                    }
185                } else {
186                    // Handle plain TCP connection
187                    Box::new(stream)
188                };
189
190                Ok((stream, Some(addr)))
191            }
192            #[cfg(unix)]
193            Listener::Unix(listener) => {
194                let (stream, _) = listener.accept().await?;
195                Ok((Box::new(stream), None))
196            }
197            #[cfg(windows)]
198            Listener::Unix(listener) => {
199                let (stream, _) = listener.accept().await?;
200                Ok((Box::new(stream), None))
201            }
202        }
203    }
204
205    pub async fn bind(addr: &str, tls_config: Option<TlsConfig>) -> io::Result<Self> {
206        // Check if address looks like a Unix socket path
207        fn is_unix_path(addr: &str) -> bool {
208            addr.starts_with('/') || addr.starts_with('.')
209        }
210
211        #[cfg(windows)]
212        fn is_windows_path(s: &str) -> bool {
213            let bytes = s.as_bytes();
214            bytes.len() >= 3
215                && bytes[0].is_ascii_alphabetic()
216                && bytes[1] == b':'
217                && (bytes[2] == b'\\' || bytes[2] == b'/')
218        }
219
220        #[cfg(windows)]
221        {
222            if is_unix_path(addr) || is_windows_path(addr) {
223                if tls_config.is_some() {
224                    return Err(io::Error::new(
225                        io::ErrorKind::InvalidInput,
226                        "TLS is not supported with Unix domain sockets",
227                    ));
228                }
229                let _ = std::fs::remove_file(addr);
230                let listener = WinUnixListener::bind(addr)?;
231                Ok(Listener::Unix(listener))
232            } else {
233                let mut addr = addr.to_owned();
234                if addr.starts_with(':') {
235                    addr = format!("127.0.0.1{addr}");
236                }
237                let listener = TcpListener::bind(addr).await?;
238                Ok(Listener::Tcp {
239                    listener: Arc::new(listener),
240                    tls_config,
241                })
242            }
243        }
244
245        #[cfg(unix)]
246        {
247            if is_unix_path(addr) {
248                if tls_config.is_some() {
249                    return Err(io::Error::new(
250                        io::ErrorKind::InvalidInput,
251                        "TLS is not supported with Unix domain sockets",
252                    ));
253                }
254                let _ = std::fs::remove_file(addr);
255                let listener = UnixListener::bind(addr)?;
256                Ok(Listener::Unix(listener))
257            } else {
258                let mut addr = addr.to_owned();
259                if addr.starts_with(':') {
260                    addr = format!("127.0.0.1{addr}");
261                }
262                let listener = TcpListener::bind(addr).await?;
263                Ok(Listener::Tcp {
264                    listener: Arc::new(listener),
265                    tls_config,
266                })
267            }
268        }
269    }
270}
271
272impl Clone for Listener {
273    fn clone(&self) -> Self {
274        match self {
275            Listener::Tcp {
276                listener,
277                tls_config,
278            } => Listener::Tcp {
279                listener: listener.clone(),
280                tls_config: tls_config.clone(),
281            },
282            #[cfg(unix)]
283            Listener::Unix(_) => {
284                panic!("Cannot clone a Unix listener")
285            }
286            #[cfg(windows)]
287            Listener::Unix(_) => {
288                panic!("Cannot clone a Unix listener")
289            }
290        }
291    }
292}
293
294impl Clone for TlsConfig {
295    fn clone(&self) -> Self {
296        TlsConfig {
297            config: self.config.clone(),
298            acceptor: TlsAcceptor::from(self.config.clone()),
299        }
300    }
301}
302
303impl std::fmt::Display for Listener {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        match self {
306            Listener::Tcp {
307                listener,
308                tls_config,
309            } => {
310                let addr = listener.local_addr().unwrap();
311                let tls_suffix = if tls_config.is_some() { " (TLS)" } else { "" };
312                write!(f, "{}:{}{}", addr.ip(), addr.port(), tls_suffix)
313            }
314            #[cfg(unix)]
315            Listener::Unix(listener) => {
316                let addr = listener.local_addr().unwrap();
317                let path = addr.as_pathname().unwrap();
318                write!(f, "{}", path.display())
319            }
320            #[cfg(windows)]
321            Listener::Unix(listener) => {
322                let path = listener.local_addr().unwrap();
323                write!(f, "{}", path.display())
324            }
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use tokio::net::TcpStream;
333
334    use tokio::io::AsyncReadExt;
335    use tokio::io::AsyncWriteExt;
336
337    #[cfg(windows)]
338    use super::win_uds_compat::WinUnixStream;
339
340    async fn exercise_listener(addr: &str) {
341        let mut listener = Listener::bind(addr, None).await.unwrap();
342        let listener_addr = match &listener {
343            Listener::Tcp { listener, .. } => {
344                let addr = listener.local_addr().unwrap();
345                format!("{}:{}", addr.ip(), addr.port())
346            }
347            #[cfg(unix)]
348            Listener::Unix(listener) => {
349                let addr = listener.local_addr().unwrap();
350                addr.as_pathname().unwrap().to_string_lossy().to_string()
351            }
352            #[cfg(windows)]
353            Listener::Unix(listener) => {
354                let path = listener.local_addr().unwrap();
355                path.to_string_lossy().to_string()
356            }
357        };
358
359        let client_task: tokio::task::JoinHandle<
360            Result<Box<dyn AsyncReadWrite + Send + Unpin>, std::io::Error>,
361        > = tokio::spawn(async move {
362            #[cfg(unix)]
363            if listener_addr.starts_with('/') {
364                use tokio::net::UnixStream;
365                let stream = UnixStream::connect(&listener_addr).await?;
366                return Ok(Box::new(stream) as AsyncReadWriteBox);
367            }
368            #[cfg(windows)]
369            if listener_addr.starts_with('/') || listener_addr.chars().nth(1) == Some(':') {
370                let stream = WinUnixStream::connect(&listener_addr).await?;
371                return Ok(Box::new(stream) as AsyncReadWriteBox);
372            }
373            let stream = TcpStream::connect(&listener_addr).await?;
374            Ok(Box::new(stream) as AsyncReadWriteBox)
375        });
376
377        let (mut serve, _) = listener.accept().await.unwrap();
378        let want = b"Hello from server!";
379        serve.write_all(want).await.unwrap();
380        drop(serve);
381
382        let mut client = client_task.await.unwrap().unwrap();
383        let mut got = Vec::new();
384        client.read_to_end(&mut got).await.unwrap();
385        assert_eq!(want.to_vec(), got);
386    }
387
388    #[tokio::test]
389    async fn test_bind_tcp() {
390        exercise_listener("127.0.0.1:0").await;
391    }
392
393    #[cfg(unix)]
394    #[tokio::test]
395    async fn test_bind_unix() {
396        let temp_dir = tempfile::tempdir().unwrap();
397        let path = temp_dir.path().join("test.sock");
398        let path = path.to_str().unwrap();
399        exercise_listener(path).await;
400    }
401}