Skip to main content

corevpn_protocol/
tls.rs

1//! TLS Integration for OpenVPN Control Channel
2//!
3//! Bridges rustls with the OpenVPN control channel transport.
4
5use std::io::{Read, Write, ErrorKind};
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use rustls::{ServerConfig, ServerConnection};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11
12use crate::{ProtocolError, Result};
13
14/// TLS handler for OpenVPN connections
15pub struct TlsHandler {
16    /// Rustls server connection
17    conn: ServerConnection,
18    /// Incoming data buffer (from control channel)
19    incoming: BytesMut,
20    /// Outgoing data buffer (to control channel)
21    outgoing: BytesMut,
22    /// Whether handshake is complete
23    handshake_complete: bool,
24}
25
26impl TlsHandler {
27    /// Create a new TLS handler with server configuration
28    pub fn new(config: Arc<ServerConfig>) -> Result<Self> {
29        let conn = ServerConnection::new(config)
30            .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
31
32        Ok(Self {
33            conn,
34            incoming: BytesMut::with_capacity(16384),
35            outgoing: BytesMut::with_capacity(16384),
36            handshake_complete: false,
37        })
38    }
39
40    /// Process incoming TLS data from control channel
41    pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
42        self.incoming.extend_from_slice(data);
43        self.process_tls()
44    }
45
46    /// Process incoming TLS records (already extracted from control channel)
47    pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
48        for record in records {
49            self.incoming.extend_from_slice(&record);
50        }
51        self.process_tls()
52    }
53
54    /// Internal TLS processing
55    fn process_tls(&mut self) -> Result<()> {
56        // Create a cursor for reading
57        let mut reader = &self.incoming[..];
58
59        match self.conn.read_tls(&mut reader) {
60            Ok(0) => {
61                // No data read
62            }
63            Ok(n) => {
64                // Remove consumed data
65                let _ = self.incoming.split_to(n);
66            }
67            Err(e) if e.kind() == ErrorKind::WouldBlock => {
68                // Need more data
69            }
70            Err(e) => {
71                return Err(ProtocolError::TlsError(e.to_string()));
72            }
73        }
74
75        // Process any TLS state changes
76        match self.conn.process_new_packets() {
77            Ok(_state) => {
78                if !self.handshake_complete && !self.conn.is_handshaking() {
79                    self.handshake_complete = true;
80                }
81            }
82            Err(e) => {
83                return Err(ProtocolError::TlsError(e.to_string()));
84            }
85        }
86
87        Ok(())
88    }
89
90    /// Get data to send on control channel
91    pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
92        self.outgoing.clear();
93
94        match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
95            Ok(0) => Ok(None),
96            Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
97            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
98            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
99        }
100    }
101
102    /// Check if handshake is complete
103    pub fn is_handshake_complete(&self) -> bool {
104        self.handshake_complete
105    }
106
107    /// Check if we're still handshaking
108    pub fn is_handshaking(&self) -> bool {
109        self.conn.is_handshaking()
110    }
111
112    /// Check if there's data waiting to be written
113    pub fn wants_write(&self) -> bool {
114        self.conn.wants_write()
115    }
116
117    /// Read decrypted application data
118    pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
119        let mut reader = self.conn.reader();
120        match reader.read(buf) {
121            Ok(n) => Ok(n),
122            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
123            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
124        }
125    }
126
127    /// Write plaintext data (will be encrypted)
128    pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
129        let mut writer = self.conn.writer();
130        match writer.write(data) {
131            Ok(n) => Ok(n),
132            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
133        }
134    }
135
136    /// Export keying material from the TLS session (RFC 5705)
137    ///
138    /// Used by OpenVPN for data channel key derivation with TLS 1.3.
139    pub fn export_keying_material(
140        &self,
141        output: &mut [u8],
142        label: &[u8],
143        context: Option<&[u8]>,
144    ) -> Result<()> {
145        self.conn.export_keying_material(output, label, context)
146            .map_err(|e| ProtocolError::TlsError(format!("EKM export failed: {}", e)))?;
147        Ok(())
148    }
149
150    /// Get peer certificate if available
151    pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
152        self.conn.peer_certificates().map(|certs| {
153            certs.iter().map(|c| c.clone().into_owned()).collect()
154        })
155    }
156
157    /// Get negotiated cipher suite name
158    pub fn cipher_suite(&self) -> Option<&'static str> {
159        self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
160    }
161}
162
163/// Helper to write to BytesMut
164struct VecWriter<'a>(&'a mut BytesMut);
165
166impl<'a> Write for VecWriter<'a> {
167    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
168        self.0.extend_from_slice(buf);
169        Ok(buf.len())
170    }
171
172    fn flush(&mut self) -> std::io::Result<()> {
173        Ok(())
174    }
175}
176
177/// Create TLS server config from certificates and key
178pub fn create_server_config(
179    cert_chain: Vec<CertificateDer<'static>>,
180    key: PrivateKeyDer<'static>,
181    client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
182) -> Result<Arc<ServerConfig>> {
183    // Security: rustls 0.23+ uses safe defaults automatically:
184    // - TLS 1.3 is the minimum (TLS 1.2 weak ciphers not available)
185    // - Only secure cipher suites are available:
186    //   - TLS13_CHACHA20_POLY1305_SHA256
187    //   - TLS13_AES_256_GCM_SHA384
188    //   - TLS13_AES_128_GCM_SHA256
189    // - Weak ciphers (RC4, DES, etc.) are not supported
190    let config = if let Some(verifier) = client_cert_verifier {
191        ServerConfig::builder()
192            .with_client_cert_verifier(verifier)
193            .with_single_cert(cert_chain, key)
194            .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
195    } else {
196        ServerConfig::builder()
197            .with_no_client_auth()
198            .with_single_cert(cert_chain, key)
199            .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
200    };
201
202    Ok(Arc::new(config))
203}
204
205/// Load certificate chain from PEM
206pub fn load_certs_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
207    let mut certs = Vec::new();
208    for cert in rustls_pemfile::certs(&mut pem.as_bytes()) {
209        match cert {
210            Ok(c) => certs.push(c),
211            Err(e) => return Err(ProtocolError::TlsError(format!("Failed to parse cert: {}", e))),
212        }
213    }
214    Ok(certs)
215}
216
217/// Load private key from PEM
218pub fn load_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>> {
219    // Try PKCS8 first, then RSA, then EC
220    for item in rustls_pemfile::read_all(&mut pem.as_bytes()) {
221        match item {
222            Ok(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(PrivateKeyDer::Pkcs8(key)),
223            Ok(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(PrivateKeyDer::Pkcs1(key)),
224            Ok(rustls_pemfile::Item::Sec1Key(key)) => return Ok(PrivateKeyDer::Sec1(key)),
225            _ => continue,
226        }
227    }
228    Err(ProtocolError::TlsError("No private key found in PEM".into()))
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    // Basic test - just verifies compilation
236    #[test]
237    fn test_tls_handler_creation() {
238        // Would need valid certs to create a real handler
239    }
240}