Skip to main content

lnc_network/
tls.rs

1//! TLS support for LANCE connections
2//!
3//! This module provides TLS configuration and connection handling for secure
4//! client-server communication using rustls.
5//!
6//! # Feature Flag
7//!
8//! TLS support requires the `tls` feature:
9//! ```toml
10//! lnc-network = { version = "0.1", features = ["tls"] }
11//! ```
12//!
13//! # Usage
14//!
15//! ```rust,ignore
16//! use lnc_network::tls::{TlsConfig, TlsAcceptor, TlsConnector};
17//!
18//! // Server-side
19//! let config = TlsConfig::server("cert.pem", "key.pem")?;
20//! let acceptor = TlsAcceptor::new(config)?;
21//!
22//! // Client-side
23//! let config = TlsConfig::client(Some("ca.pem"))?;
24//! let connector = TlsConnector::new(config)?;
25//! ```
26
27use std::path::Path;
28
29#[cfg(feature = "tls")]
30use std::sync::Arc;
31
32/// TLS configuration for LANCE connections
33#[derive(Debug, Clone)]
34pub struct TlsConfig {
35    /// Path to certificate file (PEM format)
36    pub cert_path: Option<String>,
37    /// Path to private key file (PEM format)
38    pub key_path: Option<String>,
39    /// Path to CA certificate for verification
40    pub ca_path: Option<String>,
41    /// Whether to verify peer certificates
42    pub verify_peer: bool,
43    /// Server name for SNI (client only)
44    pub server_name: Option<String>,
45}
46
47impl Default for TlsConfig {
48    fn default() -> Self {
49        Self {
50            cert_path: None,
51            key_path: None,
52            ca_path: None,
53            verify_peer: true,
54            server_name: None,
55        }
56    }
57}
58
59impl TlsConfig {
60    /// Create a new TLS config for server mode
61    pub fn server(cert_path: impl AsRef<Path>, key_path: impl AsRef<Path>) -> Self {
62        Self {
63            cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
64            key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
65            ca_path: None,
66            verify_peer: false,
67            server_name: None,
68        }
69    }
70
71    /// Create a new TLS config for client mode
72    pub fn client(ca_path: Option<impl AsRef<Path>>) -> Self {
73        Self {
74            cert_path: None,
75            key_path: None,
76            ca_path: ca_path.map(|p| p.as_ref().to_string_lossy().into_owned()),
77            verify_peer: true,
78            server_name: None,
79        }
80    }
81
82    /// Create a new TLS config for mutual TLS (mTLS)
83    pub fn mtls(
84        cert_path: impl AsRef<Path>,
85        key_path: impl AsRef<Path>,
86        ca_path: impl AsRef<Path>,
87    ) -> Self {
88        Self {
89            cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
90            key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
91            ca_path: Some(ca_path.as_ref().to_string_lossy().into_owned()),
92            verify_peer: true,
93            server_name: None,
94        }
95    }
96
97    /// Set the server name for SNI
98    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
99        self.server_name = Some(name.into());
100        self
101    }
102
103    /// Disable peer certificate verification (NOT RECOMMENDED for production)
104    pub fn with_insecure(mut self) -> Self {
105        self.verify_peer = false;
106        self
107    }
108}
109
110/// Error type for TLS operations
111#[derive(Debug)]
112pub enum TlsError {
113    /// Certificate file not found or invalid
114    CertificateError(String),
115    /// Private key file not found or invalid
116    KeyError(String),
117    /// CA certificate file not found or invalid
118    CaError(String),
119    /// TLS handshake failed
120    HandshakeError(String),
121    /// I/O error during TLS operation
122    IoError(std::io::Error),
123    /// Configuration error
124    ConfigError(String),
125}
126
127impl std::fmt::Display for TlsError {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        match self {
130            TlsError::CertificateError(msg) => write!(f, "Certificate error: {}", msg),
131            TlsError::KeyError(msg) => write!(f, "Key error: {}", msg),
132            TlsError::CaError(msg) => write!(f, "CA error: {}", msg),
133            TlsError::HandshakeError(msg) => write!(f, "Handshake error: {}", msg),
134            TlsError::IoError(e) => write!(f, "I/O error: {}", e),
135            TlsError::ConfigError(msg) => write!(f, "Config error: {}", msg),
136        }
137    }
138}
139
140impl std::error::Error for TlsError {}
141
142impl From<std::io::Error> for TlsError {
143    fn from(e: std::io::Error) -> Self {
144        TlsError::IoError(e)
145    }
146}
147
148/// Result type for TLS operations
149pub type TlsResult<T> = std::result::Result<T, TlsError>;
150
151/// TLS acceptor for server-side connections
152#[derive(Clone)]
153pub struct TlsAcceptor {
154    config: TlsConfig,
155    #[cfg(feature = "tls")]
156    inner: Arc<tokio_rustls::TlsAcceptor>,
157}
158
159impl TlsAcceptor {
160    /// Create a new TLS acceptor with the given configuration
161    pub fn new(config: TlsConfig) -> TlsResult<Self> {
162        if config.cert_path.is_none() {
163            return Err(TlsError::ConfigError(
164                "Server TLS config requires certificate path".into(),
165            ));
166        }
167        if config.key_path.is_none() {
168            return Err(TlsError::ConfigError(
169                "Server TLS config requires key path".into(),
170            ));
171        }
172
173        #[cfg(feature = "tls")]
174        {
175            let inner = Self::build_acceptor(&config)?;
176            Ok(Self {
177                config,
178                inner: Arc::new(inner),
179            })
180        }
181
182        #[cfg(not(feature = "tls"))]
183        {
184            Ok(Self { config })
185        }
186    }
187
188    #[cfg(feature = "tls")]
189    fn build_acceptor(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsAcceptor> {
190        use rustls::crypto::ring::default_provider;
191        use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
192
193        let cert_path = config
194            .cert_path
195            .as_ref()
196            .ok_or_else(|| TlsError::ConfigError("Certificate path required".into()))?;
197        let key_path = config
198            .key_path
199            .as_ref()
200            .ok_or_else(|| TlsError::ConfigError("Key path required".into()))?;
201
202        // Load certificates using PemObject API
203        let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(cert_path)
204            .map_err(|e| TlsError::CertificateError(format!("{}: {}", cert_path, e)))?
205            .filter_map(|r| r.ok())
206            .collect();
207
208        if certs.is_empty() {
209            return Err(TlsError::CertificateError("No certificates found".into()));
210        }
211
212        // Load private key using PemObject API
213        let key = PrivateKeyDer::from_pem_file(key_path)
214            .map_err(|e| TlsError::KeyError(format!("{}: {}", key_path, e)))?;
215
216        // Build server config
217        let server_config = rustls::ServerConfig::builder_with_provider(default_provider().into())
218            .with_safe_default_protocol_versions()
219            .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?
220            .with_no_client_auth()
221            .with_single_cert(certs, key)
222            .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?;
223
224        Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
225    }
226
227    /// Get the TLS configuration
228    pub fn config(&self) -> &TlsConfig {
229        &self.config
230    }
231
232    /// Accept a TLS connection (requires `tls` feature)
233    #[cfg(feature = "tls")]
234    pub async fn accept<S>(&self, stream: S) -> TlsResult<tokio_rustls::server::TlsStream<S>>
235    where
236        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
237    {
238        self.inner
239            .accept(stream)
240            .await
241            .map_err(|e| TlsError::HandshakeError(e.to_string()))
242    }
243}
244
245/// TLS connector for client-side connections
246#[derive(Clone)]
247pub struct TlsConnector {
248    config: TlsConfig,
249    #[cfg(feature = "tls")]
250    inner: Arc<tokio_rustls::TlsConnector>,
251}
252
253impl TlsConnector {
254    /// Create a new TLS connector with the given configuration
255    pub fn new(config: TlsConfig) -> TlsResult<Self> {
256        #[cfg(feature = "tls")]
257        {
258            let inner = Self::build_connector(&config)?;
259            Ok(Self {
260                config,
261                inner: Arc::new(inner),
262            })
263        }
264
265        #[cfg(not(feature = "tls"))]
266        {
267            Ok(Self { config })
268        }
269    }
270
271    #[cfg(feature = "tls")]
272    fn build_connector(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsConnector> {
273        use rustls::RootCertStore;
274        use rustls::crypto::ring::default_provider;
275        use rustls_pki_types::{CertificateDer, pem::PemObject};
276
277        let mut root_store = RootCertStore::empty();
278
279        // Load custom CA if provided
280        if let Some(ref ca_path) = config.ca_path {
281            let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(ca_path)
282                .map_err(|e| TlsError::CaError(format!("{}: {}", ca_path, e)))?
283                .filter_map(|r| r.ok())
284                .collect();
285
286            for cert in certs {
287                root_store
288                    .add(cert)
289                    .map_err(|e| TlsError::CaError(format!("Failed to add CA: {}", e)))?;
290            }
291        } else {
292            // Use system root certificates
293            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
294        }
295
296        // Build client config
297        let client_config = rustls::ClientConfig::builder_with_provider(default_provider().into())
298            .with_safe_default_protocol_versions()
299            .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?
300            .with_root_certificates(root_store)
301            .with_no_client_auth();
302
303        Ok(tokio_rustls::TlsConnector::from(Arc::new(client_config)))
304    }
305
306    /// Get the TLS configuration
307    pub fn config(&self) -> &TlsConfig {
308        &self.config
309    }
310
311    /// Connect with TLS (requires `tls` feature)
312    #[cfg(feature = "tls")]
313    pub async fn connect<S>(
314        &self,
315        server_name: &str,
316        stream: S,
317    ) -> TlsResult<tokio_rustls::client::TlsStream<S>>
318    where
319        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
320    {
321        use rustls::pki_types::ServerName;
322
323        let name = ServerName::try_from(server_name.to_string())
324            .map_err(|e| TlsError::ConfigError(format!("Invalid server name: {}", e)))?;
325
326        self.inner
327            .connect(name, stream)
328            .await
329            .map_err(|e| TlsError::HandshakeError(e.to_string()))
330    }
331}
332
333#[cfg(test)]
334#[allow(clippy::unwrap_used)]
335mod tests {
336    use super::*;
337
338    #[test]
339    fn test_tls_config_server() {
340        let config = TlsConfig::server("cert.pem", "key.pem");
341        assert_eq!(config.cert_path, Some("cert.pem".to_string()));
342        assert_eq!(config.key_path, Some("key.pem".to_string()));
343        assert!(!config.verify_peer);
344    }
345
346    #[test]
347    fn test_tls_config_client() {
348        let config = TlsConfig::client(Some("ca.pem"));
349        assert!(config.cert_path.is_none());
350        assert_eq!(config.ca_path, Some("ca.pem".to_string()));
351        assert!(config.verify_peer);
352    }
353
354    #[test]
355    fn test_tls_config_mtls() {
356        let config = TlsConfig::mtls("cert.pem", "key.pem", "ca.pem");
357        assert_eq!(config.cert_path, Some("cert.pem".to_string()));
358        assert_eq!(config.key_path, Some("key.pem".to_string()));
359        assert_eq!(config.ca_path, Some("ca.pem".to_string()));
360        assert!(config.verify_peer);
361    }
362
363    #[test]
364    fn test_tls_acceptor_requires_cert() {
365        let config = TlsConfig::default();
366        let result = TlsAcceptor::new(config);
367        assert!(result.is_err());
368    }
369
370    #[test]
371    fn test_tls_acceptor_requires_key() {
372        let config = TlsConfig {
373            cert_path: Some("cert.pem".into()),
374            ..Default::default()
375        };
376        let result = TlsAcceptor::new(config);
377        assert!(result.is_err());
378    }
379
380    #[test]
381    fn test_tls_connector_default() {
382        let config = TlsConfig::client(None::<&str>);
383        let result = TlsConnector::new(config);
384        assert!(result.is_ok());
385    }
386}