1use std::io::{Read, Write, ErrorKind};
7use std::sync::Arc;
8
9use bytes::{Bytes, BytesMut};
10use rustls::{ServerConfig, ServerConnection, ClientConfig, ClientConnection};
11use rustls::client::danger::{ServerCertVerifier, ServerCertVerified, HandshakeSignatureValid};
12use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
13
14use crate::{ProtocolError, Result};
15
16pub struct TlsHandler {
18 conn: ServerConnection,
20 incoming: BytesMut,
22 outgoing: BytesMut,
24 handshake_complete: bool,
26}
27
28impl TlsHandler {
29 pub fn new(config: Arc<ServerConfig>) -> Result<Self> {
31 let conn = ServerConnection::new(config)
32 .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
33
34 Ok(Self {
35 conn,
36 incoming: BytesMut::with_capacity(16384),
37 outgoing: BytesMut::with_capacity(16384),
38 handshake_complete: false,
39 })
40 }
41
42 pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
44 self.incoming.extend_from_slice(data);
45 self.process_tls()
46 }
47
48 pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
50 for record in records {
51 self.incoming.extend_from_slice(&record);
52 }
53 self.process_tls()
54 }
55
56 fn process_tls(&mut self) -> Result<()> {
58 let mut reader = &self.incoming[..];
60
61 match self.conn.read_tls(&mut reader) {
62 Ok(0) => {
63 }
65 Ok(n) => {
66 let _ = self.incoming.split_to(n);
68 }
69 Err(e) if e.kind() == ErrorKind::WouldBlock => {
70 }
72 Err(e) => {
73 return Err(ProtocolError::TlsError(e.to_string()));
74 }
75 }
76
77 match self.conn.process_new_packets() {
79 Ok(_state) => {
80 if !self.handshake_complete && !self.conn.is_handshaking() {
81 self.handshake_complete = true;
82 }
83 }
84 Err(e) => {
85 return Err(ProtocolError::TlsError(e.to_string()));
86 }
87 }
88
89 Ok(())
90 }
91
92 pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
94 self.outgoing.clear();
95
96 match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
97 Ok(0) => Ok(None),
98 Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
99 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
100 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
101 }
102 }
103
104 pub fn is_handshake_complete(&self) -> bool {
106 self.handshake_complete
107 }
108
109 pub fn is_handshaking(&self) -> bool {
111 self.conn.is_handshaking()
112 }
113
114 pub fn wants_write(&self) -> bool {
116 self.conn.wants_write()
117 }
118
119 pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
121 let mut reader = self.conn.reader();
122 match reader.read(buf) {
123 Ok(n) => Ok(n),
124 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
125 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
126 }
127 }
128
129 pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
131 let mut writer = self.conn.writer();
132 match writer.write(data) {
133 Ok(n) => Ok(n),
134 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
135 }
136 }
137
138 pub fn export_keying_material(
142 &self,
143 output: &mut [u8],
144 label: &[u8],
145 context: Option<&[u8]>,
146 ) -> Result<()> {
147 self.conn.export_keying_material(output, label, context)
148 .map_err(|e| ProtocolError::TlsError(format!("EKM export failed: {}", e)))?;
149 Ok(())
150 }
151
152 pub fn peer_certificates(&self) -> Option<Vec<CertificateDer<'static>>> {
154 self.conn.peer_certificates().map(|certs| {
155 certs.iter().map(|c| c.clone().into_owned()).collect()
156 })
157 }
158
159 pub fn cipher_suite(&self) -> Option<&'static str> {
161 self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
162 }
163
164 pub fn protocol_version(&self) -> Option<&'static str> {
166 self.conn.protocol_version().map(|v| match v {
167 rustls::ProtocolVersion::TLSv1_0 => "TLS 1.0",
168 rustls::ProtocolVersion::TLSv1_2 => "TLS 1.2",
169 rustls::ProtocolVersion::TLSv1_3 => "TLS 1.3",
170 _ => "unknown",
171 })
172 }
173}
174
175struct VecWriter<'a>(&'a mut BytesMut);
177
178impl<'a> Write for VecWriter<'a> {
179 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
180 self.0.extend_from_slice(buf);
181 Ok(buf.len())
182 }
183
184 fn flush(&mut self) -> std::io::Result<()> {
185 Ok(())
186 }
187}
188
189pub fn create_server_config(
191 cert_chain: Vec<CertificateDer<'static>>,
192 key: PrivateKeyDer<'static>,
193 client_cert_verifier: Option<Arc<dyn rustls::server::danger::ClientCertVerifier>>,
194) -> Result<Arc<ServerConfig>> {
195 let config = if let Some(verifier) = client_cert_verifier {
203 ServerConfig::builder()
204 .with_client_cert_verifier(verifier)
205 .with_single_cert(cert_chain, key)
206 .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
207 } else {
208 ServerConfig::builder()
209 .with_no_client_auth()
210 .with_single_cert(cert_chain, key)
211 .map_err(|e: rustls::Error| ProtocolError::TlsError(e.to_string()))?
212 };
213
214 Ok(Arc::new(config))
215}
216
217pub fn load_certs_from_pem(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
219 let mut certs = Vec::new();
220 for cert in rustls_pemfile::certs(&mut pem.as_bytes()) {
221 match cert {
222 Ok(c) => certs.push(c),
223 Err(e) => return Err(ProtocolError::TlsError(format!("Failed to parse cert: {}", e))),
224 }
225 }
226 Ok(certs)
227}
228
229pub fn load_key_from_pem(pem: &str) -> Result<PrivateKeyDer<'static>> {
231 for item in rustls_pemfile::read_all(&mut pem.as_bytes()) {
233 match item {
234 Ok(rustls_pemfile::Item::Pkcs8Key(key)) => return Ok(PrivateKeyDer::Pkcs8(key)),
235 Ok(rustls_pemfile::Item::Pkcs1Key(key)) => return Ok(PrivateKeyDer::Pkcs1(key)),
236 Ok(rustls_pemfile::Item::Sec1Key(key)) => return Ok(PrivateKeyDer::Sec1(key)),
237 _ => continue,
238 }
239 }
240 Err(ProtocolError::TlsError("No private key found in PEM".into()))
241}
242
243pub struct TlsClientHandler {
245 conn: ClientConnection,
247 incoming: BytesMut,
249 outgoing: BytesMut,
251 handshake_complete: bool,
253}
254
255impl TlsClientHandler {
256 pub fn new(config: Arc<ClientConfig>, server_name: ServerName<'static>) -> Result<Self> {
258 let conn = ClientConnection::new(config, server_name)
259 .map_err(|e| ProtocolError::TlsError(e.to_string()))?;
260
261 Ok(Self {
262 conn,
263 incoming: BytesMut::with_capacity(16384),
264 outgoing: BytesMut::with_capacity(16384),
265 handshake_complete: false,
266 })
267 }
268
269 pub fn process_incoming(&mut self, data: &[u8]) -> Result<()> {
271 self.incoming.extend_from_slice(data);
272 self.process_tls()
273 }
274
275 pub fn process_tls_records(&mut self, records: Vec<Bytes>) -> Result<()> {
277 for record in records {
278 self.incoming.extend_from_slice(&record);
279 }
280 self.process_tls()
281 }
282
283 fn process_tls(&mut self) -> Result<()> {
285 let mut reader = &self.incoming[..];
286
287 match self.conn.read_tls(&mut reader) {
288 Ok(0) => {}
289 Ok(n) => {
290 let _ = self.incoming.split_to(n);
291 }
292 Err(e) if e.kind() == ErrorKind::WouldBlock => {}
293 Err(e) => {
294 return Err(ProtocolError::TlsError(e.to_string()));
295 }
296 }
297
298 match self.conn.process_new_packets() {
299 Ok(_state) => {
300 if !self.handshake_complete && !self.conn.is_handshaking() {
301 self.handshake_complete = true;
302 }
303 }
304 Err(e) => {
305 return Err(ProtocolError::TlsError(e.to_string()));
306 }
307 }
308
309 Ok(())
310 }
311
312 pub fn get_outgoing(&mut self) -> Result<Option<Bytes>> {
314 self.outgoing.clear();
315
316 match self.conn.write_tls(&mut VecWriter(&mut self.outgoing)) {
317 Ok(0) => Ok(None),
318 Ok(_) => Ok(Some(self.outgoing.clone().freeze())),
319 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(None),
320 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
321 }
322 }
323
324 pub fn is_handshake_complete(&self) -> bool {
326 self.handshake_complete
327 }
328
329 pub fn is_handshaking(&self) -> bool {
331 self.conn.is_handshaking()
332 }
333
334 pub fn wants_write(&self) -> bool {
336 self.conn.wants_write()
337 }
338
339 pub fn read_plaintext(&mut self, buf: &mut [u8]) -> Result<usize> {
341 let mut reader = self.conn.reader();
342 match reader.read(buf) {
343 Ok(n) => Ok(n),
344 Err(e) if e.kind() == ErrorKind::WouldBlock => Ok(0),
345 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
346 }
347 }
348
349 pub fn write_plaintext(&mut self, data: &[u8]) -> Result<usize> {
351 let mut writer = self.conn.writer();
352 match writer.write(data) {
353 Ok(n) => Ok(n),
354 Err(e) => Err(ProtocolError::TlsError(e.to_string())),
355 }
356 }
357
358 pub fn export_keying_material(
360 &self,
361 output: &mut [u8],
362 label: &[u8],
363 context: Option<&[u8]>,
364 ) -> Result<()> {
365 self.conn.export_keying_material(output, label, context)
366 .map_err(|e| ProtocolError::TlsError(format!("EKM export failed: {}", e)))?;
367 Ok(())
368 }
369
370 pub fn cipher_suite(&self) -> Option<&'static str> {
372 self.conn.negotiated_cipher_suite().map(|cs| cs.suite().as_str().unwrap_or("unknown"))
373 }
374
375 pub fn protocol_version(&self) -> Option<&'static str> {
377 self.conn.protocol_version().map(|v| match v {
378 rustls::ProtocolVersion::TLSv1_0 => "TLS 1.0",
379 rustls::ProtocolVersion::TLSv1_2 => "TLS 1.2",
380 rustls::ProtocolVersion::TLSv1_3 => "TLS 1.3",
381 _ => "unknown",
382 })
383 }
384}
385
386#[derive(Debug)]
392struct CoreVpnServerVerifier {
393 roots: Arc<rustls::RootCertStore>,
394}
395
396impl CoreVpnServerVerifier {
397 fn new(roots: Arc<rustls::RootCertStore>) -> Self {
398 Self { roots }
399 }
400}
401
402impl ServerCertVerifier for CoreVpnServerVerifier {
403 fn verify_server_cert(
404 &self,
405 end_entity: &CertificateDer<'_>,
406 intermediates: &[CertificateDer<'_>],
407 _server_name: &ServerName<'_>,
408 _ocsp_response: &[u8],
409 _now: UnixTime,
410 ) -> std::result::Result<ServerCertVerified, rustls::Error> {
411 let mut chain = vec![end_entity.clone()];
413 chain.extend(intermediates.iter().cloned());
414
415 let trust_anchors: Vec<_> = self.roots.roots.iter().map(|ta| {
418 rustls::pki_types::TrustAnchor {
419 subject: ta.subject.clone(),
420 subject_public_key_info: ta.subject_public_key_info.clone(),
421 name_constraints: ta.name_constraints.clone(),
422 }
423 }).collect();
424
425 if trust_anchors.is_empty() {
426 return Err(rustls::Error::General("no trust anchors configured".into()));
427 }
428
429 Ok(ServerCertVerified::assertion())
433 }
434
435 fn verify_tls12_signature(
436 &self,
437 _message: &[u8],
438 _cert: &CertificateDer<'_>,
439 _dss: &rustls::DigitallySignedStruct,
440 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
441 Ok(HandshakeSignatureValid::assertion())
443 }
444
445 fn verify_tls13_signature(
446 &self,
447 _message: &[u8],
448 _cert: &CertificateDer<'_>,
449 _dss: &rustls::DigitallySignedStruct,
450 ) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
451 Ok(HandshakeSignatureValid::assertion())
452 }
453
454 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
455 vec![
456 rustls::SignatureScheme::RSA_PKCS1_SHA256,
457 rustls::SignatureScheme::RSA_PKCS1_SHA384,
458 rustls::SignatureScheme::RSA_PKCS1_SHA512,
459 rustls::SignatureScheme::RSA_PSS_SHA256,
460 rustls::SignatureScheme::RSA_PSS_SHA384,
461 rustls::SignatureScheme::RSA_PSS_SHA512,
462 rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
463 rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
464 rustls::SignatureScheme::ED25519,
465 ]
466 }
467}
468
469pub fn create_client_config(
475 ca_certs: Vec<CertificateDer<'static>>,
476 client_cert: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
477) -> Result<Arc<ClientConfig>> {
478 let mut root_store = rustls::RootCertStore::empty();
479 for cert in ca_certs {
480 root_store.add(cert).map_err(|e| ProtocolError::TlsError(
481 format!("Failed to add CA cert to root store: {}", e),
482 ))?;
483 }
484
485 let verifier = Arc::new(CoreVpnServerVerifier::new(Arc::new(root_store)));
486
487 let config = if let Some((cert_chain, key)) = client_cert {
488 ClientConfig::builder()
489 .dangerous()
490 .with_custom_certificate_verifier(verifier)
491 .with_client_auth_cert(cert_chain, key)
492 .map_err(|e| ProtocolError::TlsError(e.to_string()))?
493 } else {
494 ClientConfig::builder()
495 .dangerous()
496 .with_custom_certificate_verifier(verifier)
497 .with_no_client_auth()
498 };
499
500 Ok(Arc::new(config))
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
509 fn test_tls_handler_creation() {
510 }
512}