http_nu/
listener.rs

1use std::io::{self, Seek};
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use rustls::ServerConfig;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpListener;
8#[cfg(unix)]
9use tokio::net::UnixListener;
10use tokio_rustls::TlsAcceptor;
11
12pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
13
14impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
15
16pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
17
18pub struct TlsConfig {
19    pub config: Arc<ServerConfig>,
20    acceptor: TlsAcceptor,
21}
22
23impl TlsConfig {
24    pub fn from_pem(pem_path: PathBuf) -> io::Result<Self> {
25        let pem = std::fs::File::open(&pem_path).map_err(|e| {
26            io::Error::new(
27                io::ErrorKind::NotFound,
28                format!("Failed to open PEM file {}: {}", pem_path.display(), e),
29            )
30        })?;
31        let mut pem = std::io::BufReader::new(pem);
32
33        let certs = rustls_pemfile::certs(&mut pem)
34            .collect::<Result<Vec<_>, _>>()
35            .map_err(|e| {
36                io::Error::new(
37                    io::ErrorKind::InvalidData,
38                    format!("Invalid certificate: {e}"),
39                )
40            })?;
41
42        if certs.is_empty() {
43            return Err(io::Error::new(
44                io::ErrorKind::InvalidData,
45                "No certificates found",
46            ));
47        }
48
49        pem.seek(std::io::SeekFrom::Start(0))?;
50
51        let key = rustls_pemfile::private_key(&mut pem)
52            .map_err(|e| {
53                io::Error::new(
54                    io::ErrorKind::InvalidData,
55                    format!("Invalid private key: {e}"),
56                )
57            })?
58            .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found"))?;
59
60        let config = rustls::ServerConfig::builder()
61            .with_no_client_auth()
62            .with_single_cert(certs, key)
63            .map_err(|e| {
64                io::Error::new(io::ErrorKind::InvalidData, format!("TLS config error: {e}"))
65            })?;
66
67        let config = Arc::new(config);
68        let acceptor = TlsAcceptor::from(config.clone());
69        Ok(Self { config, acceptor })
70    }
71}
72
73pub enum Listener {
74    Tcp {
75        listener: Arc<TcpListener>,
76        tls_config: Option<TlsConfig>,
77    },
78    #[cfg(unix)]
79    Unix(UnixListener),
80}
81
82impl Listener {
83    pub async fn accept(
84        &mut self,
85    ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
86        match self {
87            Listener::Tcp {
88                listener,
89                tls_config,
90            } => {
91                let (stream, addr) = listener.accept().await?;
92
93                let stream = if let Some(tls) = tls_config {
94                    // Handle TLS connection
95                    match tls.acceptor.accept(stream).await {
96                        Ok(tls_stream) => Box::new(tls_stream) as AsyncReadWriteBox,
97                        Err(e) => {
98                            return Err(io::Error::new(
99                                io::ErrorKind::ConnectionAborted,
100                                format!("TLS error: {e}"),
101                            ));
102                        }
103                    }
104                } else {
105                    // Handle plain TCP connection
106                    Box::new(stream)
107                };
108
109                Ok((stream, Some(addr)))
110            }
111            #[cfg(unix)]
112            Listener::Unix(listener) => {
113                let (stream, _) = listener.accept().await?;
114                Ok((Box::new(stream), None))
115            }
116        }
117    }
118
119    pub async fn bind(addr: &str, tls_config: Option<TlsConfig>) -> io::Result<Self> {
120        #[cfg(windows)]
121        {
122            // On Windows, treat all addresses as TCP
123            let mut addr = addr.to_owned();
124            if addr.starts_with(':') {
125                addr = format!("127.0.0.1{addr}");
126            }
127            let listener = TcpListener::bind(addr).await?;
128            Ok(Listener::Tcp {
129                listener: Arc::new(listener),
130                tls_config,
131            })
132        }
133
134        #[cfg(unix)]
135        {
136            if addr.starts_with('/') || addr.starts_with('.') {
137                if tls_config.is_some() {
138                    return Err(io::Error::new(
139                        io::ErrorKind::InvalidInput,
140                        "TLS is not supported with Unix domain sockets",
141                    ));
142                }
143                let _ = std::fs::remove_file(addr);
144                let listener = UnixListener::bind(addr)?;
145                Ok(Listener::Unix(listener))
146            } else {
147                let mut addr = addr.to_owned();
148                if addr.starts_with(':') {
149                    addr = format!("127.0.0.1{addr}");
150                }
151                let listener = TcpListener::bind(addr).await?;
152                Ok(Listener::Tcp {
153                    listener: Arc::new(listener),
154                    tls_config,
155                })
156            }
157        }
158    }
159}
160
161impl Clone for Listener {
162    fn clone(&self) -> Self {
163        match self {
164            Listener::Tcp {
165                listener,
166                tls_config,
167            } => Listener::Tcp {
168                listener: listener.clone(),
169                tls_config: tls_config.clone(),
170            },
171            #[cfg(unix)]
172            Listener::Unix(_) => {
173                panic!("Cannot clone a Unix listener")
174            }
175        }
176    }
177}
178
179impl Clone for TlsConfig {
180    fn clone(&self) -> Self {
181        TlsConfig {
182            config: self.config.clone(),
183            acceptor: TlsAcceptor::from(self.config.clone()),
184        }
185    }
186}
187
188impl std::fmt::Display for Listener {
189    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190        match self {
191            Listener::Tcp {
192                listener,
193                tls_config,
194            } => {
195                let addr = listener.local_addr().unwrap();
196                let tls_suffix = if tls_config.is_some() { " (TLS)" } else { "" };
197                write!(f, "{}:{}{}", addr.ip(), addr.port(), tls_suffix)
198            }
199            #[cfg(unix)]
200            Listener::Unix(listener) => {
201                let addr = listener.local_addr().unwrap();
202                let path = addr.as_pathname().unwrap();
203                write!(f, "{}", path.display())
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod tests {
211    use super::*;
212    use tokio::net::TcpStream;
213
214    use tokio::io::AsyncReadExt;
215    use tokio::io::AsyncWriteExt;
216
217    async fn exercise_listener(addr: &str) {
218        let mut listener = Listener::bind(addr, None).await.unwrap();
219        let listener_addr = match &listener {
220            Listener::Tcp { listener, .. } => {
221                let addr = listener.local_addr().unwrap();
222                format!("{}:{}", addr.ip(), addr.port())
223            }
224            #[cfg(unix)]
225            Listener::Unix(listener) => {
226                let addr = listener.local_addr().unwrap();
227                addr.as_pathname().unwrap().to_string_lossy().to_string()
228            }
229        };
230
231        let client_task: tokio::task::JoinHandle<
232            Result<Box<dyn AsyncReadWrite + Send + Unpin>, std::io::Error>,
233        > = tokio::spawn(async move {
234            if listener_addr.starts_with('/') {
235                #[cfg(unix)]
236                {
237                    use tokio::net::UnixStream;
238                    let stream = UnixStream::connect(&listener_addr).await?;
239                    Ok(Box::new(stream) as AsyncReadWriteBox)
240                }
241                #[cfg(not(unix))]
242                {
243                    panic!("Unix sockets not supported on this platform");
244                }
245            } else {
246                let stream = TcpStream::connect(&listener_addr).await?;
247                Ok(Box::new(stream) as AsyncReadWriteBox)
248            }
249        });
250
251        let (mut serve, _) = listener.accept().await.unwrap();
252        let want = b"Hello from server!";
253        serve.write_all(want).await.unwrap();
254        drop(serve);
255
256        let mut client = client_task.await.unwrap().unwrap();
257        let mut got = Vec::new();
258        client.read_to_end(&mut got).await.unwrap();
259        assert_eq!(want.to_vec(), got);
260    }
261
262    #[tokio::test]
263    async fn test_bind_tcp() {
264        exercise_listener("127.0.0.1:0").await;
265    }
266
267    #[cfg(unix)]
268    #[tokio::test]
269    async fn test_bind_unix() {
270        let temp_dir = tempfile::tempdir().unwrap();
271        let path = temp_dir.path().join("test.sock");
272        let path = path.to_str().unwrap();
273        exercise_listener(path).await;
274    }
275}