ant_quic/crypto/
tls_extension_simulation.rs

1//! TLS Extension Simulation for RFC 7250 Raw Public Keys
2//!
3//! Since rustls 0.23.x doesn't expose APIs for custom TLS extensions,
4//! this module simulates the RFC 7250 certificate type negotiation
5//! through alternative mechanisms that work within rustls constraints.
6
7use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
8use rustls::{ClientConfig, ServerConfig};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12use super::tls_extensions::{
13    CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
14};
15
16/// Trait for hooking into TLS handshake events
17pub trait TlsExtensionHooks: Send + Sync {
18    /// Called when the handshake is complete
19    fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
20
21    /// Called to get extension data for ClientHello
22    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
23
24    /// Called to process ServerHello extensions
25    fn process_server_hello_extensions(
26        &self,
27        conn_id: &str,
28        extensions: &[(u16, Vec<u8>)],
29    ) -> Result<(), TlsExtensionError>;
30
31    /// Get the negotiation result for a connection
32    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
33}
34
35/// Simulated TLS extension context for certificate type negotiation
36#[derive(Debug)]
37pub struct SimulatedExtensionContext {
38    /// Active negotiations indexed by connection ID
39    negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
40    /// Local preferences for this endpoint
41    local_preferences: CertificateTypePreferences,
42}
43
44#[derive(Debug, Clone)]
45struct NegotiationState {
46    local_preferences: CertificateTypePreferences,
47    remote_client_types: Option<CertificateTypeList>,
48    remote_server_types: Option<CertificateTypeList>,
49    result: Option<NegotiationResult>,
50}
51
52impl SimulatedExtensionContext {
53    /// Create a new simulated extension context
54    pub fn new(preferences: CertificateTypePreferences) -> Self {
55        Self {
56            negotiations: Arc::new(Mutex::new(HashMap::new())),
57            local_preferences: preferences,
58        }
59    }
60
61    /// Simulate sending certificate type preferences
62    /// In reality, this would be sent in ClientHello/ServerHello extensions
63    pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
64        let mut negotiations = self.negotiations.lock().unwrap();
65
66        let state = NegotiationState {
67            local_preferences: self.local_preferences.clone(),
68            remote_client_types: None,
69            remote_server_types: None,
70            result: None,
71        };
72
73        negotiations.insert(conn_id.to_string(), state);
74
75        // Simulate extension data that would be sent
76        let client_ext_data = self.local_preferences.client_types.to_bytes();
77        let server_ext_data = self.local_preferences.server_types.to_bytes();
78
79        (Some(client_ext_data), Some(server_ext_data))
80    }
81
82    /// Simulate receiving certificate type preferences from peer
83    pub fn simulate_receive_preferences(
84        &self,
85        conn_id: &str,
86        client_types_data: Option<&[u8]>,
87        server_types_data: Option<&[u8]>,
88    ) -> Result<(), TlsExtensionError> {
89        let mut negotiations = self.negotiations.lock().unwrap();
90
91        let state = negotiations.get_mut(conn_id).ok_or_else(|| {
92            TlsExtensionError::InvalidExtensionData(format!(
93                "No negotiation state for connection {}",
94                conn_id
95            ))
96        })?;
97
98        if let Some(data) = client_types_data {
99            state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
100        }
101
102        if let Some(data) = server_types_data {
103            state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
104        }
105
106        Ok(())
107    }
108
109    /// Complete the negotiation and get the result
110    pub fn complete_negotiation(
111        &self,
112        conn_id: &str,
113    ) -> Result<NegotiationResult, TlsExtensionError> {
114        let mut negotiations = self.negotiations.lock().unwrap();
115
116        let state = negotiations.get_mut(conn_id).ok_or_else(|| {
117            TlsExtensionError::InvalidExtensionData(format!(
118                "No negotiation state for connection {}",
119                conn_id
120            ))
121        })?;
122
123        if let Some(result) = &state.result {
124            return Ok(result.clone());
125        }
126
127        let result = state.local_preferences.negotiate(
128            state.remote_client_types.as_ref(),
129            state.remote_server_types.as_ref(),
130        )?;
131
132        state.result = Some(result.clone());
133        Ok(result)
134    }
135
136    /// Clean up negotiation state for a connection
137    pub fn cleanup_connection(&self, conn_id: &str) {
138        let mut negotiations = self.negotiations.lock().unwrap();
139        negotiations.remove(conn_id);
140    }
141}
142
143impl TlsExtensionHooks for SimulatedExtensionContext {
144    fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
145        // Try to complete negotiation if not already done
146        let _ = self.complete_negotiation(conn_id);
147    }
148
149    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
150        let (client_types, server_types) = self.simulate_send_preferences(conn_id);
151
152        let mut extensions = Vec::new();
153
154        if let Some(data) = client_types {
155            extensions.push((47, data)); // client_certificate_type
156        }
157
158        if let Some(data) = server_types {
159            extensions.push((48, data)); // server_certificate_type
160        }
161
162        extensions
163    }
164
165    fn process_server_hello_extensions(
166        &self,
167        conn_id: &str,
168        extensions: &[(u16, Vec<u8>)],
169    ) -> Result<(), TlsExtensionError> {
170        let mut client_types_data = None;
171        let mut server_types_data = None;
172
173        for (ext_id, data) in extensions {
174            match *ext_id {
175                47 => client_types_data = Some(data.as_slice()),
176                48 => server_types_data = Some(data.as_slice()),
177                _ => {}
178            }
179        }
180
181        self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
182    }
183
184    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
185        self.complete_negotiation(conn_id).ok()
186    }
187}
188
189/// Wrapper for ClientConfig that simulates RFC 7250 extension behavior
190pub struct Rfc7250ClientConfig {
191    inner: Arc<ClientConfig>,
192    extension_context: Arc<SimulatedExtensionContext>,
193}
194
195impl Rfc7250ClientConfig {
196    /// Create a new RFC 7250 aware client configuration
197    pub fn new(base_config: ClientConfig, preferences: CertificateTypePreferences) -> Self {
198        Self {
199            inner: Arc::new(base_config),
200            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
201        }
202    }
203
204    /// Get the inner rustls ClientConfig
205    pub fn inner(&self) -> &Arc<ClientConfig> {
206        &self.inner
207    }
208
209    /// Get the extension context for negotiation
210    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
211        &self.extension_context
212    }
213
214    /// Simulate the ClientHello extension data
215    pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
216        let (client_types, server_types) =
217            self.extension_context.simulate_send_preferences(conn_id);
218
219        let mut extensions = Vec::new();
220
221        if let Some(data) = client_types {
222            extensions.push((47, data)); // client_certificate_type
223        }
224
225        if let Some(data) = server_types {
226            extensions.push((48, data)); // server_certificate_type
227        }
228
229        extensions
230    }
231}
232
233/// Wrapper for ServerConfig that simulates RFC 7250 extension behavior
234pub struct Rfc7250ServerConfig {
235    inner: Arc<ServerConfig>,
236    extension_context: Arc<SimulatedExtensionContext>,
237}
238
239impl Rfc7250ServerConfig {
240    /// Create a new RFC 7250 aware server configuration
241    pub fn new(base_config: ServerConfig, preferences: CertificateTypePreferences) -> Self {
242        Self {
243            inner: Arc::new(base_config),
244            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
245        }
246    }
247
248    /// Get the inner rustls ServerConfig
249    pub fn inner(&self) -> &Arc<ServerConfig> {
250        &self.inner
251    }
252
253    /// Get the extension context for negotiation
254    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
255        &self.extension_context
256    }
257
258    /// Process ClientHello extensions and prepare ServerHello response
259    pub fn process_client_hello_extensions(
260        &self,
261        conn_id: &str,
262        client_extensions: &[(u16, Vec<u8>)],
263    ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
264        // First, register this connection
265        self.extension_context.simulate_send_preferences(conn_id);
266
267        // Process client's certificate type preferences
268        let mut client_types_data = None;
269        let mut server_types_data = None;
270
271        for (ext_id, data) in client_extensions {
272            match *ext_id {
273                47 => client_types_data = Some(data.as_slice()),
274                48 => server_types_data = Some(data.as_slice()),
275                _ => {}
276            }
277        }
278
279        // Store remote preferences
280        self.extension_context.simulate_receive_preferences(
281            conn_id,
282            client_types_data,
283            server_types_data,
284        )?;
285
286        // Complete negotiation
287        let result = self.extension_context.complete_negotiation(conn_id)?;
288
289        // Prepare ServerHello extensions with negotiated types
290        let mut response_extensions = Vec::new();
291
292        // Send back single negotiated type for each extension
293        response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
294        response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
295
296        Ok(response_extensions)
297    }
298}
299
300/// Helper to determine if we should use Raw Public Key based on negotiation
301pub fn should_use_raw_public_key(negotiation_result: &NegotiationResult, is_client: bool) -> bool {
302    if is_client {
303        negotiation_result.client_cert_type.is_raw_public_key()
304    } else {
305        negotiation_result.server_cert_type.is_raw_public_key()
306    }
307}
308
309/// Create a connection identifier for simulation purposes
310pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
311    format!("{}-{}", local_addr, remote_addr)
312}
313
314/// Wrapper for TlsSession that integrates with TlsExtensionHooks
315pub struct ExtensionAwareTlsSession {
316    /// The underlying TLS session
317    inner_session: Box<dyn crate::crypto::Session>,
318    /// Extension hooks for certificate type negotiation
319    extension_hooks: Arc<dyn TlsExtensionHooks>,
320    /// Connection identifier
321    conn_id: String,
322    /// Whether this is a client session
323    is_client: bool,
324    /// Whether handshake is complete
325    handshake_complete: bool,
326}
327
328impl ExtensionAwareTlsSession {
329    /// Create a new extension-aware TLS session
330    pub fn new(
331        inner_session: Box<dyn crate::crypto::Session>,
332        extension_hooks: Arc<dyn TlsExtensionHooks>,
333        conn_id: String,
334        is_client: bool,
335    ) -> Self {
336        Self {
337            inner_session,
338            extension_hooks,
339            conn_id,
340            is_client,
341            handshake_complete: false,
342        }
343    }
344
345    /// Get the negotiation result if available
346    pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
347        self.extension_hooks.get_negotiation_result(&self.conn_id)
348    }
349}
350
351/// Implement the crypto::Session trait for our wrapper
352impl crate::crypto::Session for ExtensionAwareTlsSession {
353    fn initial_keys(
354        &self,
355        dst_cid: &crate::ConnectionId,
356        side: crate::Side,
357    ) -> crate::crypto::Keys {
358        self.inner_session.initial_keys(dst_cid, side)
359    }
360
361    fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
362        self.inner_session.handshake_data()
363    }
364
365    fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
366        self.inner_session.peer_identity()
367    }
368
369    fn early_crypto(
370        &self,
371    ) -> Option<(
372        Box<dyn crate::crypto::HeaderKey>,
373        Box<dyn crate::crypto::PacketKey>,
374    )> {
375        self.inner_session.early_crypto()
376    }
377
378    fn early_data_accepted(&self) -> Option<bool> {
379        self.inner_session.early_data_accepted()
380    }
381
382    fn is_handshaking(&self) -> bool {
383        self.inner_session.is_handshaking()
384    }
385
386    fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
387        let result = self.inner_session.read_handshake(buf)?;
388
389        // Check if handshake is complete
390        if result && !self.handshake_complete && !self.is_handshaking() {
391            self.handshake_complete = true;
392            self.extension_hooks
393                .on_handshake_complete(&self.conn_id, self.is_client);
394        }
395
396        Ok(result)
397    }
398
399    fn transport_parameters(
400        &self,
401    ) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError>
402    {
403        self.inner_session.transport_parameters()
404    }
405
406    fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
407        self.inner_session.write_handshake(buf)
408    }
409
410    fn next_1rtt_keys(
411        &mut self,
412    ) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
413        self.inner_session.next_1rtt_keys()
414    }
415
416    fn is_valid_retry(
417        &self,
418        orig_dst_cid: &crate::ConnectionId,
419        header: &[u8],
420        payload: &[u8],
421    ) -> bool {
422        self.inner_session
423            .is_valid_retry(orig_dst_cid, header, payload)
424    }
425
426    fn export_keying_material(
427        &self,
428        output: &mut [u8],
429        label: &[u8],
430        context: &[u8],
431    ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
432        self.inner_session
433            .export_keying_material(output, label, context)
434    }
435}
436
437/// Enhanced QUIC client config with RFC 7250 support
438pub struct Rfc7250QuicClientConfig {
439    /// Base QUIC client config
440    base_config: Arc<dyn QuicClientConfig>,
441    /// Extension context for certificate type negotiation
442    extension_context: Arc<SimulatedExtensionContext>,
443}
444
445impl Rfc7250QuicClientConfig {
446    /// Create a new RFC 7250 aware QUIC client config
447    pub fn new(
448        base_config: Arc<dyn QuicClientConfig>,
449        preferences: CertificateTypePreferences,
450    ) -> Self {
451        Self {
452            base_config,
453            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
454        }
455    }
456}
457
458impl QuicClientConfig for Rfc7250QuicClientConfig {
459    fn start_session(
460        self: Arc<Self>,
461        version: u32,
462        server_name: &str,
463        params: &crate::transport_parameters::TransportParameters,
464    ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
465        // Create the base session
466        let inner_session = self
467            .base_config
468            .clone()
469            .start_session(version, server_name, params)?;
470
471        // Create connection ID for this session
472        let conn_id = format!(
473            "client-{}-{}",
474            server_name,
475            std::time::SystemTime::now()
476                .duration_since(std::time::UNIX_EPOCH)
477                .unwrap()
478                .as_nanos()
479        );
480
481        // Create wrapper with extension hooks
482        Ok(Box::new(ExtensionAwareTlsSession::new(
483            inner_session,
484            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
485            conn_id,
486            true, // is_client
487        )))
488    }
489}
490
491/// Enhanced QUIC server config with RFC 7250 support
492pub struct Rfc7250QuicServerConfig {
493    /// Base QUIC server config
494    base_config: Arc<dyn QuicServerConfig>,
495    /// Extension context for certificate type negotiation
496    extension_context: Arc<SimulatedExtensionContext>,
497}
498
499impl Rfc7250QuicServerConfig {
500    /// Create a new RFC 7250 aware QUIC server config
501    pub fn new(
502        base_config: Arc<dyn QuicServerConfig>,
503        preferences: CertificateTypePreferences,
504    ) -> Self {
505        Self {
506            base_config,
507            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
508        }
509    }
510}
511
512impl QuicServerConfig for Rfc7250QuicServerConfig {
513    fn start_session(
514        self: Arc<Self>,
515        version: u32,
516        params: &crate::transport_parameters::TransportParameters,
517    ) -> Box<dyn crate::crypto::Session> {
518        // Create the base session
519        let inner_session = self.base_config.clone().start_session(version, params);
520
521        // Create connection ID for this session
522        let conn_id = format!(
523            "server-{}",
524            std::time::SystemTime::now()
525                .duration_since(std::time::UNIX_EPOCH)
526                .unwrap()
527                .as_nanos()
528        );
529
530        // Create wrapper with extension hooks
531        Box::new(ExtensionAwareTlsSession::new(
532            inner_session,
533            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
534            conn_id,
535            false, // is_client = false for server
536        ))
537    }
538
539    fn initial_keys(
540        &self,
541        version: u32,
542        dst_cid: &crate::ConnectionId,
543    ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
544        self.base_config.initial_keys(version, dst_cid)
545    }
546
547    fn retry_tag(
548        &self,
549        version: u32,
550        orig_dst_cid: &crate::ConnectionId,
551        packet: &[u8],
552    ) -> [u8; 16] {
553        self.base_config.retry_tag(version, orig_dst_cid, packet)
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    use super::super::tls_extensions::CertificateType;
560    use super::*;
561
562    #[test]
563    fn test_simulated_negotiation_flow() {
564        // Client side
565        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
566        let client_ctx = SimulatedExtensionContext::new(client_prefs);
567
568        // Server side
569        let server_prefs = CertificateTypePreferences::raw_public_key_only();
570        let server_ctx = SimulatedExtensionContext::new(server_prefs);
571
572        let conn_id = "test-connection";
573
574        // Client sends preferences
575        let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
576        assert!(client_types.is_some());
577        assert!(server_types.is_some());
578
579        // Server receives and processes
580        server_ctx.simulate_send_preferences(conn_id);
581        server_ctx
582            .simulate_receive_preferences(conn_id, client_types.as_deref(), server_types.as_deref())
583            .unwrap();
584
585        // Server completes negotiation
586        let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
587        assert!(server_result.is_raw_public_key_only());
588
589        // Client receives server's response (simulated)
590        let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
591        let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
592
593        client_ctx
594            .simulate_receive_preferences(
595                conn_id,
596                Some(&server_response_client),
597                Some(&server_response_server),
598            )
599            .unwrap();
600
601        // Client completes negotiation
602        let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
603        assert_eq!(client_result, server_result);
604    }
605
606    #[test]
607    fn test_wrapper_configs() {
608        let client_config = ClientConfig::builder()
609            .dangerous()
610            .with_custom_certificate_verifier(Arc::new(
611                crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new()),
612            ))
613            .with_no_client_auth();
614
615        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
616        let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
617
618        let conn_id = "test-conn";
619        let extensions = wrapped_client.get_client_hello_extensions(conn_id);
620
621        assert_eq!(extensions.len(), 2);
622        assert_eq!(extensions[0].0, 47); // client_certificate_type
623        assert_eq!(extensions[1].0, 48); // server_certificate_type
624    }
625}