ant_quic/crypto/
tls_extension_simulation.rs

1//! TLS Extension Simulation for RFC 7250 Raw Public Keys
2//!
3//! Since rustls 0.23.x doesn't expose APIs for custom TLS extensions,
4//! this module simulates the RFC 7250 certificate type negotiation
5//! through alternative mechanisms that work within rustls constraints.
6
7use std::sync::{Arc, Mutex};
8use std::collections::HashMap;
9use rustls::{ClientConfig, ServerConfig};
10use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
11
12use super::tls_extensions::{
13    CertificateTypeList, CertificateTypePreferences,
14    NegotiationResult, TlsExtensionError,
15};
16
17/// Trait for hooking into TLS handshake events
18pub trait TlsExtensionHooks: Send + Sync {
19    /// Called when the handshake is complete
20    fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
21    
22    /// Called to get extension data for ClientHello
23    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
24    
25    /// Called to process ServerHello extensions
26    fn process_server_hello_extensions(&self, conn_id: &str, extensions: &[(u16, Vec<u8>)]) -> Result<(), TlsExtensionError>;
27    
28    /// Get the negotiation result for a connection
29    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
30}
31
32/// Simulated TLS extension context for certificate type negotiation
33#[derive(Debug)]
34pub struct SimulatedExtensionContext {
35    /// Active negotiations indexed by connection ID
36    negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
37    /// Local preferences for this endpoint
38    local_preferences: CertificateTypePreferences,
39}
40
41#[derive(Debug, Clone)]
42struct NegotiationState {
43    local_preferences: CertificateTypePreferences,
44    remote_client_types: Option<CertificateTypeList>,
45    remote_server_types: Option<CertificateTypeList>,
46    result: Option<NegotiationResult>,
47}
48
49impl SimulatedExtensionContext {
50    /// Create a new simulated extension context
51    pub fn new(preferences: CertificateTypePreferences) -> Self {
52        Self {
53            negotiations: Arc::new(Mutex::new(HashMap::new())),
54            local_preferences: preferences,
55        }
56    }
57
58    /// Simulate sending certificate type preferences
59    /// In reality, this would be sent in ClientHello/ServerHello extensions
60    pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
61        let mut negotiations = self.negotiations.lock().unwrap();
62        
63        let state = NegotiationState {
64            local_preferences: self.local_preferences.clone(),
65            remote_client_types: None,
66            remote_server_types: None,
67            result: None,
68        };
69        
70        negotiations.insert(conn_id.to_string(), state);
71        
72        // Simulate extension data that would be sent
73        let client_ext_data = self.local_preferences.client_types.to_bytes();
74        let server_ext_data = self.local_preferences.server_types.to_bytes();
75        
76        (Some(client_ext_data), Some(server_ext_data))
77    }
78
79    /// Simulate receiving certificate type preferences from peer
80    pub fn simulate_receive_preferences(
81        &self,
82        conn_id: &str,
83        client_types_data: Option<&[u8]>,
84        server_types_data: Option<&[u8]>,
85    ) -> Result<(), TlsExtensionError> {
86        let mut negotiations = self.negotiations.lock().unwrap();
87        
88        let state = negotiations.get_mut(conn_id)
89            .ok_or_else(|| TlsExtensionError::InvalidExtensionData(
90                format!("No negotiation state for connection {}", conn_id)
91            ))?;
92        
93        if let Some(data) = client_types_data {
94            state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
95        }
96        
97        if let Some(data) = server_types_data {
98            state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
99        }
100        
101        Ok(())
102    }
103
104    /// Complete the negotiation and get the result
105    pub fn complete_negotiation(&self, conn_id: &str) -> Result<NegotiationResult, TlsExtensionError> {
106        let mut negotiations = self.negotiations.lock().unwrap();
107        
108        let state = negotiations.get_mut(conn_id)
109            .ok_or_else(|| TlsExtensionError::InvalidExtensionData(
110                format!("No negotiation state for connection {}", conn_id)
111            ))?;
112        
113        if let Some(result) = &state.result {
114            return Ok(result.clone());
115        }
116        
117        let result = state.local_preferences.negotiate(
118            state.remote_client_types.as_ref(),
119            state.remote_server_types.as_ref(),
120        )?;
121        
122        state.result = Some(result.clone());
123        Ok(result)
124    }
125
126    /// Clean up negotiation state for a connection
127    pub fn cleanup_connection(&self, conn_id: &str) {
128        let mut negotiations = self.negotiations.lock().unwrap();
129        negotiations.remove(conn_id);
130    }
131}
132
133impl TlsExtensionHooks for SimulatedExtensionContext {
134    fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
135        // Try to complete negotiation if not already done
136        let _ = self.complete_negotiation(conn_id);
137    }
138    
139    fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
140        let (client_types, server_types) = self.simulate_send_preferences(conn_id);
141        
142        let mut extensions = Vec::new();
143        
144        if let Some(data) = client_types {
145            extensions.push((47, data)); // client_certificate_type
146        }
147        
148        if let Some(data) = server_types {
149            extensions.push((48, data)); // server_certificate_type
150        }
151        
152        extensions
153    }
154    
155    fn process_server_hello_extensions(&self, conn_id: &str, extensions: &[(u16, Vec<u8>)]) -> Result<(), TlsExtensionError> {
156        let mut client_types_data = None;
157        let mut server_types_data = None;
158        
159        for (ext_id, data) in extensions {
160            match *ext_id {
161                47 => client_types_data = Some(data.as_slice()),
162                48 => server_types_data = Some(data.as_slice()),
163                _ => {}
164            }
165        }
166        
167        self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
168    }
169    
170    fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
171        self.complete_negotiation(conn_id).ok()
172    }
173}
174
175/// Wrapper for ClientConfig that simulates RFC 7250 extension behavior
176pub struct Rfc7250ClientConfig {
177    inner: Arc<ClientConfig>,
178    extension_context: Arc<SimulatedExtensionContext>,
179}
180
181impl Rfc7250ClientConfig {
182    /// Create a new RFC 7250 aware client configuration
183    pub fn new(
184        base_config: ClientConfig,
185        preferences: CertificateTypePreferences,
186    ) -> Self {
187        Self {
188            inner: Arc::new(base_config),
189            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
190        }
191    }
192
193    /// Get the inner rustls ClientConfig
194    pub fn inner(&self) -> &Arc<ClientConfig> {
195        &self.inner
196    }
197
198    /// Get the extension context for negotiation
199    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
200        &self.extension_context
201    }
202
203    /// Simulate the ClientHello extension data
204    pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
205        let (client_types, server_types) = self.extension_context.simulate_send_preferences(conn_id);
206        
207        let mut extensions = Vec::new();
208        
209        if let Some(data) = client_types {
210            extensions.push((47, data)); // client_certificate_type
211        }
212        
213        if let Some(data) = server_types {
214            extensions.push((48, data)); // server_certificate_type
215        }
216        
217        extensions
218    }
219}
220
221/// Wrapper for ServerConfig that simulates RFC 7250 extension behavior
222pub struct Rfc7250ServerConfig {
223    inner: Arc<ServerConfig>,
224    extension_context: Arc<SimulatedExtensionContext>,
225}
226
227impl Rfc7250ServerConfig {
228    /// Create a new RFC 7250 aware server configuration
229    pub fn new(
230        base_config: ServerConfig,
231        preferences: CertificateTypePreferences,
232    ) -> Self {
233        Self {
234            inner: Arc::new(base_config),
235            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
236        }
237    }
238
239    /// Get the inner rustls ServerConfig
240    pub fn inner(&self) -> &Arc<ServerConfig> {
241        &self.inner
242    }
243
244    /// Get the extension context for negotiation
245    pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
246        &self.extension_context
247    }
248
249    /// Process ClientHello extensions and prepare ServerHello response
250    pub fn process_client_hello_extensions(
251        &self,
252        conn_id: &str,
253        client_extensions: &[(u16, Vec<u8>)],
254    ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
255        // First, register this connection
256        self.extension_context.simulate_send_preferences(conn_id);
257        
258        // Process client's certificate type preferences
259        let mut client_types_data = None;
260        let mut server_types_data = None;
261        
262        for (ext_id, data) in client_extensions {
263            match *ext_id {
264                47 => client_types_data = Some(data.as_slice()),
265                48 => server_types_data = Some(data.as_slice()),
266                _ => {}
267            }
268        }
269        
270        // Store remote preferences
271        self.extension_context.simulate_receive_preferences(
272            conn_id,
273            client_types_data,
274            server_types_data,
275        )?;
276        
277        // Complete negotiation
278        let result = self.extension_context.complete_negotiation(conn_id)?;
279        
280        // Prepare ServerHello extensions with negotiated types
281        let mut response_extensions = Vec::new();
282        
283        // Send back single negotiated type for each extension
284        response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
285        response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
286        
287        Ok(response_extensions)
288    }
289}
290
291/// Helper to determine if we should use Raw Public Key based on negotiation
292pub fn should_use_raw_public_key(
293    negotiation_result: &NegotiationResult,
294    is_client: bool,
295) -> bool {
296    if is_client {
297        negotiation_result.client_cert_type.is_raw_public_key()
298    } else {
299        negotiation_result.server_cert_type.is_raw_public_key()
300    }
301}
302
303/// Create a connection identifier for simulation purposes
304pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
305    format!("{}-{}", local_addr, remote_addr)
306}
307
308/// Wrapper for TlsSession that integrates with TlsExtensionHooks
309pub struct ExtensionAwareTlsSession {
310    /// The underlying TLS session
311    inner_session: Box<dyn crate::crypto::Session>,
312    /// Extension hooks for certificate type negotiation
313    extension_hooks: Arc<dyn TlsExtensionHooks>,
314    /// Connection identifier
315    conn_id: String,
316    /// Whether this is a client session
317    is_client: bool,
318    /// Whether handshake is complete
319    handshake_complete: bool,
320}
321
322impl ExtensionAwareTlsSession {
323    /// Create a new extension-aware TLS session
324    pub fn new(
325        inner_session: Box<dyn crate::crypto::Session>,
326        extension_hooks: Arc<dyn TlsExtensionHooks>,
327        conn_id: String,
328        is_client: bool,
329    ) -> Self {
330        Self {
331            inner_session,
332            extension_hooks,
333            conn_id,
334            is_client,
335            handshake_complete: false,
336        }
337    }
338    
339    /// Get the negotiation result if available
340    pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
341        self.extension_hooks.get_negotiation_result(&self.conn_id)
342    }
343}
344
345/// Implement the crypto::Session trait for our wrapper
346impl crate::crypto::Session for ExtensionAwareTlsSession {
347    fn initial_keys(&self, dst_cid: &crate::ConnectionId, side: crate::Side) -> crate::crypto::Keys {
348        self.inner_session.initial_keys(dst_cid, side)
349    }
350    
351    fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
352        self.inner_session.handshake_data()
353    }
354    
355    fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
356        self.inner_session.peer_identity()
357    }
358    
359    fn early_crypto(&self) -> Option<(Box<dyn crate::crypto::HeaderKey>, Box<dyn crate::crypto::PacketKey>)> {
360        self.inner_session.early_crypto()
361    }
362    
363    fn early_data_accepted(&self) -> Option<bool> {
364        self.inner_session.early_data_accepted()
365    }
366    
367    fn is_handshaking(&self) -> bool {
368        self.inner_session.is_handshaking()
369    }
370    
371    fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
372        let result = self.inner_session.read_handshake(buf)?;
373        
374        // Check if handshake is complete
375        if result && !self.handshake_complete && !self.is_handshaking() {
376            self.handshake_complete = true;
377            self.extension_hooks.on_handshake_complete(&self.conn_id, self.is_client);
378        }
379        
380        Ok(result)
381    }
382    
383    fn transport_parameters(&self) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError> {
384        self.inner_session.transport_parameters()
385    }
386    
387    fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
388        self.inner_session.write_handshake(buf)
389    }
390    
391    fn next_1rtt_keys(&mut self) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
392        self.inner_session.next_1rtt_keys()
393    }
394    
395    fn is_valid_retry(&self, orig_dst_cid: &crate::ConnectionId, header: &[u8], payload: &[u8]) -> bool {
396        self.inner_session.is_valid_retry(orig_dst_cid, header, payload)
397    }
398    
399    fn export_keying_material(
400        &self,
401        output: &mut [u8],
402        label: &[u8],
403        context: &[u8],
404    ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
405        self.inner_session.export_keying_material(output, label, context)
406    }
407}
408
409/// Enhanced QUIC client config with RFC 7250 support
410pub struct Rfc7250QuicClientConfig {
411    /// Base QUIC client config
412    base_config: Arc<dyn QuicClientConfig>,
413    /// Extension context for certificate type negotiation
414    extension_context: Arc<SimulatedExtensionContext>,
415}
416
417impl Rfc7250QuicClientConfig {
418    /// Create a new RFC 7250 aware QUIC client config
419    pub fn new(
420        base_config: Arc<dyn QuicClientConfig>,
421        preferences: CertificateTypePreferences,
422    ) -> Self {
423        Self {
424            base_config,
425            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
426        }
427    }
428}
429
430impl QuicClientConfig for Rfc7250QuicClientConfig {
431    fn start_session(
432        self: Arc<Self>,
433        version: u32,
434        server_name: &str,
435        params: &crate::transport_parameters::TransportParameters,
436    ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
437        // Create the base session
438        let inner_session = self.base_config.clone().start_session(version, server_name, params)?;
439        
440        // Create connection ID for this session
441        let conn_id = format!("client-{}-{}", server_name, std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
442        
443        // Create wrapper with extension hooks
444        Ok(Box::new(ExtensionAwareTlsSession::new(
445            inner_session,
446            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
447            conn_id,
448            true, // is_client
449        )))
450    }
451}
452
453/// Enhanced QUIC server config with RFC 7250 support
454pub struct Rfc7250QuicServerConfig {
455    /// Base QUIC server config
456    base_config: Arc<dyn QuicServerConfig>,
457    /// Extension context for certificate type negotiation
458    extension_context: Arc<SimulatedExtensionContext>,
459}
460
461impl Rfc7250QuicServerConfig {
462    /// Create a new RFC 7250 aware QUIC server config
463    pub fn new(
464        base_config: Arc<dyn QuicServerConfig>,
465        preferences: CertificateTypePreferences,
466    ) -> Self {
467        Self {
468            base_config,
469            extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
470        }
471    }
472}
473
474impl QuicServerConfig for Rfc7250QuicServerConfig {
475    fn start_session(
476        self: Arc<Self>,
477        version: u32,
478        params: &crate::transport_parameters::TransportParameters,
479    ) -> Box<dyn crate::crypto::Session> {
480        // Create the base session
481        let inner_session = self.base_config.clone().start_session(version, params);
482        
483        // Create connection ID for this session
484        let conn_id = format!("server-{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
485        
486        // Create wrapper with extension hooks
487        Box::new(ExtensionAwareTlsSession::new(
488            inner_session,
489            self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
490            conn_id,
491            false, // is_client = false for server
492        ))
493    }
494    
495    fn initial_keys(
496        &self,
497        version: u32,
498        dst_cid: &crate::ConnectionId,
499    ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
500        self.base_config.initial_keys(version, dst_cid)
501    }
502    
503    fn retry_tag(&self, version: u32, orig_dst_cid: &crate::ConnectionId, packet: &[u8]) -> [u8; 16] {
504        self.base_config.retry_tag(version, orig_dst_cid, packet)
505    }
506}
507
508#[cfg(test)]
509mod tests {
510    use super::*;
511    use super::super::tls_extensions::CertificateType;
512
513    #[test]
514    fn test_simulated_negotiation_flow() {
515        // Client side
516        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
517        let client_ctx = SimulatedExtensionContext::new(client_prefs);
518        
519        // Server side
520        let server_prefs = CertificateTypePreferences::raw_public_key_only();
521        let server_ctx = SimulatedExtensionContext::new(server_prefs);
522        
523        let conn_id = "test-connection";
524        
525        // Client sends preferences
526        let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
527        assert!(client_types.is_some());
528        assert!(server_types.is_some());
529        
530        // Server receives and processes
531        server_ctx.simulate_send_preferences(conn_id);
532        server_ctx.simulate_receive_preferences(
533            conn_id,
534            client_types.as_deref(),
535            server_types.as_deref(),
536        ).unwrap();
537        
538        // Server completes negotiation
539        let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
540        assert!(server_result.is_raw_public_key_only());
541        
542        // Client receives server's response (simulated)
543        let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
544        let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
545        
546        client_ctx.simulate_receive_preferences(
547            conn_id,
548            Some(&server_response_client),
549            Some(&server_response_server),
550        ).unwrap();
551        
552        // Client completes negotiation
553        let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
554        assert_eq!(client_result, server_result);
555    }
556
557    #[test]
558    fn test_wrapper_configs() {
559        let client_config = ClientConfig::builder()
560            .dangerous()
561            .with_custom_certificate_verifier(Arc::new(crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new())))
562            .with_no_client_auth();
563        
564        let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
565        let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
566        
567        let conn_id = "test-conn";
568        let extensions = wrapped_client.get_client_hello_extensions(conn_id);
569        
570        assert_eq!(extensions.len(), 2);
571        assert_eq!(extensions[0].0, 47); // client_certificate_type
572        assert_eq!(extensions[1].0, 48); // server_certificate_type
573    }
574}