1use std::io::{self, Seek};
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use rustls::ServerConfig;
6use tokio::io::{AsyncRead, AsyncWrite};
7use tokio::net::TcpListener;
8#[cfg(unix)]
9use tokio::net::UnixListener;
10use tokio_rustls::TlsAcceptor;
11
12pub trait AsyncReadWrite: AsyncRead + AsyncWrite {}
13
14impl<T: AsyncRead + AsyncWrite> AsyncReadWrite for T {}
15
16pub type AsyncReadWriteBox = Box<dyn AsyncReadWrite + Unpin + Send>;
17
18pub struct TlsConfig {
19 pub config: Arc<ServerConfig>,
20 acceptor: TlsAcceptor,
21}
22
23impl TlsConfig {
24 pub fn from_pem(pem_path: PathBuf) -> io::Result<Self> {
25 let pem = std::fs::File::open(&pem_path).map_err(|e| {
26 io::Error::new(
27 io::ErrorKind::NotFound,
28 format!("Failed to open PEM file {}: {}", pem_path.display(), e),
29 )
30 })?;
31 let mut pem = std::io::BufReader::new(pem);
32
33 let certs = rustls_pemfile::certs(&mut pem)
34 .collect::<Result<Vec<_>, _>>()
35 .map_err(|e| {
36 io::Error::new(
37 io::ErrorKind::InvalidData,
38 format!("Invalid certificate: {e}"),
39 )
40 })?;
41
42 if certs.is_empty() {
43 return Err(io::Error::new(
44 io::ErrorKind::InvalidData,
45 "No certificates found",
46 ));
47 }
48
49 pem.seek(std::io::SeekFrom::Start(0))?;
50
51 let key = rustls_pemfile::private_key(&mut pem)
52 .map_err(|e| {
53 io::Error::new(
54 io::ErrorKind::InvalidData,
55 format!("Invalid private key: {e}"),
56 )
57 })?
58 .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "No private key found"))?;
59
60 let config = rustls::ServerConfig::builder()
61 .with_no_client_auth()
62 .with_single_cert(certs, key)
63 .map_err(|e| {
64 io::Error::new(io::ErrorKind::InvalidData, format!("TLS config error: {e}"))
65 })?;
66
67 let config = Arc::new(config);
68 let acceptor = TlsAcceptor::from(config.clone());
69 Ok(Self { config, acceptor })
70 }
71}
72
73pub enum Listener {
74 Tcp {
75 listener: Arc<TcpListener>,
76 tls_config: Option<TlsConfig>,
77 },
78 #[cfg(unix)]
79 Unix(UnixListener),
80}
81
82impl Listener {
83 pub async fn accept(
84 &mut self,
85 ) -> io::Result<(AsyncReadWriteBox, Option<std::net::SocketAddr>)> {
86 match self {
87 Listener::Tcp {
88 listener,
89 tls_config,
90 } => {
91 let (stream, addr) = listener.accept().await?;
92
93 let stream = if let Some(tls) = tls_config {
94 match tls.acceptor.accept(stream).await {
96 Ok(tls_stream) => Box::new(tls_stream) as AsyncReadWriteBox,
97 Err(e) => {
98 return Err(io::Error::new(
99 io::ErrorKind::ConnectionAborted,
100 format!("TLS error: {e}"),
101 ));
102 }
103 }
104 } else {
105 Box::new(stream)
107 };
108
109 Ok((stream, Some(addr)))
110 }
111 #[cfg(unix)]
112 Listener::Unix(listener) => {
113 let (stream, _) = listener.accept().await?;
114 Ok((Box::new(stream), None))
115 }
116 }
117 }
118
119 pub async fn bind(addr: &str, tls_config: Option<TlsConfig>) -> io::Result<Self> {
120 #[cfg(windows)]
121 {
122 let mut addr = addr.to_owned();
124 if addr.starts_with(':') {
125 addr = format!("127.0.0.1{addr}");
126 }
127 let listener = TcpListener::bind(addr).await?;
128 Ok(Listener::Tcp {
129 listener: Arc::new(listener),
130 tls_config,
131 })
132 }
133
134 #[cfg(unix)]
135 {
136 if addr.starts_with('/') || addr.starts_with('.') {
137 if tls_config.is_some() {
138 return Err(io::Error::new(
139 io::ErrorKind::InvalidInput,
140 "TLS is not supported with Unix domain sockets",
141 ));
142 }
143 let _ = std::fs::remove_file(addr);
144 let listener = UnixListener::bind(addr)?;
145 Ok(Listener::Unix(listener))
146 } else {
147 let mut addr = addr.to_owned();
148 if addr.starts_with(':') {
149 addr = format!("127.0.0.1{addr}");
150 }
151 let listener = TcpListener::bind(addr).await?;
152 Ok(Listener::Tcp {
153 listener: Arc::new(listener),
154 tls_config,
155 })
156 }
157 }
158 }
159}
160
161impl Clone for Listener {
162 fn clone(&self) -> Self {
163 match self {
164 Listener::Tcp {
165 listener,
166 tls_config,
167 } => Listener::Tcp {
168 listener: listener.clone(),
169 tls_config: tls_config.clone(),
170 },
171 #[cfg(unix)]
172 Listener::Unix(_) => {
173 panic!("Cannot clone a Unix listener")
174 }
175 }
176 }
177}
178
179impl Clone for TlsConfig {
180 fn clone(&self) -> Self {
181 TlsConfig {
182 config: self.config.clone(),
183 acceptor: TlsAcceptor::from(self.config.clone()),
184 }
185 }
186}
187
188impl std::fmt::Display for Listener {
189 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
190 match self {
191 Listener::Tcp {
192 listener,
193 tls_config,
194 } => {
195 let addr = listener.local_addr().unwrap();
196 let tls_suffix = if tls_config.is_some() { " (TLS)" } else { "" };
197 write!(f, "{}:{}{}", addr.ip(), addr.port(), tls_suffix)
198 }
199 #[cfg(unix)]
200 Listener::Unix(listener) => {
201 let addr = listener.local_addr().unwrap();
202 let path = addr.as_pathname().unwrap();
203 write!(f, "{}", path.display())
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212 use tokio::net::TcpStream;
213
214 use tokio::io::AsyncReadExt;
215 use tokio::io::AsyncWriteExt;
216
217 async fn exercise_listener(addr: &str) {
218 let mut listener = Listener::bind(addr, None).await.unwrap();
219 let listener_addr = match &listener {
220 Listener::Tcp { listener, .. } => {
221 let addr = listener.local_addr().unwrap();
222 format!("{}:{}", addr.ip(), addr.port())
223 }
224 #[cfg(unix)]
225 Listener::Unix(listener) => {
226 let addr = listener.local_addr().unwrap();
227 addr.as_pathname().unwrap().to_string_lossy().to_string()
228 }
229 };
230
231 let client_task: tokio::task::JoinHandle<
232 Result<Box<dyn AsyncReadWrite + Send + Unpin>, std::io::Error>,
233 > = tokio::spawn(async move {
234 if listener_addr.starts_with('/') {
235 #[cfg(unix)]
236 {
237 use tokio::net::UnixStream;
238 let stream = UnixStream::connect(&listener_addr).await?;
239 Ok(Box::new(stream) as AsyncReadWriteBox)
240 }
241 #[cfg(not(unix))]
242 {
243 panic!("Unix sockets not supported on this platform");
244 }
245 } else {
246 let stream = TcpStream::connect(&listener_addr).await?;
247 Ok(Box::new(stream) as AsyncReadWriteBox)
248 }
249 });
250
251 let (mut serve, _) = listener.accept().await.unwrap();
252 let want = b"Hello from server!";
253 serve.write_all(want).await.unwrap();
254 drop(serve);
255
256 let mut client = client_task.await.unwrap().unwrap();
257 let mut got = Vec::new();
258 client.read_to_end(&mut got).await.unwrap();
259 assert_eq!(want.to_vec(), got);
260 }
261
262 #[tokio::test]
263 async fn test_bind_tcp() {
264 exercise_listener("127.0.0.1:0").await;
265 }
266
267 #[cfg(unix)]
268 #[tokio::test]
269 async fn test_bind_unix() {
270 let temp_dir = tempfile::tempdir().unwrap();
271 let path = temp_dir.path().join("test.sock");
272 let path = path.to_str().unwrap();
273 exercise_listener(path).await;
274 }
275}