Skip to main content

lnc_network/
tls.rs

1//! TLS support for LANCE connections
2//!
3//! This module provides TLS configuration and connection handling for secure
4//! client-server communication using rustls.
5//!
6//! # Feature Flag
7//!
8//! TLS support requires the `tls` feature:
9//! ```toml
10//! lnc-network = { version = "0.1", features = ["tls"] }
11//! ```
12//!
13//! # Usage
14//!
15//! ```rust,ignore
16//! use lnc_network::tls::{TlsConfig, TlsAcceptor, TlsConnector};
17//!
18//! // Server-side
19//! let config = TlsConfig::server("cert.pem", "key.pem")?;
20//! let acceptor = TlsAcceptor::new(config)?;
21//!
22//! // Client-side
23//! let config = TlsConfig::client(Some("ca.pem"))?;
24//! let connector = TlsConnector::new(config)?;
25//! ```
26
27use std::path::Path;
28
29#[cfg(feature = "tls")]
30use std::fs::File;
31#[cfg(feature = "tls")]
32use std::io::BufReader;
33#[cfg(feature = "tls")]
34use std::sync::Arc;
35
36/// TLS configuration for LANCE connections
37#[derive(Debug, Clone)]
38pub struct TlsConfig {
39    /// Path to certificate file (PEM format)
40    pub cert_path: Option<String>,
41    /// Path to private key file (PEM format)
42    pub key_path: Option<String>,
43    /// Path to CA certificate for verification
44    pub ca_path: Option<String>,
45    /// Whether to verify peer certificates
46    pub verify_peer: bool,
47    /// Server name for SNI (client only)
48    pub server_name: Option<String>,
49}
50
51impl Default for TlsConfig {
52    fn default() -> Self {
53        Self {
54            cert_path: None,
55            key_path: None,
56            ca_path: None,
57            verify_peer: true,
58            server_name: None,
59        }
60    }
61}
62
63impl TlsConfig {
64    /// Create a new TLS config for server mode
65    pub fn server(cert_path: impl AsRef<Path>, key_path: impl AsRef<Path>) -> Self {
66        Self {
67            cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
68            key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
69            ca_path: None,
70            verify_peer: false,
71            server_name: None,
72        }
73    }
74
75    /// Create a new TLS config for client mode
76    pub fn client(ca_path: Option<impl AsRef<Path>>) -> Self {
77        Self {
78            cert_path: None,
79            key_path: None,
80            ca_path: ca_path.map(|p| p.as_ref().to_string_lossy().into_owned()),
81            verify_peer: true,
82            server_name: None,
83        }
84    }
85
86    /// Create a new TLS config for mutual TLS (mTLS)
87    pub fn mtls(
88        cert_path: impl AsRef<Path>,
89        key_path: impl AsRef<Path>,
90        ca_path: impl AsRef<Path>,
91    ) -> Self {
92        Self {
93            cert_path: Some(cert_path.as_ref().to_string_lossy().into_owned()),
94            key_path: Some(key_path.as_ref().to_string_lossy().into_owned()),
95            ca_path: Some(ca_path.as_ref().to_string_lossy().into_owned()),
96            verify_peer: true,
97            server_name: None,
98        }
99    }
100
101    /// Set the server name for SNI
102    pub fn with_server_name(mut self, name: impl Into<String>) -> Self {
103        self.server_name = Some(name.into());
104        self
105    }
106
107    /// Disable peer certificate verification (NOT RECOMMENDED for production)
108    pub fn with_insecure(mut self) -> Self {
109        self.verify_peer = false;
110        self
111    }
112}
113
114/// Error type for TLS operations
115#[derive(Debug)]
116pub enum TlsError {
117    /// Certificate file not found or invalid
118    CertificateError(String),
119    /// Private key file not found or invalid
120    KeyError(String),
121    /// CA certificate file not found or invalid
122    CaError(String),
123    /// TLS handshake failed
124    HandshakeError(String),
125    /// I/O error during TLS operation
126    IoError(std::io::Error),
127    /// Configuration error
128    ConfigError(String),
129}
130
131impl std::fmt::Display for TlsError {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        match self {
134            TlsError::CertificateError(msg) => write!(f, "Certificate error: {}", msg),
135            TlsError::KeyError(msg) => write!(f, "Key error: {}", msg),
136            TlsError::CaError(msg) => write!(f, "CA error: {}", msg),
137            TlsError::HandshakeError(msg) => write!(f, "Handshake error: {}", msg),
138            TlsError::IoError(e) => write!(f, "I/O error: {}", e),
139            TlsError::ConfigError(msg) => write!(f, "Config error: {}", msg),
140        }
141    }
142}
143
144impl std::error::Error for TlsError {}
145
146impl From<std::io::Error> for TlsError {
147    fn from(e: std::io::Error) -> Self {
148        TlsError::IoError(e)
149    }
150}
151
152/// Result type for TLS operations
153pub type TlsResult<T> = std::result::Result<T, TlsError>;
154
155/// TLS acceptor for server-side connections
156#[derive(Clone)]
157pub struct TlsAcceptor {
158    config: TlsConfig,
159    #[cfg(feature = "tls")]
160    inner: Arc<tokio_rustls::TlsAcceptor>,
161}
162
163impl TlsAcceptor {
164    /// Create a new TLS acceptor with the given configuration
165    pub fn new(config: TlsConfig) -> TlsResult<Self> {
166        if config.cert_path.is_none() {
167            return Err(TlsError::ConfigError(
168                "Server TLS config requires certificate path".into(),
169            ));
170        }
171        if config.key_path.is_none() {
172            return Err(TlsError::ConfigError(
173                "Server TLS config requires key path".into(),
174            ));
175        }
176
177        #[cfg(feature = "tls")]
178        {
179            let inner = Self::build_acceptor(&config)?;
180            Ok(Self {
181                config,
182                inner: Arc::new(inner),
183            })
184        }
185
186        #[cfg(not(feature = "tls"))]
187        {
188            Ok(Self { config })
189        }
190    }
191
192    #[cfg(feature = "tls")]
193    fn build_acceptor(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsAcceptor> {
194        use rustls::pki_types::CertificateDer;
195
196        let cert_path = config
197            .cert_path
198            .as_ref()
199            .ok_or_else(|| TlsError::ConfigError("Certificate path required".into()))?;
200        let key_path = config
201            .key_path
202            .as_ref()
203            .ok_or_else(|| TlsError::ConfigError("Key path required".into()))?;
204
205        // Load certificates
206        let cert_file = File::open(cert_path)
207            .map_err(|e| TlsError::CertificateError(format!("{}: {}", cert_path, e)))?;
208        let mut cert_reader = BufReader::new(cert_file);
209        let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut cert_reader)
210            .filter_map(|r| r.ok())
211            .collect();
212
213        if certs.is_empty() {
214            return Err(TlsError::CertificateError("No certificates found".into()));
215        }
216
217        // Load private key
218        let key_file =
219            File::open(key_path).map_err(|e| TlsError::KeyError(format!("{}: {}", key_path, e)))?;
220        let mut key_reader = BufReader::new(key_file);
221        let key = rustls_pemfile::private_key(&mut key_reader)
222            .map_err(|e| TlsError::KeyError(format!("Failed to parse key: {}", e)))?
223            .ok_or_else(|| TlsError::KeyError("No private key found".into()))?;
224
225        // Build server config
226        let server_config = rustls::ServerConfig::builder()
227            .with_no_client_auth()
228            .with_single_cert(certs, key)
229            .map_err(|e| TlsError::ConfigError(format!("Failed to build config: {}", e)))?;
230
231        Ok(tokio_rustls::TlsAcceptor::from(Arc::new(server_config)))
232    }
233
234    /// Get the TLS configuration
235    pub fn config(&self) -> &TlsConfig {
236        &self.config
237    }
238
239    /// Accept a TLS connection (requires `tls` feature)
240    #[cfg(feature = "tls")]
241    pub async fn accept<S>(&self, stream: S) -> TlsResult<tokio_rustls::server::TlsStream<S>>
242    where
243        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
244    {
245        self.inner
246            .accept(stream)
247            .await
248            .map_err(|e| TlsError::HandshakeError(e.to_string()))
249    }
250}
251
252/// TLS connector for client-side connections
253#[derive(Clone)]
254pub struct TlsConnector {
255    config: TlsConfig,
256    #[cfg(feature = "tls")]
257    inner: Arc<tokio_rustls::TlsConnector>,
258}
259
260impl TlsConnector {
261    /// Create a new TLS connector with the given configuration
262    pub fn new(config: TlsConfig) -> TlsResult<Self> {
263        #[cfg(feature = "tls")]
264        {
265            let inner = Self::build_connector(&config)?;
266            Ok(Self {
267                config,
268                inner: Arc::new(inner),
269            })
270        }
271
272        #[cfg(not(feature = "tls"))]
273        {
274            Ok(Self { config })
275        }
276    }
277
278    #[cfg(feature = "tls")]
279    fn build_connector(config: &TlsConfig) -> TlsResult<tokio_rustls::TlsConnector> {
280        use rustls::RootCertStore;
281        use rustls::pki_types::CertificateDer;
282
283        let mut root_store = RootCertStore::empty();
284
285        // Load custom CA if provided
286        if let Some(ref ca_path) = config.ca_path {
287            let ca_file = File::open(ca_path)
288                .map_err(|e| TlsError::CaError(format!("{}: {}", ca_path, e)))?;
289            let mut ca_reader = BufReader::new(ca_file);
290            let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut ca_reader)
291                .filter_map(|r| r.ok())
292                .collect();
293
294            for cert in certs {
295                root_store
296                    .add(cert)
297                    .map_err(|e| TlsError::CaError(format!("Failed to add CA: {}", e)))?;
298            }
299        } else {
300            // Use system root certificates
301            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
302        }
303
304        // Build client config
305        let client_config = rustls::ClientConfig::builder()
306            .with_root_certificates(root_store)
307            .with_no_client_auth();
308
309        Ok(tokio_rustls::TlsConnector::from(Arc::new(client_config)))
310    }
311
312    /// Get the TLS configuration
313    pub fn config(&self) -> &TlsConfig {
314        &self.config
315    }
316
317    /// Connect with TLS (requires `tls` feature)
318    #[cfg(feature = "tls")]
319    pub async fn connect<S>(
320        &self,
321        server_name: &str,
322        stream: S,
323    ) -> TlsResult<tokio_rustls::client::TlsStream<S>>
324    where
325        S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
326    {
327        use rustls::pki_types::ServerName;
328
329        let name = ServerName::try_from(server_name.to_string())
330            .map_err(|e| TlsError::ConfigError(format!("Invalid server name: {}", e)))?;
331
332        self.inner
333            .connect(name, stream)
334            .await
335            .map_err(|e| TlsError::HandshakeError(e.to_string()))
336    }
337}
338
339#[cfg(test)]
340#[allow(clippy::unwrap_used)]
341mod tests {
342    use super::*;
343
344    #[test]
345    fn test_tls_config_server() {
346        let config = TlsConfig::server("cert.pem", "key.pem");
347        assert_eq!(config.cert_path, Some("cert.pem".to_string()));
348        assert_eq!(config.key_path, Some("key.pem".to_string()));
349        assert!(!config.verify_peer);
350    }
351
352    #[test]
353    fn test_tls_config_client() {
354        let config = TlsConfig::client(Some("ca.pem"));
355        assert!(config.cert_path.is_none());
356        assert_eq!(config.ca_path, Some("ca.pem".to_string()));
357        assert!(config.verify_peer);
358    }
359
360    #[test]
361    fn test_tls_config_mtls() {
362        let config = TlsConfig::mtls("cert.pem", "key.pem", "ca.pem");
363        assert_eq!(config.cert_path, Some("cert.pem".to_string()));
364        assert_eq!(config.key_path, Some("key.pem".to_string()));
365        assert_eq!(config.ca_path, Some("ca.pem".to_string()));
366        assert!(config.verify_peer);
367    }
368
369    #[test]
370    fn test_tls_acceptor_requires_cert() {
371        let config = TlsConfig::default();
372        let result = TlsAcceptor::new(config);
373        assert!(result.is_err());
374    }
375
376    #[test]
377    fn test_tls_acceptor_requires_key() {
378        let config = TlsConfig {
379            cert_path: Some("cert.pem".into()),
380            ..Default::default()
381        };
382        let result = TlsAcceptor::new(config);
383        assert!(result.is_err());
384    }
385
386    #[test]
387    fn test_tls_connector_default() {
388        let config = TlsConfig::client(None::<&str>);
389        let result = TlsConnector::new(config);
390        assert!(result.is_ok());
391    }
392}