1use std::path::Path;
28
29#[cfg(feature = "tls")]
30use std::sync::Arc;
31
32#[derive(Debug, Clone)]
34pub struct TlsConfig {
35 pub cert_path: Option<String>,
37 pub key_path: Option<String>,
39 pub ca_path: Option<String>,
41 pub verify_peer: bool,
43 pub server_name: Option<String>,
45}
46
47impl Default for TlsConfig {
48 fn default() -> Self {
49 Self {
50 cert_path: None,
51 key_path: None,
52 ca_path: None,
53 verify_peer: true,
54 server_name: None,
55 }
56 }
57}
58
59impl TlsConfig {
60 pub fn server(cert_path: impl AsRef<Path>, key_path: impl AsRef<Path>) -> Self {
62 Self {
63 cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
64 key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
65 ca_path: None,
66 verify_peer: false,
67 server_name: None,
68 }
69 }
70
71 pub fn client(ca_path: Option<impl AsRef<Path>>) -> Self {
73 Self {
74 cert_path: None,
75 key_path: None,
76 ca_path: ca_path.map(|p| p.as_ref().to_string_lossy().into_owned()),
77 verify_peer: true,
78 server_name: None,
79 }
80 }
81
82 pub fn mtls(
84 cert_path: impl AsRef<Path>,
85 key_path: impl AsRef<Path>,
86 ca_path: impl AsRef<Path>,
87 ) -> Self {
88 Self {
89 cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
90 key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
91 ca_path: Some(ca_path.as_ref().to_string_lossy().into_owned()),
92 verify_peer: true,
93 server_name: None,
94 }
95 }
96
97 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
99 self.server_name = Some(name.into());
100 self
101 }
102
103 pub fn with_insecure(mut self) -> Self {
105 self.verify_peer = false;
106 self
107 }
108}
109
110#[derive(Debug)]
112pub enum TlsError {
113 CertificateError(String),
115 KeyError(String),
117 CaError(String),
119 HandshakeError(String),
121 IoError(std::io::Error),
123 ConfigError(String),
125}
126
127impl std::fmt::Display for TlsError {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 match self {
130 TlsError::CertificateError(msg) => write!(f, "Certificate error: {}", msg),
131 TlsError::KeyError(msg) => write!(f, "Key error: {}", msg),
132 TlsError::CaError(msg) => write!(f, "CA error: {}", msg),
133 TlsError::HandshakeError(msg) => write!(f, "Handshake error: {}", msg),
134 TlsError::IoError(e) => write!(f, "I/O error: {}", e),
135 TlsError::ConfigError(msg) => write!(f, "Config error: {}", msg),
136 }
137 }
138}
139
140impl std::error::Error for TlsError {}
141
142impl From<std::io::Error> for TlsError {
143 fn from(e: std::io::Error) -> Self {
144 TlsError::IoError(e)
145 }
146}
147
148pub type TlsResult<T> = std::result::Result<T, TlsError>;
150
151#[derive(Clone)]
153pub struct TlsAcceptor {
154 config: TlsConfig,
155 #[cfg(feature = "tls")]
156 inner: Arc<tokio_rustls::TlsAcceptor>,
157}
158
159impl TlsAcceptor {
160 pub fn new(config: TlsConfig) -> TlsResult<Self> {
162 if config.cert_path.is_none() {
163 return Err(TlsError::ConfigError(
164 "Server TLS config requires certificate path".into(),
165 ));
166 }
167 if config.key_path.is_none() {
168 return Err(TlsError::ConfigError(
169 "Server TLS config requires key path".into(),
170 ));
171 }
172
173 #[cfg(feature = "tls")]
174 {
175 let inner = Self::build_acceptor(&config)?;
176 Ok(Self {
177 config,
178 inner: Arc::new(inner),
179 })
180 }
181
182 #[cfg(not(feature = "tls"))]
183 {
184 Ok(Self { config })
185 }
186 }
187
188 #[cfg(feature = "tls")]
189 fn build_acceptor(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsAcceptor> {
190 use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
191
192 let cert_path = config
193 .cert_path
194 .as_ref()
195 .ok_or_else(|| TlsError::ConfigError("Certificate path required".into()))?;
196 let key_path = config
197 .key_path
198 .as_ref()
199 .ok_or_else(|| TlsError::ConfigError("Key path required".into()))?;
200
201 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(cert_path)
203 .map_err(|e| TlsError::CertificateError(format!("{}: {}", cert_path, e)))?
204 .filter_map(|r| r.ok())
205 .collect();
206
207 if certs.is_empty() {
208 return Err(TlsError::CertificateError("No certificates found".into()));
209 }
210
211 let key = PrivateKeyDer::from_pem_file(key_path)
213 .map_err(|e| TlsError::KeyError(format!("{}: {}", key_path, e)))?;
214
215 let server_config = rustls::ServerConfig::builder()
217 .with_no_client_auth()
218 .with_single_cert(certs, key)
219 .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?;
220
221 Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
222 }
223
224 pub fn config(&self) -> &TlsConfig {
226 &self.config
227 }
228
229 #[cfg(feature = "tls")]
231 pub async fn accept<S>(&self, stream: S) -> TlsResult<tokio_rustls::server::TlsStream<S>>
232 where
233 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
234 {
235 self.inner
236 .accept(stream)
237 .await
238 .map_err(|e| TlsError::HandshakeError(e.to_string()))
239 }
240}
241
242#[derive(Clone)]
244pub struct TlsConnector {
245 config: TlsConfig,
246 #[cfg(feature = "tls")]
247 inner: Arc<tokio_rustls::TlsConnector>,
248}
249
250impl TlsConnector {
251 pub fn new(config: TlsConfig) -> TlsResult<Self> {
253 #[cfg(feature = "tls")]
254 {
255 let inner = Self::build_connector(&config)?;
256 Ok(Self {
257 config,
258 inner: Arc::new(inner),
259 })
260 }
261
262 #[cfg(not(feature = "tls"))]
263 {
264 Ok(Self { config })
265 }
266 }
267
268 #[cfg(feature = "tls")]
269 fn build_connector(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsConnector> {
270 use rustls::RootCertStore;
271 use rustls_pki_types::{CertificateDer, pem::PemObject};
272
273 let mut root_store = RootCertStore::empty();
274
275 if let Some(ref ca_path) = config.ca_path {
277 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(ca_path)
278 .map_err(|e| TlsError::CaError(format!("{}: {}", ca_path, e)))?
279 .filter_map(|r| r.ok())
280 .collect();
281
282 for cert in certs {
283 root_store
284 .add(cert)
285 .map_err(|e| TlsError::CaError(format!("Failed to add CA: {}", e)))?;
286 }
287 } else {
288 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
290 }
291
292 let client_config = rustls::ClientConfig::builder()
294 .with_root_certificates(root_store)
295 .with_no_client_auth();
296
297 Ok(tokio_rustls::TlsConnector::from(Arc::new(client_config)))
298 }
299
300 pub fn config(&self) -> &TlsConfig {
302 &self.config
303 }
304
305 #[cfg(feature = "tls")]
307 pub async fn connect<S>(
308 &self,
309 server_name: &str,
310 stream: S,
311 ) -> TlsResult<tokio_rustls::client::TlsStream<S>>
312 where
313 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
314 {
315 use rustls::pki_types::ServerName;
316
317 let name = ServerName::try_from(server_name.to_string())
318 .map_err(|e| TlsError::ConfigError(format!("Invalid server name: {}", e)))?;
319
320 self.inner
321 .connect(name, stream)
322 .await
323 .map_err(|e| TlsError::HandshakeError(e.to_string()))
324 }
325}
326
327#[cfg(test)]
328#[allow(clippy::unwrap_used)]
329mod tests {
330 use super::*;
331
332 #[test]
333 fn test_tls_config_server() {
334 let config = TlsConfig::server("cert.pem", "key.pem");
335 assert_eq!(config.cert_path, Some("cert.pem".to_string()));
336 assert_eq!(config.key_path, Some("key.pem".to_string()));
337 assert!(!config.verify_peer);
338 }
339
340 #[test]
341 fn test_tls_config_client() {
342 let config = TlsConfig::client(Some("ca.pem"));
343 assert!(config.cert_path.is_none());
344 assert_eq!(config.ca_path, Some("ca.pem".to_string()));
345 assert!(config.verify_peer);
346 }
347
348 #[test]
349 fn test_tls_config_mtls() {
350 let config = TlsConfig::mtls("cert.pem", "key.pem", "ca.pem");
351 assert_eq!(config.cert_path, Some("cert.pem".to_string()));
352 assert_eq!(config.key_path, Some("key.pem".to_string()));
353 assert_eq!(config.ca_path, Some("ca.pem".to_string()));
354 assert!(config.verify_peer);
355 }
356
357 #[test]
358 fn test_tls_acceptor_requires_cert() {
359 let config = TlsConfig::default();
360 let result = TlsAcceptor::new(config);
361 assert!(result.is_err());
362 }
363
364 #[test]
365 fn test_tls_acceptor_requires_key() {
366 let config = TlsConfig {
367 cert_path: Some("cert.pem".into()),
368 ..Default::default()
369 };
370 let result = TlsAcceptor::new(config);
371 assert!(result.is_err());
372 }
373
374 #[test]
375 fn test_tls_connector_default() {
376 let config = TlsConfig::client(None::<&str>);
377 let result = TlsConnector::new(config);
378 assert!(result.is_ok());
379 }
380}