ant_quic/crypto/
tls_extension_simulation.rs

1//! TLS Extension Simulation for RFC 7250 Raw Public Keys
2//!
3//! Since rustls 0.23.x doesn't expose APIs for custom TLS extensions,
4//! this module simulates the RFC 7250 certificate type negotiation
5//! through alternative mechanisms that work within rustls constraints.
6
7use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
8use rustls::{ClientConfig, ServerConfig};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12use super::tls_extensions::{
13    CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
14};
15
16/// Trait for hooking into TLS handshake events
17pub trait TlsExtensionHooks: Send + Sync {
18    /// Called when the handshake is complete
19    fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
20
21    /// Called to get extension data for ClientHello
22    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
23
24    /// Called to process ServerHello extensions
25    fn process_server_hello_extensions(
26        &self,
27        conn_id: &str,
28        extensions: &[(u16, Vec<u8>)],
29    ) -> Result<(), TlsExtensionError>;
30
31    /// Get the negotiation result for a connection
32    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
33}
34
35/// Simulated TLS extension context for certificate type negotiation
36#[derive(Debug)]
37pub struct SimulatedExtensionContext {
38    /// Active negotiations indexed by connection ID
39    negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
40    /// Local preferences for this endpoint
41    local_preferences: CertificateTypePreferences,
42}
43
44#[derive(Debug, Clone)]
45struct NegotiationState {
46    local_preferences: CertificateTypePreferences,
47    remote_client_types: Option<CertificateTypeList>,
48    remote_server_types: Option<CertificateTypeList>,
49    result: Option<NegotiationResult>,
50}
51
52impl SimulatedExtensionContext {
53    /// Create a new simulated extension context
54    pub fn new(preferences: CertificateTypePreferences) -> Self {
55        Self {
56            negotiations: Arc::new(Mutex::new(HashMap::new())),
57            local_preferences: preferences,
58        }
59    }
60
61    /// Simulate sending certificate type preferences
62    /// In reality, this would be sent in ClientHello/ServerHello extensions
63    pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
64        let mut negotiations = self.negotiations.lock().unwrap();
65
66        let state = NegotiationState {
67            local_preferences: self.local_preferences.clone(),
68            remote_client_types: None,
69            remote_server_types: None,
70            result: None,
71        };
72
73        negotiations.insert(conn_id.to_string(), state);
74
75        // Simulate extension data that would be sent
76        let client_ext_data = self.local_preferences.client_types.to_bytes();
77        let server_ext_data = self.local_preferences.server_types.to_bytes();
78
79        (Some(client_ext_data), Some(server_ext_data))
80    }
81
82    /// Simulate receiving certificate type preferences from peer
83    pub fn simulate_receive_preferences(
84        &self,
85        conn_id: &str,
86        client_types_data: Option<&[u8]>,
87        server_types_data: Option<&[u8]>,
88    ) -> Result<(), TlsExtensionError> {
89        let mut negotiations = self.negotiations.lock().unwrap();
90
91        let state = negotiations.get_mut(conn_id).ok_or_else(|| {
92            TlsExtensionError::InvalidExtensionData(format!(
93                "No negotiation state for connection {conn_id}"
94            ))
95        })?;
96
97        if let Some(data) = client_types_data {
98            state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
99        }
100
101        if let Some(data) = server_types_data {
102            state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
103        }
104
105        Ok(())
106    }
107
108    /// Complete the negotiation and get the result
109    pub fn complete_negotiation(
110        &self,
111        conn_id: &str,
112    ) -> Result<NegotiationResult, TlsExtensionError> {
113        let mut negotiations = self.negotiations.lock().unwrap();
114
115        let state = negotiations.get_mut(conn_id).ok_or_else(|| {
116            TlsExtensionError::InvalidExtensionData(format!(
117                "No negotiation state for connection {conn_id}"
118            ))
119        })?;
120
121        if let Some(result) = &state.result {
122            return Ok(result.clone());
123        }
124
125        let result = state.local_preferences.negotiate(
126            state.remote_client_types.as_ref(),
127            state.remote_server_types.as_ref(),
128        )?;
129
130        state.result = Some(result.clone());
131        Ok(result)
132    }
133
134    /// Clean up negotiation state for a connection
135    pub fn cleanup_connection(&self, conn_id: &str) {
136        let mut negotiations = self.negotiations.lock().unwrap();
137        negotiations.remove(conn_id);
138    }
139}
140
141impl TlsExtensionHooks for SimulatedExtensionContext {
142    fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
143        // Try to complete negotiation if not already done
144        let _ = self.complete_negotiation(conn_id);
145    }
146
147    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
148        let (client_types, server_types) = self.simulate_send_preferences(conn_id);
149
150        let mut extensions = Vec::new();
151
152        if let Some(data) = client_types {
153            extensions.push((47, data)); // client_certificate_type
154        }
155
156        if let Some(data) = server_types {
157            extensions.push((48, data)); // server_certificate_type
158        }
159
160        extensions
161    }
162
163    fn process_server_hello_extensions(
164        &self,
165        conn_id: &str,
166        extensions: &[(u16, Vec<u8>)],
167    ) -> Result<(), TlsExtensionError> {
168        let mut client_types_data = None;
169        let mut server_types_data = None;
170
171        for (ext_id, data) in extensions {
172            match *ext_id {
173                47 => client_types_data = Some(data.as_slice()),
174                48 => server_types_data = Some(data.as_slice()),
175                _ => {}
176            }
177        }
178
179        self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
180    }
181
182    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
183        self.complete_negotiation(conn_id).ok()
184    }
185}
186
187/// Wrapper for ClientConfig that simulates RFC 7250 extension behavior
188pub struct Rfc7250ClientConfig {
189    inner: Arc<ClientConfig>,
190    extension_context: Arc<SimulatedExtensionContext>,
191}
192
193impl Rfc7250ClientConfig {
194    /// Create a new RFC 7250 aware client configuration
195    pub fn new(base_config: ClientConfig, preferences: CertificateTypePreferences) -> Self {
196        Self {
197            inner: Arc::new(base_config),
198            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
199        }
200    }
201
202    /// Get the inner rustls ClientConfig
203    pub fn inner(&self) -> &Arc<ClientConfig> {
204        &self.inner
205    }
206
207    /// Get the extension context for negotiation
208    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
209        &self.extension_context
210    }
211
212    /// Simulate the ClientHello extension data
213    pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
214        let (client_types, server_types) =
215            self.extension_context.simulate_send_preferences(conn_id);
216
217        let mut extensions = Vec::new();
218
219        if let Some(data) = client_types {
220            extensions.push((47, data)); // client_certificate_type
221        }
222
223        if let Some(data) = server_types {
224            extensions.push((48, data)); // server_certificate_type
225        }
226
227        extensions
228    }
229}
230
231/// Wrapper for ServerConfig that simulates RFC 7250 extension behavior
232pub struct Rfc7250ServerConfig {
233    inner: Arc<ServerConfig>,
234    extension_context: Arc<SimulatedExtensionContext>,
235}
236
237impl Rfc7250ServerConfig {
238    /// Create a new RFC 7250 aware server configuration
239    pub fn new(base_config: ServerConfig, preferences: CertificateTypePreferences) -> Self {
240        Self {
241            inner: Arc::new(base_config),
242            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
243        }
244    }
245
246    /// Get the inner rustls ServerConfig
247    pub fn inner(&self) -> &Arc<ServerConfig> {
248        &self.inner
249    }
250
251    /// Get the extension context for negotiation
252    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
253        &self.extension_context
254    }
255
256    /// Process ClientHello extensions and prepare ServerHello response
257    pub fn process_client_hello_extensions(
258        &self,
259        conn_id: &str,
260        client_extensions: &[(u16, Vec<u8>)],
261    ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
262        // First, register this connection
263        self.extension_context.simulate_send_preferences(conn_id);
264
265        // Process client's certificate type preferences
266        let mut client_types_data = None;
267        let mut server_types_data = None;
268
269        for (ext_id, data) in client_extensions {
270            match *ext_id {
271                47 => client_types_data = Some(data.as_slice()),
272                48 => server_types_data = Some(data.as_slice()),
273                _ => {}
274            }
275        }
276
277        // Store remote preferences
278        self.extension_context.simulate_receive_preferences(
279            conn_id,
280            client_types_data,
281            server_types_data,
282        )?;
283
284        // Complete negotiation
285        let result = self.extension_context.complete_negotiation(conn_id)?;
286
287        // Prepare ServerHello extensions with negotiated types
288        let mut response_extensions = Vec::new();
289
290        // Send back single negotiated type for each extension
291        response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
292        response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
293
294        Ok(response_extensions)
295    }
296}
297
298/// Helper to determine if we should use Raw Public Key based on negotiation
299pub fn should_use_raw_public_key(negotiation_result: &NegotiationResult, is_client: bool) -> bool {
300    if is_client {
301        negotiation_result.client_cert_type.is_raw_public_key()
302    } else {
303        negotiation_result.server_cert_type.is_raw_public_key()
304    }
305}
306
307/// Create a connection identifier for simulation purposes
308pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
309    format!("{local_addr}-{remote_addr}")
310}
311
312/// Wrapper for TlsSession that integrates with TlsExtensionHooks
313pub struct ExtensionAwareTlsSession {
314    /// The underlying TLS session
315    inner_session: Box<dyn crate::crypto::Session>,
316    /// Extension hooks for certificate type negotiation
317    extension_hooks: Arc<dyn TlsExtensionHooks>,
318    /// Connection identifier
319    conn_id: String,
320    /// Whether this is a client session
321    is_client: bool,
322    /// Whether handshake is complete
323    handshake_complete: bool,
324}
325
326impl ExtensionAwareTlsSession {
327    /// Create a new extension-aware TLS session
328    pub fn new(
329        inner_session: Box<dyn crate::crypto::Session>,
330        extension_hooks: Arc<dyn TlsExtensionHooks>,
331        conn_id: String,
332        is_client: bool,
333    ) -> Self {
334        Self {
335            inner_session,
336            extension_hooks,
337            conn_id,
338            is_client,
339            handshake_complete: false,
340        }
341    }
342
343    /// Get the negotiation result if available
344    pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
345        self.extension_hooks.get_negotiation_result(&self.conn_id)
346    }
347}
348
349/// Implement the crypto::Session trait for our wrapper
350impl crate::crypto::Session for ExtensionAwareTlsSession {
351    fn initial_keys(
352        &self,
353        dst_cid: &crate::ConnectionId,
354        side: crate::Side,
355    ) -> crate::crypto::Keys {
356        self.inner_session.initial_keys(dst_cid, side)
357    }
358
359    fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
360        self.inner_session.handshake_data()
361    }
362
363    fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
364        self.inner_session.peer_identity()
365    }
366
367    fn early_crypto(
368        &self,
369    ) -> Option<(
370        Box<dyn crate::crypto::HeaderKey>,
371        Box<dyn crate::crypto::PacketKey>,
372    )> {
373        self.inner_session.early_crypto()
374    }
375
376    fn early_data_accepted(&self) -> Option<bool> {
377        self.inner_session.early_data_accepted()
378    }
379
380    fn is_handshaking(&self) -> bool {
381        self.inner_session.is_handshaking()
382    }
383
384    fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
385        let result = self.inner_session.read_handshake(buf)?;
386
387        // Check if handshake is complete
388        if result && !self.handshake_complete && !self.is_handshaking() {
389            self.handshake_complete = true;
390            self.extension_hooks
391                .on_handshake_complete(&self.conn_id, self.is_client);
392        }
393
394        Ok(result)
395    }
396
397    fn transport_parameters(
398        &self,
399    ) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError>
400    {
401        self.inner_session.transport_parameters()
402    }
403
404    fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
405        self.inner_session.write_handshake(buf)
406    }
407
408    fn next_1rtt_keys(
409        &mut self,
410    ) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
411        self.inner_session.next_1rtt_keys()
412    }
413
414    fn is_valid_retry(
415        &self,
416        orig_dst_cid: &crate::ConnectionId,
417        header: &[u8],
418        payload: &[u8],
419    ) -> bool {
420        self.inner_session
421            .is_valid_retry(orig_dst_cid, header, payload)
422    }
423
424    fn export_keying_material(
425        &self,
426        output: &mut [u8],
427        label: &[u8],
428        context: &[u8],
429    ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
430        self.inner_session
431            .export_keying_material(output, label, context)
432    }
433}
434
435/// Enhanced QUIC client config with RFC 7250 support
436pub struct Rfc7250QuicClientConfig {
437    /// Base QUIC client config
438    base_config: Arc<dyn QuicClientConfig>,
439    /// Extension context for certificate type negotiation
440    extension_context: Arc<SimulatedExtensionContext>,
441}
442
443impl Rfc7250QuicClientConfig {
444    /// Create a new RFC 7250 aware QUIC client config
445    pub fn new(
446        base_config: Arc<dyn QuicClientConfig>,
447        preferences: CertificateTypePreferences,
448    ) -> Self {
449        Self {
450            base_config,
451            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
452        }
453    }
454}
455
456impl QuicClientConfig for Rfc7250QuicClientConfig {
457    fn start_session(
458        self: Arc<Self>,
459        version: u32,
460        server_name: &str,
461        params: &crate::transport_parameters::TransportParameters,
462    ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
463        // Create the base session
464        let inner_session = self
465            .base_config
466            .clone()
467            .start_session(version, server_name, params)?;
468
469        // Create connection ID for this session
470        let conn_id = format!(
471            "client-{}-{}",
472            server_name,
473            std::time::SystemTime::now()
474                .duration_since(std::time::UNIX_EPOCH)
475                .unwrap()
476                .as_nanos()
477        );
478
479        // Create wrapper with extension hooks
480        Ok(Box::new(ExtensionAwareTlsSession::new(
481            inner_session,
482            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
483            conn_id,
484            true, // is_client
485        )))
486    }
487}
488
489/// Enhanced QUIC server config with RFC 7250 support
490pub struct Rfc7250QuicServerConfig {
491    /// Base QUIC server config
492    base_config: Arc<dyn QuicServerConfig>,
493    /// Extension context for certificate type negotiation
494    extension_context: Arc<SimulatedExtensionContext>,
495}
496
497impl Rfc7250QuicServerConfig {
498    /// Create a new RFC 7250 aware QUIC server config
499    pub fn new(
500        base_config: Arc<dyn QuicServerConfig>,
501        preferences: CertificateTypePreferences,
502    ) -> Self {
503        Self {
504            base_config,
505            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
506        }
507    }
508}
509
510impl QuicServerConfig for Rfc7250QuicServerConfig {
511    fn start_session(
512        self: Arc<Self>,
513        version: u32,
514        params: &crate::transport_parameters::TransportParameters,
515    ) -> Box<dyn crate::crypto::Session> {
516        // Create the base session
517        let inner_session = self.base_config.clone().start_session(version, params);
518
519        // Create connection ID for this session
520        let conn_id = format!(
521            "server-{}",
522            std::time::SystemTime::now()
523                .duration_since(std::time::UNIX_EPOCH)
524                .unwrap()
525                .as_nanos()
526        );
527
528        // Create wrapper with extension hooks
529        Box::new(ExtensionAwareTlsSession::new(
530            inner_session,
531            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
532            conn_id,
533            false, // is_client = false for server
534        ))
535    }
536
537    fn initial_keys(
538        &self,
539        version: u32,
540        dst_cid: &crate::ConnectionId,
541    ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
542        self.base_config.initial_keys(version, dst_cid)
543    }
544
545    fn retry_tag(
546        &self,
547        version: u32,
548        orig_dst_cid: &crate::ConnectionId,
549        packet: &[u8],
550    ) -> [u8; 16] {
551        self.base_config.retry_tag(version, orig_dst_cid, packet)
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    use super::super::tls_extensions::CertificateType;
558    use super::*;
559    use std::sync::Once;
560
561    static INIT: Once = Once::new();
562
563    // Ensure crypto provider is installed for tests
564    fn ensure_crypto_provider() {
565        INIT.call_once(|| {
566            // Install the crypto provider if not already installed
567            #[cfg(feature = "rustls-aws-lc-rs")]
568            let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
569
570            #[cfg(feature = "rustls-ring")]
571            let _ = rustls::crypto::ring::default_provider().install_default();
572        });
573    }
574
575    #[test]
576    fn test_simulated_negotiation_flow() {
577        // Client side
578        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
579        let client_ctx = SimulatedExtensionContext::new(client_prefs);
580
581        // Server side
582        let server_prefs = CertificateTypePreferences::raw_public_key_only();
583        let server_ctx = SimulatedExtensionContext::new(server_prefs);
584
585        let conn_id = "test-connection";
586
587        // Client sends preferences
588        let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
589        assert!(client_types.is_some());
590        assert!(server_types.is_some());
591
592        // Server receives and processes
593        server_ctx.simulate_send_preferences(conn_id);
594        server_ctx
595            .simulate_receive_preferences(conn_id, client_types.as_deref(), server_types.as_deref())
596            .unwrap();
597
598        // Server completes negotiation
599        let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
600        assert!(server_result.is_raw_public_key_only());
601
602        // Client receives server's response (simulated)
603        let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
604        let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
605
606        client_ctx
607            .simulate_receive_preferences(
608                conn_id,
609                Some(&server_response_client),
610                Some(&server_response_server),
611            )
612            .unwrap();
613
614        // Client completes negotiation
615        let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
616        assert_eq!(client_result, server_result);
617    }
618
619    #[test]
620    fn test_wrapper_configs() {
621        ensure_crypto_provider();
622        let client_config = ClientConfig::builder()
623            .dangerous()
624            .with_custom_certificate_verifier(Arc::new(
625                crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new()),
626            ))
627            .with_no_client_auth();
628
629        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
630        let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
631
632        let conn_id = "test-conn";
633        let extensions = wrapped_client.get_client_hello_extensions(conn_id);
634
635        assert_eq!(extensions.len(), 2);
636        assert_eq!(extensions[0].0, 47); // client_certificate_type
637        assert_eq!(extensions[1].0, 48); // server_certificate_type
638    }
639}