1use std::path::Path;
28
29#[cfg(feature = "tls")]
30use std::sync::Arc;
31
32#[derive(Debug, Clone)]
34pub struct TlsConfig {
35 pub cert_path: Option<String>,
37 pub key_path: Option<String>,
39 pub ca_path: Option<String>,
41 pub verify_peer: bool,
43 pub server_name: Option<String>,
45}
46
47impl Default for TlsConfig {
48 fn default() -> Self {
49 Self {
50 cert_path: None,
51 key_path: None,
52 ca_path: None,
53 verify_peer: true,
54 server_name: None,
55 }
56 }
57}
58
59impl TlsConfig {
60 pub fn server(cert_path: impl AsRef<Path>, key_path: impl AsRef<Path>) -> Self {
62 Self {
63 cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
64 key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
65 ca_path: None,
66 verify_peer: false,
67 server_name: None,
68 }
69 }
70
71 pub fn client(ca_path: Option<impl AsRef<Path>>) -> Self {
73 Self {
74 cert_path: None,
75 key_path: None,
76 ca_path: ca_path.map(|p| p.as_ref().to_string_lossy().into_owned()),
77 verify_peer: true,
78 server_name: None,
79 }
80 }
81
82 pub fn mtls(
84 cert_path: impl AsRef<Path>,
85 key_path: impl AsRef<Path>,
86 ca_path: impl AsRef<Path>,
87 ) -> Self {
88 Self {
89 cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
90 key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
91 ca_path: Some(ca_path.as_ref().to_string_lossy().into_owned()),
92 verify_peer: true,
93 server_name: None,
94 }
95 }
96
97 pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
99 self.server_name = Some(name.into());
100 self
101 }
102
103 pub fn with_insecure(mut self) -> Self {
105 self.verify_peer = false;
106 self
107 }
108}
109
110#[derive(Debug)]
112pub enum TlsError {
113 CertificateError(String),
115 KeyError(String),
117 CaError(String),
119 HandshakeError(String),
121 IoError(std::io::Error),
123 ConfigError(String),
125}
126
127impl std::fmt::Display for TlsError {
128 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129 match self {
130 TlsError::CertificateError(msg) => write!(f, "Certificate error: {}", msg),
131 TlsError::KeyError(msg) => write!(f, "Key error: {}", msg),
132 TlsError::CaError(msg) => write!(f, "CA error: {}", msg),
133 TlsError::HandshakeError(msg) => write!(f, "Handshake error: {}", msg),
134 TlsError::IoError(e) => write!(f, "I/O error: {}", e),
135 TlsError::ConfigError(msg) => write!(f, "Config error: {}", msg),
136 }
137 }
138}
139
140impl std::error::Error for TlsError {}
141
142impl From<std::io::Error> for TlsError {
143 fn from(e: std::io::Error) -> Self {
144 TlsError::IoError(e)
145 }
146}
147
148pub type TlsResult<T> = std::result::Result<T, TlsError>;
150
151#[derive(Clone)]
153pub struct TlsAcceptor {
154 config: TlsConfig,
155 #[cfg(feature = "tls")]
156 inner: Arc<tokio_rustls::TlsAcceptor>,
157}
158
159impl TlsAcceptor {
160 pub fn new(config: TlsConfig) -> TlsResult<Self> {
162 if config.cert_path.is_none() {
163 return Err(TlsError::ConfigError(
164 "Server TLS config requires certificate path".into(),
165 ));
166 }
167 if config.key_path.is_none() {
168 return Err(TlsError::ConfigError(
169 "Server TLS config requires key path".into(),
170 ));
171 }
172
173 #[cfg(feature = "tls")]
174 {
175 let inner = Self::build_acceptor(&config)?;
176 Ok(Self {
177 config,
178 inner: Arc::new(inner),
179 })
180 }
181
182 #[cfg(not(feature = "tls"))]
183 {
184 Ok(Self { config })
185 }
186 }
187
188 #[cfg(feature = "tls")]
189 fn build_acceptor(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsAcceptor> {
190 use rustls::crypto::ring::default_provider;
191 use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
192
193 let cert_path = config
194 .cert_path
195 .as_ref()
196 .ok_or_else(|| TlsError::ConfigError("Certificate path required".into()))?;
197 let key_path = config
198 .key_path
199 .as_ref()
200 .ok_or_else(|| TlsError::ConfigError("Key path required".into()))?;
201
202 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(cert_path)
204 .map_err(|e| TlsError::CertificateError(format!("{}: {}", cert_path, e)))?
205 .filter_map(|r| r.ok())
206 .collect();
207
208 if certs.is_empty() {
209 return Err(TlsError::CertificateError("No certificates found".into()));
210 }
211
212 let key = PrivateKeyDer::from_pem_file(key_path)
214 .map_err(|e| TlsError::KeyError(format!("{}: {}", key_path, e)))?;
215
216 let server_config = rustls::ServerConfig::builder_with_provider(default_provider().into())
218 .with_safe_default_protocol_versions()
219 .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?
220 .with_no_client_auth()
221 .with_single_cert(certs, key)
222 .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?;
223
224 Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
225 }
226
227 pub fn config(&self) -> &TlsConfig {
229 &self.config
230 }
231
232 #[cfg(feature = "tls")]
234 pub async fn accept<S>(&self, stream: S) -> TlsResult<tokio_rustls::server::TlsStream<S>>
235 where
236 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
237 {
238 self.inner
239 .accept(stream)
240 .await
241 .map_err(|e| TlsError::HandshakeError(e.to_string()))
242 }
243}
244
245#[derive(Clone)]
247pub struct TlsConnector {
248 config: TlsConfig,
249 #[cfg(feature = "tls")]
250 inner: Arc<tokio_rustls::TlsConnector>,
251}
252
253impl TlsConnector {
254 pub fn new(config: TlsConfig) -> TlsResult<Self> {
256 #[cfg(feature = "tls")]
257 {
258 let inner = Self::build_connector(&config)?;
259 Ok(Self {
260 config,
261 inner: Arc::new(inner),
262 })
263 }
264
265 #[cfg(not(feature = "tls"))]
266 {
267 Ok(Self { config })
268 }
269 }
270
271 #[cfg(feature = "tls")]
272 fn build_connector(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsConnector> {
273 use rustls::RootCertStore;
274 use rustls::crypto::ring::default_provider;
275 use rustls_pki_types::{CertificateDer, pem::PemObject};
276
277 let mut root_store = RootCertStore::empty();
278
279 if let Some(ref ca_path) = config.ca_path {
281 let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(ca_path)
282 .map_err(|e| TlsError::CaError(format!("{}: {}", ca_path, e)))?
283 .filter_map(|r| r.ok())
284 .collect();
285
286 for cert in certs {
287 root_store
288 .add(cert)
289 .map_err(|e| TlsError::CaError(format!("Failed to add CA: {}", e)))?;
290 }
291 } else {
292 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
294 }
295
296 let client_config = rustls::ClientConfig::builder_with_provider(default_provider().into())
298 .with_safe_default_protocol_versions()
299 .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?
300 .with_root_certificates(root_store)
301 .with_no_client_auth();
302
303 Ok(tokio_rustls::TlsConnector::from(Arc::new(client_config)))
304 }
305
306 pub fn config(&self) -> &TlsConfig {
308 &self.config
309 }
310
311 #[cfg(feature = "tls")]
313 pub async fn connect<S>(
314 &self,
315 server_name: &str,
316 stream: S,
317 ) -> TlsResult<tokio_rustls::client::TlsStream<S>>
318 where
319 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
320 {
321 use rustls::pki_types::ServerName;
322
323 let name = ServerName::try_from(server_name.to_string())
324 .map_err(|e| TlsError::ConfigError(format!("Invalid server name: {}", e)))?;
325
326 self.inner
327 .connect(name, stream)
328 .await
329 .map_err(|e| TlsError::HandshakeError(e.to_string()))
330 }
331}
332
333#[cfg(test)]
334#[allow(clippy::unwrap_used)]
335mod tests {
336 use super::*;
337
338 #[test]
339 fn test_tls_config_server() {
340 let config = TlsConfig::server("cert.pem", "key.pem");
341 assert_eq!(config.cert_path, Some("cert.pem".to_string()));
342 assert_eq!(config.key_path, Some("key.pem".to_string()));
343 assert!(!config.verify_peer);
344 }
345
346 #[test]
347 fn test_tls_config_client() {
348 let config = TlsConfig::client(Some("ca.pem"));
349 assert!(config.cert_path.is_none());
350 assert_eq!(config.ca_path, Some("ca.pem".to_string()));
351 assert!(config.verify_peer);
352 }
353
354 #[test]
355 fn test_tls_config_mtls() {
356 let config = TlsConfig::mtls("cert.pem", "key.pem", "ca.pem");
357 assert_eq!(config.cert_path, Some("cert.pem".to_string()));
358 assert_eq!(config.key_path, Some("key.pem".to_string()));
359 assert_eq!(config.ca_path, Some("ca.pem".to_string()));
360 assert!(config.verify_peer);
361 }
362
363 #[test]
364 fn test_tls_acceptor_requires_cert() {
365 let config = TlsConfig::default();
366 let result = TlsAcceptor::new(config);
367 assert!(result.is_err());
368 }
369
370 #[test]
371 fn test_tls_acceptor_requires_key() {
372 let config = TlsConfig {
373 cert_path: Some("cert.pem".into()),
374 ..Default::default()
375 };
376 let result = TlsAcceptor::new(config);
377 assert!(result.is_err());
378 }
379
380 #[test]
381 fn test_tls_connector_default() {
382 let config = TlsConfig::client(None::<&str>);
383 let result = TlsConnector::new(config);
384 assert!(result.is_ok());
385 }
386}