1use std::io::{self, Seek};
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use rustls::ServerConfig;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpListener;
8#[cfg(unix)]
9use tokio::net::UnixListener;
10use tokio_rustls::TlsAcceptor;
11
12pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
13
14impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
15
16pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
17
18pub struct TlsConfig {
19 pub config: Arc<ServerConfig>,
20 acceptor: TlsAcceptor,
21}
22
23impl TlsConfig {
24 pub fn from_pem(pem_path: PathBuf) -> io::Result<Self> {
25 let pem = std::fs::File::open(&pem_path).map_err(|e| {
26 io::Error::new(
27 io::ErrorKind::NotFound,
28 format!("Failed to open PEM file {}: {}", pem_path.display(), e),
29 )
30 })?;
31 let mut pem = std::io::BufReader::new(pem);
32
33 let certs = rustls_pemfile::certs(&mut pem)
34 .collect::<Result<Vec<_>, _>>()
35 .map_err(|e| {
36 io::Error::new(
37 io::ErrorKind::InvalidData,
38 format!("Invalid certificate: {e}"),
39 )
40 })?;
41
42 if certs.is_empty() {
43 return Err(io::Error::new(
44 io::ErrorKind::InvalidData,
45 "No certificates found",
46 ));
47 }
48
49 pem.seek(std::io::SeekFrom::Start(0))?;
50
51 let key = rustls_pemfile::private_key(&mut pem)
52 .map_err(|e| {
53 io::Error::new(
54 io::ErrorKind::InvalidData,
55 format!("Invalid private key: {e}"),
56 )
57 })?
58 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found"))?;
59
60 let mut config = rustls::ServerConfig::builder()
61 .with_no_client_auth()
62 .with_single_cert(certs, key)
63 .map_err(|e| {
64 io::Error::new(io::ErrorKind::InvalidData, format!("TLS config error: {e}"))
65 })?;
66
67 config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
69
70 let config = Arc::new(config);
71 let acceptor = TlsAcceptor::from(config.clone());
72 Ok(Self { config, acceptor })
73 }
74}
75
76pub enum Listener {
77 Tcp {
78 listener: Arc<TcpListener>,
79 tls_config: Option<TlsConfig>,
80 },
81 #[cfg(unix)]
82 Unix(UnixListener),
83}
84
85impl Listener {
86 pub async fn accept(
87 &mut self,
88 ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
89 match self {
90 Listener::Tcp {
91 listener,
92 tls_config,
93 } => {
94 let (stream, addr) = listener.accept().await?;
95
96 let stream = if let Some(tls) = tls_config {
97 match tls.acceptor.accept(stream).await {
99 Ok(tls_stream) => Box::new(tls_stream) as AsyncReadWriteBox,
100 Err(e) => {
101 return Err(io::Error::new(
102 io::ErrorKind::ConnectionAborted,
103 format!("TLS error: {e}"),
104 ));
105 }
106 }
107 } else {
108 Box::new(stream)
110 };
111
112 Ok((stream, Some(addr)))
113 }
114 #[cfg(unix)]
115 Listener::Unix(listener) => {
116 let (stream, _) = listener.accept().await?;
117 Ok((Box::new(stream), None))
118 }
119 }
120 }
121
122 pub async fn bind(addr: &str, tls_config: Option<TlsConfig>) -> io::Result<Self> {
123 #[cfg(windows)]
124 {
125 let mut addr = addr.to_owned();
127 if addr.starts_with(':') {
128 addr = format!("127.0.0.1{addr}");
129 }
130 let listener = TcpListener::bind(addr).await?;
131 Ok(Listener::Tcp {
132 listener: Arc::new(listener),
133 tls_config,
134 })
135 }
136
137 #[cfg(unix)]
138 {
139 if addr.starts_with('/') || addr.starts_with('.') {
140 if tls_config.is_some() {
141 return Err(io::Error::new(
142 io::ErrorKind::InvalidInput,
143 "TLS is not supported with Unix domain sockets",
144 ));
145 }
146 let _ = std::fs::remove_file(addr);
147 let listener = UnixListener::bind(addr)?;
148 Ok(Listener::Unix(listener))
149 } else {
150 let mut addr = addr.to_owned();
151 if addr.starts_with(':') {
152 addr = format!("127.0.0.1{addr}");
153 }
154 let listener = TcpListener::bind(addr).await?;
155 Ok(Listener::Tcp {
156 listener: Arc::new(listener),
157 tls_config,
158 })
159 }
160 }
161 }
162}
163
164impl Clone for Listener {
165 fn clone(&self) -> Self {
166 match self {
167 Listener::Tcp {
168 listener,
169 tls_config,
170 } => Listener::Tcp {
171 listener: listener.clone(),
172 tls_config: tls_config.clone(),
173 },
174 #[cfg(unix)]
175 Listener::Unix(_) => {
176 panic!("Cannot clone a Unix listener")
177 }
178 }
179 }
180}
181
182impl Clone for TlsConfig {
183 fn clone(&self) -> Self {
184 TlsConfig {
185 config: self.config.clone(),
186 acceptor: TlsAcceptor::from(self.config.clone()),
187 }
188 }
189}
190
191impl std::fmt::Display for Listener {
192 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
193 match self {
194 Listener::Tcp {
195 listener,
196 tls_config,
197 } => {
198 let addr = listener.local_addr().unwrap();
199 let tls_suffix = if tls_config.is_some() { " (TLS)" } else { "" };
200 write!(f, "{}:{}{}", addr.ip(), addr.port(), tls_suffix)
201 }
202 #[cfg(unix)]
203 Listener::Unix(listener) => {
204 let addr = listener.local_addr().unwrap();
205 let path = addr.as_pathname().unwrap();
206 write!(f, "{}", path.display())
207 }
208 }
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use tokio::net::TcpStream;
216
217 use tokio::io::AsyncReadExt;
218 use tokio::io::AsyncWriteExt;
219
220 async fn exercise_listener(addr: &str) {
221 let mut listener = Listener::bind(addr, None).await.unwrap();
222 let listener_addr = match &listener {
223 Listener::Tcp { listener, .. } => {
224 let addr = listener.local_addr().unwrap();
225 format!("{}:{}", addr.ip(), addr.port())
226 }
227 #[cfg(unix)]
228 Listener::Unix(listener) => {
229 let addr = listener.local_addr().unwrap();
230 addr.as_pathname().unwrap().to_string_lossy().to_string()
231 }
232 };
233
234 let client_task: tokio::task::JoinHandle<
235 Result<Box<dyn AsyncReadWrite + Send + Unpin>, std::io::Error>,
236 > = tokio::spawn(async move {
237 if listener_addr.starts_with('/') {
238 #[cfg(unix)]
239 {
240 use tokio::net::UnixStream;
241 let stream = UnixStream::connect(&listener_addr).await?;
242 Ok(Box::new(stream) as AsyncReadWriteBox)
243 }
244 #[cfg(not(unix))]
245 {
246 panic!("Unix sockets not supported on this platform");
247 }
248 } else {
249 let stream = TcpStream::connect(&listener_addr).await?;
250 Ok(Box::new(stream) as AsyncReadWriteBox)
251 }
252 });
253
254 let (mut serve, _) = listener.accept().await.unwrap();
255 let want = b"Hello from server!";
256 serve.write_all(want).await.unwrap();
257 drop(serve);
258
259 let mut client = client_task.await.unwrap().unwrap();
260 let mut got = Vec::new();
261 client.read_to_end(&mut got).await.unwrap();
262 assert_eq!(want.to_vec(), got);
263 }
264
265 #[tokio::test]
266 async fn test_bind_tcp() {
267 exercise_listener("127.0.0.1:0").await;
268 }
269
270 #[cfg(unix)]
271 #[tokio::test]
272 async fn test_bind_unix() {
273 let temp_dir = tempfile::tempdir().unwrap();
274 let path = temp_dir.path().join("test.sock");
275 let path = path.to_str().unwrap();
276 exercise_listener(path).await;
277 }
278}