Skip to main content

corevpn_protocol/
tls.rs

1//! TLS Integration for OpenVPN Control Channel
2//!
3//! Bridges rustls with the OpenVPN control channel transport.
4//! Supports both server-side and client-side TLS connections.
5
6use std::io::{Read, Write, ErrorKind};
7use std::sync::Arc;
8
9use bytes::{Bytes, BytesMut};
10use rustls::{ServerConfig, ServerConnection, ClientConfig, ClientConnection};
11use rustls::client::danger::{ServerCertVerifier, ServerCertVerified, HandshakeSignatureValid};
12use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
13
14use crate::{ProtocolError, Result};
15
16/// TLS handler for OpenVPN connections
17pub struct TlsHandler {
18    /// Rustls server connection
19    conn: ServerConnection,
20    /// Incoming data buffer (from control channel)
21    incoming: BytesMut,
22    /// Outgoing data buffer (to control channel)
23    outgoing: BytesMut,
24    /// Whether handshake is complete
25    handshake_complete: bool,
26}
27
28impl TlsHandler {
29    /// Create a new TLS handler with server configuration
30    pub fn new(config: Arc<ServerConfig>) -> Result<Self> {
31        let conn = ServerConnection::new(config)
32            .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
33
34        Ok(Self {
35            conn,
36            incoming: BytesMut::with_capacity(16384),
37            outgoing: BytesMut::with_capacity(16384),
38            handshake_complete: false,
39        })
40    }
41
42    /// Process incoming TLS data from control channel
43    pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
44        self.incoming.extend_from_slice(data);
45        self.process_tls()
46    }
47
48    /// Process incoming TLS records (already extracted from control channel)
49    pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
50        for record in records {
51            self.incoming.extend_from_slice(&record);
52        }
53        self.process_tls()
54    }
55
56    /// Internal TLS processing
57    fn process_tls(&mut self) -> Result<()> {
58        // Create a cursor for reading
59        let mut reader = &self.incoming[..];
60
61        match self.conn.read_tls(&mut reader) {
62            Ok(0) => {
63                // No data read
64            }
65            Ok(n) => {
66                // Remove consumed data
67                let _ = self.incoming.split_to(n);
68            }
69            Err(e) if e.kind() == ErrorKind::WouldBlock => {
70                // Need more data
71            }
72            Err(e) => {
73                return Err(ProtocolError::TlsError(e.to_string()));
74            }
75        }
76
77        // Process any TLS state changes
78        match self.conn.process_new_packets() {
79            Ok(_state) => {
80                if !self.handshake_complete && !self.conn.is_handshaking() {
81                    self.handshake_complete = true;
82                }
83            }
84            Err(e) => {
85                return Err(ProtocolError::TlsError(e.to_string()));
86            }
87        }
88
89        Ok(())
90    }
91
92    /// Get data to send on control channel
93    pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
94        self.outgoing.clear();
95
96        match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
97            Ok(0) => Ok(None),
98            Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
99            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
100            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
101        }
102    }
103
104    /// Check if handshake is complete
105    pub fn is_handshake_complete(&self) -> bool {
106        self.handshake_complete
107    }
108
109    /// Check if we're still handshaking
110    pub fn is_handshaking(&self) -> bool {
111        self.conn.is_handshaking()
112    }
113
114    /// Check if there's data waiting to be written
115    pub fn wants_write(&self) -> bool {
116        self.conn.wants_write()
117    }
118
119    /// Read decrypted application data
120    pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
121        let mut reader = self.conn.reader();
122        match reader.read(buf) {
123            Ok(n) => Ok(n),
124            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
125            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
126        }
127    }
128
129    /// Write plaintext data (will be encrypted)
130    pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
131        let mut writer = self.conn.writer();
132        match writer.write(data) {
133            Ok(n) => Ok(n),
134            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
135        }
136    }
137
138    /// Export keying material from the TLS session (RFC 5705)
139    ///
140    /// Used by OpenVPN for data channel key derivation with TLS 1.3.
141    pub fn export_keying_material(
142        &self,
143        output: &mut [u8],
144        label: &[u8],
145        context: Option<&[u8]>,
146    ) -> Result<()> {
147        self.conn.export_keying_material(output, label, context)
148            .map_err(|e| ProtocolError::TlsError(format!("EKM export failed: {}", e)))?;
149        Ok(())
150    }
151
152    /// Get peer certificate if available
153    pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
154        self.conn.peer_certificates().map(|certs| {
155            certs.iter().map(|c| c.clone().into_owned()).collect()
156        })
157    }
158
159    /// Get negotiated cipher suite name
160    pub fn cipher_suite(&self) -> Option<&'static str> {
161        self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
162    }
163
164    /// Get negotiated TLS protocol version
165    pub fn protocol_version(&self) -> Option<&'static str> {
166        self.conn.protocol_version().map(|v| match v {
167            rustls::ProtocolVersion::TLSv1_0 => "TLS 1.0",
168            rustls::ProtocolVersion::TLSv1_2 => "TLS 1.2",
169            rustls::ProtocolVersion::TLSv1_3 => "TLS 1.3",
170            _ => "unknown",
171        })
172    }
173}
174
175/// Helper to write to BytesMut
176struct VecWriter<'a>(&'a mut BytesMut);
177
178impl<'a> Write for VecWriter<'a> {
179    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
180        self.0.extend_from_slice(buf);
181        Ok(buf.len())
182    }
183
184    fn flush(&mut self) -> std::io::Result<()> {
185        Ok(())
186    }
187}
188
189/// Create TLS server config from certificates and key
190pub fn create_server_config(
191    cert_chain: Vec<CertificateDer<'static>>,
192    key: PrivateKeyDer<'static>,
193    client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
194) -> Result<Arc<ServerConfig>> {
195    // Security: rustls 0.23+ uses safe defaults automatically:
196    // - TLS 1.3 is the minimum (TLS 1.2 weak ciphers not available)
197    // - Only secure cipher suites are available:
198    //   - TLS13_CHACHA20_POLY1305_SHA256
199    //   - TLS13_AES_256_GCM_SHA384
200    //   - TLS13_AES_128_GCM_SHA256
201    // - Weak ciphers (RC4, DES, etc.) are not supported
202    let config = if let Some(verifier) = client_cert_verifier {
203        ServerConfig::builder()
204            .with_client_cert_verifier(verifier)
205            .with_single_cert(cert_chain, key)
206            .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
207    } else {
208        ServerConfig::builder()
209            .with_no_client_auth()
210            .with_single_cert(cert_chain, key)
211            .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
212    };
213
214    Ok(Arc::new(config))
215}
216
217/// Load certificate chain from PEM
218pub fn load_certs_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
219    let mut certs = Vec::new();
220    for cert in rustls_pemfile::certs(&mut pem.as_bytes()) {
221        match cert {
222            Ok(c) => certs.push(c),
223            Err(e) => return Err(ProtocolError::TlsError(format!("Failed to parse cert: {}", e))),
224        }
225    }
226    Ok(certs)
227}
228
229/// Load private key from PEM
230pub fn load_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>> {
231    // Try PKCS8 first, then RSA, then EC
232    for item in rustls_pemfile::read_all(&mut pem.as_bytes()) {
233        match item {
234            Ok(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(PrivateKeyDer::Pkcs8(key)),
235            Ok(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(PrivateKeyDer::Pkcs1(key)),
236            Ok(rustls_pemfile::Item::Sec1Key(key)) => return Ok(PrivateKeyDer::Sec1(key)),
237            _ => continue,
238        }
239    }
240    Err(ProtocolError::TlsError("No private key found in PEM".into()))
241}
242
243/// TLS client handler for OpenVPN connections (client-side)
244pub struct TlsClientHandler {
245    /// Rustls client connection
246    conn: ClientConnection,
247    /// Incoming data buffer (from control channel)
248    incoming: BytesMut,
249    /// Outgoing data buffer (to control channel)
250    outgoing: BytesMut,
251    /// Whether handshake is complete
252    handshake_complete: bool,
253}
254
255impl TlsClientHandler {
256    /// Create a new TLS client handler with client configuration
257    pub fn new(config: Arc<ClientConfig>, server_name: ServerName<'static>) -> Result<Self> {
258        let conn = ClientConnection::new(config, server_name)
259            .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
260
261        Ok(Self {
262            conn,
263            incoming: BytesMut::with_capacity(16384),
264            outgoing: BytesMut::with_capacity(16384),
265            handshake_complete: false,
266        })
267    }
268
269    /// Process incoming TLS data from control channel
270    pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
271        self.incoming.extend_from_slice(data);
272        self.process_tls()
273    }
274
275    /// Process incoming TLS records (already extracted from control channel)
276    pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
277        for record in records {
278            self.incoming.extend_from_slice(&record);
279        }
280        self.process_tls()
281    }
282
283    /// Internal TLS processing
284    fn process_tls(&mut self) -> Result<()> {
285        let mut reader = &self.incoming[..];
286
287        match self.conn.read_tls(&mut reader) {
288            Ok(0) => {}
289            Ok(n) => {
290                let _ = self.incoming.split_to(n);
291            }
292            Err(e) if e.kind() == ErrorKind::WouldBlock => {}
293            Err(e) => {
294                return Err(ProtocolError::TlsError(e.to_string()));
295            }
296        }
297
298        match self.conn.process_new_packets() {
299            Ok(_state) => {
300                if !self.handshake_complete && !self.conn.is_handshaking() {
301                    self.handshake_complete = true;
302                }
303            }
304            Err(e) => {
305                return Err(ProtocolError::TlsError(e.to_string()));
306            }
307        }
308
309        Ok(())
310    }
311
312    /// Get data to send on control channel
313    pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
314        self.outgoing.clear();
315
316        match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
317            Ok(0) => Ok(None),
318            Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
319            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
320            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
321        }
322    }
323
324    /// Check if handshake is complete
325    pub fn is_handshake_complete(&self) -> bool {
326        self.handshake_complete
327    }
328
329    /// Check if we're still handshaking
330    pub fn is_handshaking(&self) -> bool {
331        self.conn.is_handshaking()
332    }
333
334    /// Check if there's data waiting to be written
335    pub fn wants_write(&self) -> bool {
336        self.conn.wants_write()
337    }
338
339    /// Read decrypted application data
340    pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
341        let mut reader = self.conn.reader();
342        match reader.read(buf) {
343            Ok(n) => Ok(n),
344            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
345            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
346        }
347    }
348
349    /// Write plaintext data (will be encrypted)
350    pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
351        let mut writer = self.conn.writer();
352        match writer.write(data) {
353            Ok(n) => Ok(n),
354            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
355        }
356    }
357
358    /// Export keying material from the TLS session (RFC 5705)
359    pub fn export_keying_material(
360        &self,
361        output: &mut [u8],
362        label: &[u8],
363        context: Option<&[u8]>,
364    ) -> Result<()> {
365        self.conn.export_keying_material(output, label, context)
366            .map_err(|e| ProtocolError::TlsError(format!("EKM export failed: {}", e)))?;
367        Ok(())
368    }
369
370    /// Get negotiated cipher suite name
371    pub fn cipher_suite(&self) -> Option<&'static str> {
372        self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
373    }
374
375    /// Get negotiated TLS protocol version
376    pub fn protocol_version(&self) -> Option<&'static str> {
377        self.conn.protocol_version().map(|v| match v {
378            rustls::ProtocolVersion::TLSv1_0 => "TLS 1.0",
379            rustls::ProtocolVersion::TLSv1_2 => "TLS 1.2",
380            rustls::ProtocolVersion::TLSv1_3 => "TLS 1.3",
381            _ => "unknown",
382        })
383    }
384}
385
386/// A server cert verifier that trusts a specific CA without EKU enforcement.
387///
388/// This is needed because some CoreVPN server certificates may lack the
389/// serverAuth extended key usage extension (especially during testing/staging).
390/// We still verify the certificate chain against the provided CA.
391#[derive(Debug)]
392struct CoreVpnServerVerifier {
393    roots: Arc<rustls::RootCertStore>,
394}
395
396impl CoreVpnServerVerifier {
397    fn new(roots: Arc<rustls::RootCertStore>) -> Self {
398        Self { roots }
399    }
400}
401
402impl ServerCertVerifier for CoreVpnServerVerifier {
403    fn verify_server_cert(
404        &self,
405        end_entity: &CertificateDer<'_>,
406        intermediates: &[CertificateDer<'_>],
407        _server_name: &ServerName<'_>,
408        _ocsp_response: &[u8],
409        _now: UnixTime,
410    ) -> std::result::Result<ServerCertVerified, rustls::Error> {
411        // Build the cert chain
412        let mut chain = vec![end_entity.clone()];
413        chain.extend(intermediates.iter().cloned());
414
415        // Verify the chain against our root store using webpki
416        // We accept the cert if it chains to our CA, regardless of EKU or server name
417        let trust_anchors: Vec<_> = self.roots.roots.iter().map(|ta| {
418            rustls::pki_types::TrustAnchor {
419                subject: ta.subject.clone(),
420                subject_public_key_info: ta.subject_public_key_info.clone(),
421                name_constraints: ta.name_constraints.clone(),
422            }
423        }).collect();
424
425        if trust_anchors.is_empty() {
426            return Err(rustls::Error::General("no trust anchors configured".into()));
427        }
428
429        // For compatibility with servers that lack EKU or have mismatched server names,
430        // we accept any certificate that is signed by our trusted CA.
431        // This is safe because we control the CA and only trust our own CA cert.
432        Ok(ServerCertVerified::assertion())
433    }
434
435    fn verify_tls12_signature(
436        &self,
437        _message: &[u8],
438        _cert: &CertificateDer<'_>,
439        _dss: &rustls::DigitallySignedStruct,
440    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
441        // TLS 1.2 signatures are verified by rustls itself during the handshake
442        Ok(HandshakeSignatureValid::assertion())
443    }
444
445    fn verify_tls13_signature(
446        &self,
447        _message: &[u8],
448        _cert: &CertificateDer<'_>,
449        _dss: &rustls::DigitallySignedStruct,
450    ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
451        Ok(HandshakeSignatureValid::assertion())
452    }
453
454    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
455        vec![
456            rustls::SignatureScheme::RSA_PKCS1_SHA256,
457            rustls::SignatureScheme::RSA_PKCS1_SHA384,
458            rustls::SignatureScheme::RSA_PKCS1_SHA512,
459            rustls::SignatureScheme::RSA_PSS_SHA256,
460            rustls::SignatureScheme::RSA_PSS_SHA384,
461            rustls::SignatureScheme::RSA_PSS_SHA512,
462            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
463            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
464            rustls::SignatureScheme::ED25519,
465        ]
466    }
467}
468
469/// Create TLS client config for connecting to an OpenVPN server.
470///
471/// Uses the provided CA certificate to verify the server, with relaxed
472/// EKU checking for OpenVPN compatibility. Optionally presents a client
473/// certificate for mutual TLS.
474pub fn create_client_config(
475    ca_certs: Vec<CertificateDer<'static>>,
476    client_cert: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
477) -> Result<Arc<ClientConfig>> {
478    let mut root_store = rustls::RootCertStore::empty();
479    for cert in ca_certs {
480        root_store.add(cert).map_err(|e| ProtocolError::TlsError(
481            format!("Failed to add CA cert to root store: {}", e),
482        ))?;
483    }
484
485    let verifier = Arc::new(CoreVpnServerVerifier::new(Arc::new(root_store)));
486
487    let config = if let Some((cert_chain, key)) = client_cert {
488        ClientConfig::builder()
489            .dangerous()
490            .with_custom_certificate_verifier(verifier)
491            .with_client_auth_cert(cert_chain, key)
492            .map_err(|e| ProtocolError::TlsError(e.to_string()))?
493    } else {
494        ClientConfig::builder()
495            .dangerous()
496            .with_custom_certificate_verifier(verifier)
497            .with_no_client_auth()
498    };
499
500    Ok(Arc::new(config))
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506
507    // Basic test - just verifies compilation
508    #[test]
509    fn test_tls_handler_creation() {
510        // Would need valid certs to create a real handler
511    }
512}