ant_quic/crypto/
tls_extension_simulation.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8//! TLS Extension Simulation for RFC 7250 Raw Public Keys
9//!
10//! Since rustls 0.23.x doesn't expose APIs for custom TLS extensions,
11//! this module simulates the RFC 7250 certificate type negotiation
12//! through alternative mechanisms that work within rustls constraints.
13
14use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
15use rustls::{ClientConfig, ServerConfig};
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18
19use super::tls_extensions::{
20    CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
21};
22
23/// Trait for hooking into TLS handshake events
24pub trait TlsExtensionHooks: Send + Sync {
25    /// Called when the handshake is complete
26    fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
27
28    /// Called to get extension data for ClientHello
29    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
30
31    /// Called to process ServerHello extensions
32    fn process_server_hello_extensions(
33        &self,
34        conn_id: &str,
35        extensions: &[(u16, Vec<u8>)],
36    ) -> Result<(), TlsExtensionError>;
37
38    /// Get the negotiation result for a connection
39    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
40}
41
42/// Simulated TLS extension context for certificate type negotiation
43#[derive(Debug)]
44pub struct SimulatedExtensionContext {
45    /// Active negotiations indexed by connection ID
46    negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
47    /// Local preferences for this endpoint
48    local_preferences: CertificateTypePreferences,
49}
50
51#[derive(Debug, Clone)]
52struct NegotiationState {
53    local_preferences: CertificateTypePreferences,
54    remote_client_types: Option<CertificateTypeList>,
55    remote_server_types: Option<CertificateTypeList>,
56    result: Option<NegotiationResult>,
57}
58
59impl SimulatedExtensionContext {
60    /// Create a new simulated extension context
61    pub fn new(preferences: CertificateTypePreferences) -> Self {
62        Self {
63            negotiations: Arc::new(Mutex::new(HashMap::new())),
64            local_preferences: preferences,
65        }
66    }
67
68    /// Simulate sending certificate type preferences
69    /// In reality, this would be sent in ClientHello/ServerHello extensions
70    #[allow(clippy::unwrap_used, clippy::expect_used)]
71    pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
72        let mut negotiations = self
73            .negotiations
74            .lock()
75            .expect("Mutex poisoning is unexpected in normal operation");
76
77        let state = NegotiationState {
78            local_preferences: self.local_preferences.clone(),
79            remote_client_types: None,
80            remote_server_types: None,
81            result: None,
82        };
83
84        negotiations.insert(conn_id.to_string(), state);
85
86        // Simulate extension data that would be sent
87        let client_ext_data = self.local_preferences.client_types.to_bytes();
88        let server_ext_data = self.local_preferences.server_types.to_bytes();
89
90        (Some(client_ext_data), Some(server_ext_data))
91    }
92
93    /// Simulate receiving certificate type preferences from peer
94    #[allow(clippy::unwrap_used, clippy::expect_used)]
95    pub fn simulate_receive_preferences(
96        &self,
97        conn_id: &str,
98        client_types_data: Option<&[u8]>,
99        server_types_data: Option<&[u8]>,
100    ) -> Result<(), TlsExtensionError> {
101        let mut negotiations = self
102            .negotiations
103            .lock()
104            .expect("Mutex poisoning is unexpected in normal operation");
105
106        let state = negotiations.get_mut(conn_id).ok_or_else(|| {
107            TlsExtensionError::InvalidExtensionData(format!(
108                "No negotiation state for connection {conn_id}"
109            ))
110        })?;
111
112        if let Some(data) = client_types_data {
113            state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
114        }
115
116        if let Some(data) = server_types_data {
117            state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
118        }
119
120        Ok(())
121    }
122
123    /// Complete the negotiation and get the result
124    #[allow(clippy::unwrap_used, clippy::expect_used)]
125    pub fn complete_negotiation(
126        &self,
127        conn_id: &str,
128    ) -> Result<NegotiationResult, TlsExtensionError> {
129        let mut negotiations = self
130            .negotiations
131            .lock()
132            .expect("Mutex poisoning is unexpected in normal operation");
133
134        let state = negotiations.get_mut(conn_id).ok_or_else(|| {
135            TlsExtensionError::InvalidExtensionData(format!(
136                "No negotiation state for connection {conn_id}"
137            ))
138        })?;
139
140        if let Some(result) = &state.result {
141            return Ok(result.clone());
142        }
143
144        let result = state.local_preferences.negotiate(
145            state.remote_client_types.as_ref(),
146            state.remote_server_types.as_ref(),
147        )?;
148
149        state.result = Some(result.clone());
150        Ok(result)
151    }
152
153    /// Clean up negotiation state for a connection
154    #[allow(clippy::unwrap_used, clippy::expect_used)]
155    pub fn cleanup_connection(&self, conn_id: &str) {
156        let mut negotiations = self
157            .negotiations
158            .lock()
159            .expect("Mutex poisoning is unexpected in normal operation");
160        negotiations.remove(conn_id);
161    }
162}
163
164impl TlsExtensionHooks for SimulatedExtensionContext {
165    fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
166        // Try to complete negotiation if not already done
167        let _ = self.complete_negotiation(conn_id);
168    }
169
170    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
171        let (client_types, server_types) = self.simulate_send_preferences(conn_id);
172
173        let mut extensions = Vec::new();
174
175        if let Some(data) = client_types {
176            extensions.push((47, data)); // client_certificate_type
177        }
178
179        if let Some(data) = server_types {
180            extensions.push((48, data)); // server_certificate_type
181        }
182
183        extensions
184    }
185
186    fn process_server_hello_extensions(
187        &self,
188        conn_id: &str,
189        extensions: &[(u16, Vec<u8>)],
190    ) -> Result<(), TlsExtensionError> {
191        let mut client_types_data = None;
192        let mut server_types_data = None;
193
194        for (ext_id, data) in extensions {
195            match *ext_id {
196                47 => client_types_data = Some(data.as_slice()),
197                48 => server_types_data = Some(data.as_slice()),
198                _ => {}
199            }
200        }
201
202        self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
203    }
204
205    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
206        self.complete_negotiation(conn_id).ok()
207    }
208}
209
210/// Wrapper for ClientConfig that simulates RFC 7250 extension behavior
211pub struct Rfc7250ClientConfig {
212    inner: Arc<ClientConfig>,
213    extension_context: Arc<SimulatedExtensionContext>,
214}
215
216impl Rfc7250ClientConfig {
217    /// Create a new RFC 7250 aware client configuration
218    pub fn new(base_config: ClientConfig, preferences: CertificateTypePreferences) -> Self {
219        Self {
220            inner: Arc::new(base_config),
221            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
222        }
223    }
224
225    /// Get the inner rustls ClientConfig
226    pub fn inner(&self) -> &Arc<ClientConfig> {
227        &self.inner
228    }
229
230    /// Get the extension context for negotiation
231    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
232        &self.extension_context
233    }
234
235    /// Simulate the ClientHello extension data
236    pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
237        let (client_types, server_types) =
238            self.extension_context.simulate_send_preferences(conn_id);
239
240        let mut extensions = Vec::new();
241
242        if let Some(data) = client_types {
243            extensions.push((47, data)); // client_certificate_type
244        }
245
246        if let Some(data) = server_types {
247            extensions.push((48, data)); // server_certificate_type
248        }
249
250        extensions
251    }
252}
253
254/// Wrapper for ServerConfig that simulates RFC 7250 extension behavior
255pub struct Rfc7250ServerConfig {
256    inner: Arc<ServerConfig>,
257    extension_context: Arc<SimulatedExtensionContext>,
258}
259
260impl Rfc7250ServerConfig {
261    /// Create a new RFC 7250 aware server configuration
262    pub fn new(base_config: ServerConfig, preferences: CertificateTypePreferences) -> Self {
263        Self {
264            inner: Arc::new(base_config),
265            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
266        }
267    }
268
269    /// Get the inner rustls ServerConfig
270    pub fn inner(&self) -> &Arc<ServerConfig> {
271        &self.inner
272    }
273
274    /// Get the extension context for negotiation
275    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
276        &self.extension_context
277    }
278
279    /// Process ClientHello extensions and prepare ServerHello response
280    pub fn process_client_hello_extensions(
281        &self,
282        conn_id: &str,
283        client_extensions: &[(u16, Vec<u8>)],
284    ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
285        // First, register this connection
286        self.extension_context.simulate_send_preferences(conn_id);
287
288        // Process client's certificate type preferences
289        let mut client_types_data = None;
290        let mut server_types_data = None;
291
292        for (ext_id, data) in client_extensions {
293            match *ext_id {
294                47 => client_types_data = Some(data.as_slice()),
295                48 => server_types_data = Some(data.as_slice()),
296                _ => {}
297            }
298        }
299
300        // Store remote preferences
301        self.extension_context.simulate_receive_preferences(
302            conn_id,
303            client_types_data,
304            server_types_data,
305        )?;
306
307        // Complete negotiation
308        let result = self.extension_context.complete_negotiation(conn_id)?;
309
310        // Prepare ServerHello extensions with negotiated types
311        let mut response_extensions = Vec::new();
312
313        // Send back single negotiated type for each extension
314        response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
315        response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
316
317        Ok(response_extensions)
318    }
319}
320
321/// Helper to determine if we should use Raw Public Key based on negotiation
322pub fn should_use_raw_public_key(negotiation_result: &NegotiationResult, is_client: bool) -> bool {
323    if is_client {
324        negotiation_result.client_cert_type.is_raw_public_key()
325    } else {
326        negotiation_result.server_cert_type.is_raw_public_key()
327    }
328}
329
330/// Create a connection identifier for simulation purposes
331pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
332    format!("{local_addr}-{remote_addr}")
333}
334
335/// Wrapper for TlsSession that integrates with TlsExtensionHooks
336pub struct ExtensionAwareTlsSession {
337    /// The underlying TLS session
338    inner_session: Box<dyn crate::crypto::Session>,
339    /// Extension hooks for certificate type negotiation
340    extension_hooks: Arc<dyn TlsExtensionHooks>,
341    /// Connection identifier
342    conn_id: String,
343    /// Whether this is a client session
344    is_client: bool,
345    /// Whether handshake is complete
346    handshake_complete: bool,
347}
348
349impl ExtensionAwareTlsSession {
350    /// Create a new extension-aware TLS session
351    pub fn new(
352        inner_session: Box<dyn crate::crypto::Session>,
353        extension_hooks: Arc<dyn TlsExtensionHooks>,
354        conn_id: String,
355        is_client: bool,
356    ) -> Self {
357        Self {
358            inner_session,
359            extension_hooks,
360            conn_id,
361            is_client,
362            handshake_complete: false,
363        }
364    }
365
366    /// Get the negotiation result if available
367    pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
368        self.extension_hooks.get_negotiation_result(&self.conn_id)
369    }
370}
371
372/// Implement the crypto::Session trait for our wrapper
373impl crate::crypto::Session for ExtensionAwareTlsSession {
374    fn initial_keys(
375        &self,
376        dst_cid: &crate::ConnectionId,
377        side: crate::Side,
378    ) -> crate::crypto::Keys {
379        self.inner_session.initial_keys(dst_cid, side)
380    }
381
382    fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
383        self.inner_session.handshake_data()
384    }
385
386    fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
387        self.inner_session.peer_identity()
388    }
389
390    fn early_crypto(
391        &self,
392    ) -> Option<(
393        Box<dyn crate::crypto::HeaderKey>,
394        Box<dyn crate::crypto::PacketKey>,
395    )> {
396        self.inner_session.early_crypto()
397    }
398
399    fn early_data_accepted(&self) -> Option<bool> {
400        self.inner_session.early_data_accepted()
401    }
402
403    fn is_handshaking(&self) -> bool {
404        self.inner_session.is_handshaking()
405    }
406
407    fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
408        let result = self.inner_session.read_handshake(buf)?;
409
410        // Check if handshake is complete
411        if result && !self.handshake_complete && !self.is_handshaking() {
412            self.handshake_complete = true;
413            self.extension_hooks
414                .on_handshake_complete(&self.conn_id, self.is_client);
415        }
416
417        Ok(result)
418    }
419
420    fn transport_parameters(
421        &self,
422    ) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError>
423    {
424        self.inner_session.transport_parameters()
425    }
426
427    fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
428        self.inner_session.write_handshake(buf)
429    }
430
431    fn next_1rtt_keys(
432        &mut self,
433    ) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
434        self.inner_session.next_1rtt_keys()
435    }
436
437    fn is_valid_retry(
438        &self,
439        orig_dst_cid: &crate::ConnectionId,
440        header: &[u8],
441        payload: &[u8],
442    ) -> bool {
443        self.inner_session
444            .is_valid_retry(orig_dst_cid, header, payload)
445    }
446
447    fn export_keying_material(
448        &self,
449        output: &mut [u8],
450        label: &[u8],
451        context: &[u8],
452    ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
453        self.inner_session
454            .export_keying_material(output, label, context)
455    }
456}
457
458/// Enhanced QUIC client config with RFC 7250 support
459pub struct Rfc7250QuicClientConfig {
460    /// Base QUIC client config
461    base_config: Arc<dyn QuicClientConfig>,
462    /// Extension context for certificate type negotiation
463    extension_context: Arc<SimulatedExtensionContext>,
464}
465
466impl Rfc7250QuicClientConfig {
467    /// Create a new RFC 7250 aware QUIC client config
468    pub fn new(
469        base_config: Arc<dyn QuicClientConfig>,
470        preferences: CertificateTypePreferences,
471    ) -> Self {
472        Self {
473            base_config,
474            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
475        }
476    }
477}
478
479impl QuicClientConfig for Rfc7250QuicClientConfig {
480    fn start_session(
481        self: Arc<Self>,
482        version: u32,
483        server_name: &str,
484        params: &crate::transport_parameters::TransportParameters,
485    ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
486        // Create the base session
487        let inner_session = self
488            .base_config
489            .clone()
490            .start_session(version, server_name, params)?;
491
492        // Create connection ID for this session
493        let conn_id = format!(
494            "client-{}-{}",
495            server_name,
496            std::time::SystemTime::now()
497                .duration_since(std::time::UNIX_EPOCH)
498                .unwrap_or_else(|_| std::time::Duration::from_secs(0))
499                .as_nanos()
500        );
501
502        // Create wrapper with extension hooks
503        Ok(Box::new(ExtensionAwareTlsSession::new(
504            inner_session,
505            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
506            conn_id,
507            true, // is_client
508        )))
509    }
510}
511
512/// Enhanced QUIC server config with RFC 7250 support
513pub struct Rfc7250QuicServerConfig {
514    /// Base QUIC server config
515    base_config: Arc<dyn QuicServerConfig>,
516    /// Extension context for certificate type negotiation
517    extension_context: Arc<SimulatedExtensionContext>,
518}
519
520impl Rfc7250QuicServerConfig {
521    /// Create a new RFC 7250 aware QUIC server config
522    pub fn new(
523        base_config: Arc<dyn QuicServerConfig>,
524        preferences: CertificateTypePreferences,
525    ) -> Self {
526        Self {
527            base_config,
528            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
529        }
530    }
531}
532
533impl QuicServerConfig for Rfc7250QuicServerConfig {
534    fn start_session(
535        self: Arc<Self>,
536        version: u32,
537        params: &crate::transport_parameters::TransportParameters,
538    ) -> Box<dyn crate::crypto::Session> {
539        // Create the base session
540        let inner_session = self.base_config.clone().start_session(version, params);
541
542        // Create connection ID for this session
543        let conn_id = format!(
544            "server-{}",
545            std::time::SystemTime::now()
546                .duration_since(std::time::UNIX_EPOCH)
547                .unwrap_or_else(|_| std::time::Duration::from_secs(0))
548                .as_nanos()
549        );
550
551        // Create wrapper with extension hooks
552        Box::new(ExtensionAwareTlsSession::new(
553            inner_session,
554            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
555            conn_id,
556            false, // is_client = false for server
557        ))
558    }
559
560    fn initial_keys(
561        &self,
562        version: u32,
563        dst_cid: &crate::ConnectionId,
564    ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
565        self.base_config.initial_keys(version, dst_cid)
566    }
567
568    fn retry_tag(
569        &self,
570        version: u32,
571        orig_dst_cid: &crate::ConnectionId,
572        packet: &[u8],
573    ) -> [u8; 16] {
574        self.base_config.retry_tag(version, orig_dst_cid, packet)
575    }
576}
577
578#[cfg(test)]
579mod tests {
580    use super::super::tls_extensions::CertificateType;
581    use super::*;
582    use std::sync::Once;
583
584    static INIT: Once = Once::new();
585
586    // Ensure crypto provider is installed for tests
587    fn ensure_crypto_provider() {
588        INIT.call_once(|| {
589            // Install the crypto provider if not already installed
590            #[cfg(feature = "rustls-aws-lc-rs")]
591            let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
592
593            #[cfg(feature = "rustls-ring")]
594            let _ = rustls::crypto::ring::default_provider().install_default();
595        });
596    }
597
598    #[test]
599    fn test_simulated_negotiation_flow() {
600        // Client side
601        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
602        let client_ctx = SimulatedExtensionContext::new(client_prefs);
603
604        // Server side
605        let server_prefs = CertificateTypePreferences::raw_public_key_only();
606        let server_ctx = SimulatedExtensionContext::new(server_prefs);
607
608        let conn_id = "test-connection";
609
610        // Client sends preferences
611        let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
612        assert!(client_types.is_some());
613        assert!(server_types.is_some());
614
615        // Server receives and processes
616        server_ctx.simulate_send_preferences(conn_id);
617        server_ctx
618            .simulate_receive_preferences(conn_id, client_types.as_deref(), server_types.as_deref())
619            .unwrap();
620
621        // Server completes negotiation
622        let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
623        assert!(server_result.is_raw_public_key_only());
624
625        // Client receives server's response (simulated)
626        let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
627        let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
628
629        client_ctx
630            .simulate_receive_preferences(
631                conn_id,
632                Some(&server_response_client),
633                Some(&server_response_server),
634            )
635            .unwrap();
636
637        // Client completes negotiation
638        let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
639        assert_eq!(client_result, server_result);
640    }
641
642    #[test]
643    fn test_wrapper_configs() {
644        ensure_crypto_provider();
645        let client_config = ClientConfig::builder()
646            .dangerous()
647            .with_custom_certificate_verifier(Arc::new(
648                crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new()),
649            ))
650            .with_no_client_auth();
651
652        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
653        let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
654
655        let conn_id = "test-conn";
656        let extensions = wrapped_client.get_client_hello_extensions(conn_id);
657
658        assert_eq!(extensions.len(), 2);
659        assert_eq!(extensions[0].0, 47); // client_certificate_type
660        assert_eq!(extensions[1].0, 48); // server_certificate_type
661    }
662}