Skip to main content

lnc_network/
tls.rs

1//! TLS support for LANCE connections
2//!
3//! This module provides TLS configuration and connection handling for secure
4//! client-server communication using rustls.
5//!
6//! # Feature Flag
7//!
8//! TLS support requires the `tls` feature:
9//! ```toml
10//! lnc-network = { version = "0.1", features = ["tls"] }
11//! ```
12//!
13//! # Usage
14//!
15//! ```rust,ignore
16//! use lnc_network::tls::{TlsConfig, TlsAcceptor, TlsConnector};
17//!
18//! // Server-side
19//! let config = TlsConfig::server("cert.pem", "key.pem")?;
20//! let acceptor = TlsAcceptor::new(config)?;
21//!
22//! // Client-side
23//! let config = TlsConfig::client(Some("ca.pem"))?;
24//! let connector = TlsConnector::new(config)?;
25//! ```
26
27use std::path::Path;
28
29#[cfg(feature = "tls")]
30use std::sync::Arc;
31
32/// TLS configuration for LANCE connections
33#[derive(Debug, Clone)]
34pub struct TlsConfig {
35    /// Path to certificate file (PEM format)
36    pub cert_path: Option<String>,
37    /// Path to private key file (PEM format)
38    pub key_path: Option<String>,
39    /// Path to CA certificate for verification
40    pub ca_path: Option<String>,
41    /// Whether to verify peer certificates
42    pub verify_peer: bool,
43    /// Server name for SNI (client only)
44    pub server_name: Option<String>,
45}
46
47impl Default for TlsConfig {
48    fn default() -> Self {
49        Self {
50            cert_path: None,
51            key_path: None,
52            ca_path: None,
53            verify_peer: true,
54            server_name: None,
55        }
56    }
57}
58
59impl TlsConfig {
60    /// Create a new TLS config for server mode
61    pub fn server(cert_path: impl AsRef<Path>, key_path: impl AsRef<Path>) -> Self {
62        Self {
63            cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
64            key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
65            ca_path: None,
66            verify_peer: false,
67            server_name: None,
68        }
69    }
70
71    /// Create a new TLS config for client mode
72    pub fn client(ca_path: Option<impl AsRef<Path>>) -> Self {
73        Self {
74            cert_path: None,
75            key_path: None,
76            ca_path: ca_path.map(|p| p.as_ref().to_string_lossy().into_owned()),
77            verify_peer: true,
78            server_name: None,
79        }
80    }
81
82    /// Create a new TLS config for mutual TLS (mTLS)
83    pub fn mtls(
84        cert_path: impl AsRef<Path>,
85        key_path: impl AsRef<Path>,
86        ca_path: impl AsRef<Path>,
87    ) -> Self {
88        Self {
89            cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
90            key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
91            ca_path: Some(ca_path.as_ref().to_string_lossy().into_owned()),
92            verify_peer: true,
93            server_name: None,
94        }
95    }
96
97    /// Set the server name for SNI
98    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
99        self.server_name = Some(name.into());
100        self
101    }
102
103    /// Disable peer certificate verification (NOT RECOMMENDED for production)
104    pub fn with_insecure(mut self) -> Self {
105        self.verify_peer = false;
106        self
107    }
108}
109
110/// Error type for TLS operations
111#[derive(Debug)]
112pub enum TlsError {
113    /// Certificate file not found or invalid
114    CertificateError(String),
115    /// Private key file not found or invalid
116    KeyError(String),
117    /// CA certificate file not found or invalid
118    CaError(String),
119    /// TLS handshake failed
120    HandshakeError(String),
121    /// I/O error during TLS operation
122    IoError(std::io::Error),
123    /// Configuration error
124    ConfigError(String),
125}
126
127impl std::fmt::Display for TlsError {
128    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
129        match self {
130            TlsError::CertificateError(msg) => write!(f, "Certificate error: {}", msg),
131            TlsError::KeyError(msg) => write!(f, "Key error: {}", msg),
132            TlsError::CaError(msg) => write!(f, "CA error: {}", msg),
133            TlsError::HandshakeError(msg) => write!(f, "Handshake error: {}", msg),
134            TlsError::IoError(e) => write!(f, "I/O error: {}", e),
135            TlsError::ConfigError(msg) => write!(f, "Config error: {}", msg),
136        }
137    }
138}
139
140impl std::error::Error for TlsError {}
141
142impl From<std::io::Error> for TlsError {
143    fn from(e: std::io::Error) -> Self {
144        TlsError::IoError(e)
145    }
146}
147
148/// Result type for TLS operations
149pub type TlsResult<T> = std::result::Result<T, TlsError>;
150
151/// TLS acceptor for server-side connections
152#[derive(Clone)]
153pub struct TlsAcceptor {
154    config: TlsConfig,
155    #[cfg(feature = "tls")]
156    inner: Arc<tokio_rustls::TlsAcceptor>,
157}
158
159impl TlsAcceptor {
160    /// Create a new TLS acceptor with the given configuration
161    pub fn new(config: TlsConfig) -> TlsResult<Self> {
162        if config.cert_path.is_none() {
163            return Err(TlsError::ConfigError(
164                "Server TLS config requires certificate path".into(),
165            ));
166        }
167        if config.key_path.is_none() {
168            return Err(TlsError::ConfigError(
169                "Server TLS config requires key path".into(),
170            ));
171        }
172
173        #[cfg(feature = "tls")]
174        {
175            let inner = Self::build_acceptor(&config)?;
176            Ok(Self {
177                config,
178                inner: Arc::new(inner),
179            })
180        }
181
182        #[cfg(not(feature = "tls"))]
183        {
184            Ok(Self { config })
185        }
186    }
187
188    #[cfg(feature = "tls")]
189    fn build_acceptor(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsAcceptor> {
190        use rustls_pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject};
191
192        let cert_path = config
193            .cert_path
194            .as_ref()
195            .ok_or_else(|| TlsError::ConfigError("Certificate path required".into()))?;
196        let key_path = config
197            .key_path
198            .as_ref()
199            .ok_or_else(|| TlsError::ConfigError("Key path required".into()))?;
200
201        // Load certificates using PemObject API
202        let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(cert_path)
203            .map_err(|e| TlsError::CertificateError(format!("{}: {}", cert_path, e)))?
204            .filter_map(|r| r.ok())
205            .collect();
206
207        if certs.is_empty() {
208            return Err(TlsError::CertificateError("No certificates found".into()));
209        }
210
211        // Load private key using PemObject API
212        let key = PrivateKeyDer::from_pem_file(key_path)
213            .map_err(|e| TlsError::KeyError(format!("{}: {}", key_path, e)))?;
214
215        // Build server config
216        let server_config = rustls::ServerConfig::builder()
217            .with_no_client_auth()
218            .with_single_cert(certs, key)
219            .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?;
220
221        Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
222    }
223
224    /// Get the TLS configuration
225    pub fn config(&self) -> &TlsConfig {
226        &self.config
227    }
228
229    /// Accept a TLS connection (requires `tls` feature)
230    #[cfg(feature = "tls")]
231    pub async fn accept<S>(&self, stream: S) -> TlsResult<tokio_rustls::server::TlsStream<S>>
232    where
233        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
234    {
235        self.inner
236            .accept(stream)
237            .await
238            .map_err(|e| TlsError::HandshakeError(e.to_string()))
239    }
240}
241
242/// TLS connector for client-side connections
243#[derive(Clone)]
244pub struct TlsConnector {
245    config: TlsConfig,
246    #[cfg(feature = "tls")]
247    inner: Arc<tokio_rustls::TlsConnector>,
248}
249
250impl TlsConnector {
251    /// Create a new TLS connector with the given configuration
252    pub fn new(config: TlsConfig) -> TlsResult<Self> {
253        #[cfg(feature = "tls")]
254        {
255            let inner = Self::build_connector(&config)?;
256            Ok(Self {
257                config,
258                inner: Arc::new(inner),
259            })
260        }
261
262        #[cfg(not(feature = "tls"))]
263        {
264            Ok(Self { config })
265        }
266    }
267
268    #[cfg(feature = "tls")]
269    fn build_connector(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsConnector> {
270        use rustls::RootCertStore;
271        use rustls_pki_types::{CertificateDer, pem::PemObject};
272
273        let mut root_store = RootCertStore::empty();
274
275        // Load custom CA if provided
276        if let Some(ref ca_path) = config.ca_path {
277            let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_file_iter(ca_path)
278                .map_err(|e| TlsError::CaError(format!("{}: {}", ca_path, e)))?
279                .filter_map(|r| r.ok())
280                .collect();
281
282            for cert in certs {
283                root_store
284                    .add(cert)
285                    .map_err(|e| TlsError::CaError(format!("Failed to add CA: {}", e)))?;
286            }
287        } else {
288            // Use system root certificates
289            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
290        }
291
292        // Build client config
293        let client_config = rustls::ClientConfig::builder()
294            .with_root_certificates(root_store)
295            .with_no_client_auth();
296
297        Ok(tokio_rustls::TlsConnector::from(Arc::new(client_config)))
298    }
299
300    /// Get the TLS configuration
301    pub fn config(&self) -> &TlsConfig {
302        &self.config
303    }
304
305    /// Connect with TLS (requires `tls` feature)
306    #[cfg(feature = "tls")]
307    pub async fn connect<S>(
308        &self,
309        server_name: &str,
310        stream: S,
311    ) -> TlsResult<tokio_rustls::client::TlsStream<S>>
312    where
313        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
314    {
315        use rustls::pki_types::ServerName;
316
317        let name = ServerName::try_from(server_name.to_string())
318            .map_err(|e| TlsError::ConfigError(format!("Invalid server name: {}", e)))?;
319
320        self.inner
321            .connect(name, stream)
322            .await
323            .map_err(|e| TlsError::HandshakeError(e.to_string()))
324    }
325}
326
327#[cfg(test)]
328#[allow(clippy::unwrap_used)]
329mod tests {
330    use super::*;
331
332    #[test]
333    fn test_tls_config_server() {
334        let config = TlsConfig::server("cert.pem", "key.pem");
335        assert_eq!(config.cert_path, Some("cert.pem".to_string()));
336        assert_eq!(config.key_path, Some("key.pem".to_string()));
337        assert!(!config.verify_peer);
338    }
339
340    #[test]
341    fn test_tls_config_client() {
342        let config = TlsConfig::client(Some("ca.pem"));
343        assert!(config.cert_path.is_none());
344        assert_eq!(config.ca_path, Some("ca.pem".to_string()));
345        assert!(config.verify_peer);
346    }
347
348    #[test]
349    fn test_tls_config_mtls() {
350        let config = TlsConfig::mtls("cert.pem", "key.pem", "ca.pem");
351        assert_eq!(config.cert_path, Some("cert.pem".to_string()));
352        assert_eq!(config.key_path, Some("key.pem".to_string()));
353        assert_eq!(config.ca_path, Some("ca.pem".to_string()));
354        assert!(config.verify_peer);
355    }
356
357    #[test]
358    fn test_tls_acceptor_requires_cert() {
359        let config = TlsConfig::default();
360        let result = TlsAcceptor::new(config);
361        assert!(result.is_err());
362    }
363
364    #[test]
365    fn test_tls_acceptor_requires_key() {
366        let config = TlsConfig {
367            cert_path: Some("cert.pem".into()),
368            ..Default::default()
369        };
370        let result = TlsAcceptor::new(config);
371        assert!(result.is_err());
372    }
373
374    #[test]
375    fn test_tls_connector_default() {
376        let config = TlsConfig::client(None::<&str>);
377        let result = TlsConnector::new(config);
378        assert!(result.is_ok());
379    }
380}