1use std::path::Path;
28
29#[cfg(feature = "tls")]
30use std::fs::File;
31#[cfg(feature = "tls")]
32use std::io::BufReader;
33#[cfg(feature = "tls")]
34use std::sync::Arc;
35
36#[derive(Debug, Clone)]
38pub struct TlsConfig {
39 pub cert_path: Option<String>,
41 pub key_path: Option<String>,
43 pub ca_path: Option<String>,
45 pub verify_peer: bool,
47 pub server_name: Option<String>,
49}
50
51impl Default for TlsConfig {
52 fn default() -> Self {
53 Self {
54 cert_path: None,
55 key_path: None,
56 ca_path: None,
57 verify_peer: true,
58 server_name: None,
59 }
60 }
61}
62
63impl TlsConfig {
64 pub fn server(cert_path: impl AsRef<Path>, key_path: impl AsRef<Path>) -> Self {
66 Self {
67 cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
68 key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
69 ca_path: None,
70 verify_peer: false,
71 server_name: None,
72 }
73 }
74
75 pub fn client(ca_path: Option<impl AsRef<Path>>) -> Self {
77 Self {
78 cert_path: None,
79 key_path: None,
80 ca_path: ca_path.map(|p| p.as_ref().to_string_lossy().into_owned()),
81 verify_peer: true,
82 server_name: None,
83 }
84 }
85
86 pub fn mtls(
88 cert_path: impl AsRef<Path>,
89 key_path: impl AsRef<Path>,
90 ca_path: impl AsRef<Path>,
91 ) -> Self {
92 Self {
93 cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
94 key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
95 ca_path: Some(ca_path.as_ref().to_string_lossy().into_owned()),
96 verify_peer: true,
97 server_name: None,
98 }
99 }
100
101 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
103 self.server_name = Some(name.into());
104 self
105 }
106
107 pub fn with_insecure(mut self) -> Self {
109 self.verify_peer = false;
110 self
111 }
112}
113
114#[derive(Debug)]
116pub enum TlsError {
117 CertificateError(String),
119 KeyError(String),
121 CaError(String),
123 HandshakeError(String),
125 IoError(std::io::Error),
127 ConfigError(String),
129}
130
131impl std::fmt::Display for TlsError {
132 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133 match self {
134 TlsError::CertificateError(msg) => write!(f, "Certificate error: {}", msg),
135 TlsError::KeyError(msg) => write!(f, "Key error: {}", msg),
136 TlsError::CaError(msg) => write!(f, "CA error: {}", msg),
137 TlsError::HandshakeError(msg) => write!(f, "Handshake error: {}", msg),
138 TlsError::IoError(e) => write!(f, "I/O error: {}", e),
139 TlsError::ConfigError(msg) => write!(f, "Config error: {}", msg),
140 }
141 }
142}
143
144impl std::error::Error for TlsError {}
145
146impl From<std::io::Error> for TlsError {
147 fn from(e: std::io::Error) -> Self {
148 TlsError::IoError(e)
149 }
150}
151
152pub type TlsResult<T> = std::result::Result<T, TlsError>;
154
155#[derive(Clone)]
157pub struct TlsAcceptor {
158 config: TlsConfig,
159 #[cfg(feature = "tls")]
160 inner: Arc<tokio_rustls::TlsAcceptor>,
161}
162
163impl TlsAcceptor {
164 pub fn new(config: TlsConfig) -> TlsResult<Self> {
166 if config.cert_path.is_none() {
167 return Err(TlsError::ConfigError(
168 "Server TLS config requires certificate path".into(),
169 ));
170 }
171 if config.key_path.is_none() {
172 return Err(TlsError::ConfigError(
173 "Server TLS config requires key path".into(),
174 ));
175 }
176
177 #[cfg(feature = "tls")]
178 {
179 let inner = Self::build_acceptor(&config)?;
180 Ok(Self {
181 config,
182 inner: Arc::new(inner),
183 })
184 }
185
186 #[cfg(not(feature = "tls"))]
187 {
188 Ok(Self { config })
189 }
190 }
191
192 #[cfg(feature = "tls")]
193 fn build_acceptor(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsAcceptor> {
194 use rustls::pki_types::CertificateDer;
195
196 let cert_path = config
197 .cert_path
198 .as_ref()
199 .ok_or_else(|| TlsError::ConfigError("Certificate path required".into()))?;
200 let key_path = config
201 .key_path
202 .as_ref()
203 .ok_or_else(|| TlsError::ConfigError("Key path required".into()))?;
204
205 let cert_file = File::open(cert_path)
207 .map_err(|e| TlsError::CertificateError(format!("{}: {}", cert_path, e)))?;
208 let mut cert_reader = BufReader::new(cert_file);
209 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
210 .filter_map(|r| r.ok())
211 .collect();
212
213 if certs.is_empty() {
214 return Err(TlsError::CertificateError("No certificates found".into()));
215 }
216
217 let key_file =
219 File::open(key_path).map_err(|e| TlsError::KeyError(format!("{}: {}", key_path, e)))?;
220 let mut key_reader = BufReader::new(key_file);
221 let key = rustls_pemfile::private_key(&mut key_reader)
222 .map_err(|e| TlsError::KeyError(format!("Failed to parse key: {}", e)))?
223 .ok_or_else(|| TlsError::KeyError("No private key found".into()))?;
224
225 let server_config = rustls::ServerConfig::builder()
227 .with_no_client_auth()
228 .with_single_cert(certs, key)
229 .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?;
230
231 Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
232 }
233
234 pub fn config(&self) -> &TlsConfig {
236 &self.config
237 }
238
239 #[cfg(feature = "tls")]
241 pub async fn accept<S>(&self, stream: S) -> TlsResult<tokio_rustls::server::TlsStream<S>>
242 where
243 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
244 {
245 self.inner
246 .accept(stream)
247 .await
248 .map_err(|e| TlsError::HandshakeError(e.to_string()))
249 }
250}
251
252#[derive(Clone)]
254pub struct TlsConnector {
255 config: TlsConfig,
256 #[cfg(feature = "tls")]
257 inner: Arc<tokio_rustls::TlsConnector>,
258}
259
260impl TlsConnector {
261 pub fn new(config: TlsConfig) -> TlsResult<Self> {
263 #[cfg(feature = "tls")]
264 {
265 let inner = Self::build_connector(&config)?;
266 Ok(Self {
267 config,
268 inner: Arc::new(inner),
269 })
270 }
271
272 #[cfg(not(feature = "tls"))]
273 {
274 Ok(Self { config })
275 }
276 }
277
278 #[cfg(feature = "tls")]
279 fn build_connector(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsConnector> {
280 use rustls::RootCertStore;
281 use rustls::pki_types::CertificateDer;
282
283 let mut root_store = RootCertStore::empty();
284
285 if let Some(ref ca_path) = config.ca_path {
287 let ca_file = File::open(ca_path)
288 .map_err(|e| TlsError::CaError(format!("{}: {}", ca_path, e)))?;
289 let mut ca_reader = BufReader::new(ca_file);
290 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
291 .filter_map(|r| r.ok())
292 .collect();
293
294 for cert in certs {
295 root_store
296 .add(cert)
297 .map_err(|e| TlsError::CaError(format!("Failed to add CA: {}", e)))?;
298 }
299 } else {
300 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
302 }
303
304 let client_config = rustls::ClientConfig::builder()
306 .with_root_certificates(root_store)
307 .with_no_client_auth();
308
309 Ok(tokio_rustls::TlsConnector::from(Arc::new(client_config)))
310 }
311
312 pub fn config(&self) -> &TlsConfig {
314 &self.config
315 }
316
317 #[cfg(feature = "tls")]
319 pub async fn connect<S>(
320 &self,
321 server_name: &str,
322 stream: S,
323 ) -> TlsResult<tokio_rustls::client::TlsStream<S>>
324 where
325 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
326 {
327 use rustls::pki_types::ServerName;
328
329 let name = ServerName::try_from(server_name.to_string())
330 .map_err(|e| TlsError::ConfigError(format!("Invalid server name: {}", e)))?;
331
332 self.inner
333 .connect(name, stream)
334 .await
335 .map_err(|e| TlsError::HandshakeError(e.to_string()))
336 }
337}
338
339#[cfg(test)]
340#[allow(clippy::unwrap_used)]
341mod tests {
342 use super::*;
343
344 #[test]
345 fn test_tls_config_server() {
346 let config = TlsConfig::server("cert.pem", "key.pem");
347 assert_eq!(config.cert_path, Some("cert.pem".to_string()));
348 assert_eq!(config.key_path, Some("key.pem".to_string()));
349 assert!(!config.verify_peer);
350 }
351
352 #[test]
353 fn test_tls_config_client() {
354 let config = TlsConfig::client(Some("ca.pem"));
355 assert!(config.cert_path.is_none());
356 assert_eq!(config.ca_path, Some("ca.pem".to_string()));
357 assert!(config.verify_peer);
358 }
359
360 #[test]
361 fn test_tls_config_mtls() {
362 let config = TlsConfig::mtls("cert.pem", "key.pem", "ca.pem");
363 assert_eq!(config.cert_path, Some("cert.pem".to_string()));
364 assert_eq!(config.key_path, Some("key.pem".to_string()));
365 assert_eq!(config.ca_path, Some("ca.pem".to_string()));
366 assert!(config.verify_peer);
367 }
368
369 #[test]
370 fn test_tls_acceptor_requires_cert() {
371 let config = TlsConfig::default();
372 let result = TlsAcceptor::new(config);
373 assert!(result.is_err());
374 }
375
376 #[test]
377 fn test_tls_acceptor_requires_key() {
378 let config = TlsConfig {
379 cert_path: Some("cert.pem".into()),
380 ..Default::default()
381 };
382 let result = TlsAcceptor::new(config);
383 assert!(result.is_err());
384 }
385
386 #[test]
387 fn test_tls_connector_default() {
388 let config = TlsConfig::client(None::<&str>);
389 let result = TlsConnector::new(config);
390 assert!(result.is_ok());
391 }
392}