corevpn_protocol/
tls.rs

1//! TLS Integration for OpenVPN Control Channel
2//!
3//! Bridges rustls with the OpenVPN control channel transport.
4
5use std::io::{Read, Write, ErrorKind};
6use std::sync::Arc;
7
8use bytes::{Bytes, BytesMut};
9use rustls::{ServerConfig, ServerConnection};
10use rustls::pki_types::{CertificateDer, PrivateKeyDer};
11
12use crate::{ProtocolError, Result};
13
14/// TLS handler for OpenVPN connections
15pub struct TlsHandler {
16    /// Rustls server connection
17    conn: ServerConnection,
18    /// Incoming data buffer (from control channel)
19    incoming: BytesMut,
20    /// Outgoing data buffer (to control channel)
21    outgoing: BytesMut,
22    /// Whether handshake is complete
23    handshake_complete: bool,
24}
25
26impl TlsHandler {
27    /// Create a new TLS handler with server configuration
28    pub fn new(config: Arc<ServerConfig>) -> Result<Self> {
29        let conn = ServerConnection::new(config)
30            .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
31
32        Ok(Self {
33            conn,
34            incoming: BytesMut::with_capacity(16384),
35            outgoing: BytesMut::with_capacity(16384),
36            handshake_complete: false,
37        })
38    }
39
40    /// Process incoming TLS data from control channel
41    pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
42        self.incoming.extend_from_slice(data);
43        self.process_tls()
44    }
45
46    /// Process incoming TLS records (already extracted from control channel)
47    pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
48        for record in records {
49            self.incoming.extend_from_slice(&record);
50        }
51        self.process_tls()
52    }
53
54    /// Internal TLS processing
55    fn process_tls(&mut self) -> Result<()> {
56        // Create a cursor for reading
57        let mut reader = &self.incoming[..];
58
59        match self.conn.read_tls(&mut reader) {
60            Ok(0) => {
61                // No data read
62            }
63            Ok(n) => {
64                // Remove consumed data
65                let _ = self.incoming.split_to(n);
66            }
67            Err(e) if e.kind() == ErrorKind::WouldBlock => {
68                // Need more data
69            }
70            Err(e) => {
71                return Err(ProtocolError::TlsError(e.to_string()));
72            }
73        }
74
75        // Process any TLS state changes
76        match self.conn.process_new_packets() {
77            Ok(_state) => {
78                if !self.handshake_complete && !self.conn.is_handshaking() {
79                    self.handshake_complete = true;
80                }
81            }
82            Err(e) => {
83                return Err(ProtocolError::TlsError(e.to_string()));
84            }
85        }
86
87        Ok(())
88    }
89
90    /// Get data to send on control channel
91    pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
92        self.outgoing.clear();
93
94        match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
95            Ok(0) => Ok(None),
96            Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
97            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
98            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
99        }
100    }
101
102    /// Check if handshake is complete
103    pub fn is_handshake_complete(&self) -> bool {
104        self.handshake_complete
105    }
106
107    /// Check if we're still handshaking
108    pub fn is_handshaking(&self) -> bool {
109        self.conn.is_handshaking()
110    }
111
112    /// Check if there's data waiting to be written
113    pub fn wants_write(&self) -> bool {
114        self.conn.wants_write()
115    }
116
117    /// Read decrypted application data
118    pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
119        let mut reader = self.conn.reader();
120        match reader.read(buf) {
121            Ok(n) => Ok(n),
122            Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
123            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
124        }
125    }
126
127    /// Write plaintext data (will be encrypted)
128    pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
129        let mut writer = self.conn.writer();
130        match writer.write(data) {
131            Ok(n) => Ok(n),
132            Err(e) => Err(ProtocolError::TlsError(e.to_string())),
133        }
134    }
135
136    /// Get peer certificate if available
137    pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
138        self.conn.peer_certificates().map(|certs| {
139            certs.iter().map(|c| c.clone().into_owned()).collect()
140        })
141    }
142
143    /// Get negotiated cipher suite name
144    pub fn cipher_suite(&self) -> Option<&'static str> {
145        self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
146    }
147}
148
149/// Helper to write to BytesMut
150struct VecWriter<'a>(&'a mut BytesMut);
151
152impl<'a> Write for VecWriter<'a> {
153    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
154        self.0.extend_from_slice(buf);
155        Ok(buf.len())
156    }
157
158    fn flush(&mut self) -> std::io::Result<()> {
159        Ok(())
160    }
161}
162
163/// Create TLS server config from certificates and key
164pub fn create_server_config(
165    cert_chain: Vec<CertificateDer<'static>>,
166    key: PrivateKeyDer<'static>,
167    client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
168) -> Result<Arc<ServerConfig>> {
169
170
171
172    let config = if let Some(verifier) = client_cert_verifier {
173        ServerConfig::builder()
174            .with_client_cert_verifier(verifier)
175            .with_single_cert(cert_chain, key)
176            .map_err(|e| ProtocolError::TlsError(e.to_string()))?
177    } else {
178        ServerConfig::builder()
179            .with_no_client_auth()
180            .with_single_cert(cert_chain, key)
181            .map_err(|e| ProtocolError::TlsError(e.to_string()))?
182    };
183
184    // Configure for maximum security
185    // TLS 1.3 is already the minimum in rustls 0.23
186
187    Ok(Arc::new(config))
188}
189
190/// Load certificate chain from PEM
191pub fn load_certs_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
192    let mut certs = Vec::new();
193    for cert in rustls_pemfile::certs(&mut pem.as_bytes()) {
194        match cert {
195            Ok(c) => certs.push(c),
196            Err(e) => return Err(ProtocolError::TlsError(format!("Failed to parse cert: {}", e))),
197        }
198    }
199    Ok(certs)
200}
201
202/// Load private key from PEM
203pub fn load_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>> {
204    // Try PKCS8 first, then RSA, then EC
205    for item in rustls_pemfile::read_all(&mut pem.as_bytes()) {
206        match item {
207            Ok(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(PrivateKeyDer::Pkcs8(key)),
208            Ok(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(PrivateKeyDer::Pkcs1(key)),
209            Ok(rustls_pemfile::Item::Sec1Key(key)) => return Ok(PrivateKeyDer::Sec1(key)),
210            _ => continue,
211        }
212    }
213    Err(ProtocolError::TlsError("No private key found in PEM".into()))
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    // Basic test - just verifies compilation
221    #[test]
222    fn test_tls_handler_creation() {
223        // Would need valid certs to create a real handler
224    }
225}