1use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
8use rustls::{ClientConfig, ServerConfig};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12use super::tls_extensions::{
13 CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
14};
15
16pub trait TlsExtensionHooks: Send + Sync {
18 fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
20
21 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
23
24 fn process_server_hello_extensions(
26 &self,
27 conn_id: &str,
28 extensions: &[(u16, Vec<u8>)],
29 ) -> Result<(), TlsExtensionError>;
30
31 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
33}
34
35#[derive(Debug)]
37pub struct SimulatedExtensionContext {
38 negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
40 local_preferences: CertificateTypePreferences,
42}
43
44#[derive(Debug, Clone)]
45struct NegotiationState {
46 local_preferences: CertificateTypePreferences,
47 remote_client_types: Option<CertificateTypeList>,
48 remote_server_types: Option<CertificateTypeList>,
49 result: Option<NegotiationResult>,
50}
51
52impl SimulatedExtensionContext {
53 pub fn new(preferences: CertificateTypePreferences) -> Self {
55 Self {
56 negotiations: Arc::new(Mutex::new(HashMap::new())),
57 local_preferences: preferences,
58 }
59 }
60
61 pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
64 let mut negotiations = self.negotiations.lock().unwrap();
65
66 let state = NegotiationState {
67 local_preferences: self.local_preferences.clone(),
68 remote_client_types: None,
69 remote_server_types: None,
70 result: None,
71 };
72
73 negotiations.insert(conn_id.to_string(), state);
74
75 let client_ext_data = self.local_preferences.client_types.to_bytes();
77 let server_ext_data = self.local_preferences.server_types.to_bytes();
78
79 (Some(client_ext_data), Some(server_ext_data))
80 }
81
82 pub fn simulate_receive_preferences(
84 &self,
85 conn_id: &str,
86 client_types_data: Option<&[u8]>,
87 server_types_data: Option<&[u8]>,
88 ) -> Result<(), TlsExtensionError> {
89 let mut negotiations = self.negotiations.lock().unwrap();
90
91 let state = negotiations.get_mut(conn_id).ok_or_else(|| {
92 TlsExtensionError::InvalidExtensionData(format!(
93 "No negotiation state for connection {conn_id}"
94 ))
95 })?;
96
97 if let Some(data) = client_types_data {
98 state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
99 }
100
101 if let Some(data) = server_types_data {
102 state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
103 }
104
105 Ok(())
106 }
107
108 pub fn complete_negotiation(
110 &self,
111 conn_id: &str,
112 ) -> Result<NegotiationResult, TlsExtensionError> {
113 let mut negotiations = self.negotiations.lock().unwrap();
114
115 let state = negotiations.get_mut(conn_id).ok_or_else(|| {
116 TlsExtensionError::InvalidExtensionData(format!(
117 "No negotiation state for connection {conn_id}"
118 ))
119 })?;
120
121 if let Some(result) = &state.result {
122 return Ok(result.clone());
123 }
124
125 let result = state.local_preferences.negotiate(
126 state.remote_client_types.as_ref(),
127 state.remote_server_types.as_ref(),
128 )?;
129
130 state.result = Some(result.clone());
131 Ok(result)
132 }
133
134 pub fn cleanup_connection(&self, conn_id: &str) {
136 let mut negotiations = self.negotiations.lock().unwrap();
137 negotiations.remove(conn_id);
138 }
139}
140
141impl TlsExtensionHooks for SimulatedExtensionContext {
142 fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
143 let _ = self.complete_negotiation(conn_id);
145 }
146
147 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
148 let (client_types, server_types) = self.simulate_send_preferences(conn_id);
149
150 let mut extensions = Vec::new();
151
152 if let Some(data) = client_types {
153 extensions.push((47, data)); }
155
156 if let Some(data) = server_types {
157 extensions.push((48, data)); }
159
160 extensions
161 }
162
163 fn process_server_hello_extensions(
164 &self,
165 conn_id: &str,
166 extensions: &[(u16, Vec<u8>)],
167 ) -> Result<(), TlsExtensionError> {
168 let mut client_types_data = None;
169 let mut server_types_data = None;
170
171 for (ext_id, data) in extensions {
172 match *ext_id {
173 47 => client_types_data = Some(data.as_slice()),
174 48 => server_types_data = Some(data.as_slice()),
175 _ => {}
176 }
177 }
178
179 self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
180 }
181
182 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
183 self.complete_negotiation(conn_id).ok()
184 }
185}
186
187pub struct Rfc7250ClientConfig {
189 inner: Arc<ClientConfig>,
190 extension_context: Arc<SimulatedExtensionContext>,
191}
192
193impl Rfc7250ClientConfig {
194 pub fn new(base_config: ClientConfig, preferences: CertificateTypePreferences) -> Self {
196 Self {
197 inner: Arc::new(base_config),
198 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
199 }
200 }
201
202 pub fn inner(&self) -> &Arc<ClientConfig> {
204 &self.inner
205 }
206
207 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
209 &self.extension_context
210 }
211
212 pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
214 let (client_types, server_types) =
215 self.extension_context.simulate_send_preferences(conn_id);
216
217 let mut extensions = Vec::new();
218
219 if let Some(data) = client_types {
220 extensions.push((47, data)); }
222
223 if let Some(data) = server_types {
224 extensions.push((48, data)); }
226
227 extensions
228 }
229}
230
231pub struct Rfc7250ServerConfig {
233 inner: Arc<ServerConfig>,
234 extension_context: Arc<SimulatedExtensionContext>,
235}
236
237impl Rfc7250ServerConfig {
238 pub fn new(base_config: ServerConfig, preferences: CertificateTypePreferences) -> Self {
240 Self {
241 inner: Arc::new(base_config),
242 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
243 }
244 }
245
246 pub fn inner(&self) -> &Arc<ServerConfig> {
248 &self.inner
249 }
250
251 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
253 &self.extension_context
254 }
255
256 pub fn process_client_hello_extensions(
258 &self,
259 conn_id: &str,
260 client_extensions: &[(u16, Vec<u8>)],
261 ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
262 self.extension_context.simulate_send_preferences(conn_id);
264
265 let mut client_types_data = None;
267 let mut server_types_data = None;
268
269 for (ext_id, data) in client_extensions {
270 match *ext_id {
271 47 => client_types_data = Some(data.as_slice()),
272 48 => server_types_data = Some(data.as_slice()),
273 _ => {}
274 }
275 }
276
277 self.extension_context.simulate_receive_preferences(
279 conn_id,
280 client_types_data,
281 server_types_data,
282 )?;
283
284 let result = self.extension_context.complete_negotiation(conn_id)?;
286
287 let mut response_extensions = Vec::new();
289
290 response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
292 response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
293
294 Ok(response_extensions)
295 }
296}
297
298pub fn should_use_raw_public_key(negotiation_result: &NegotiationResult, is_client: bool) -> bool {
300 if is_client {
301 negotiation_result.client_cert_type.is_raw_public_key()
302 } else {
303 negotiation_result.server_cert_type.is_raw_public_key()
304 }
305}
306
307pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
309 format!("{local_addr}-{remote_addr}")
310}
311
312pub struct ExtensionAwareTlsSession {
314 inner_session: Box<dyn crate::crypto::Session>,
316 extension_hooks: Arc<dyn TlsExtensionHooks>,
318 conn_id: String,
320 is_client: bool,
322 handshake_complete: bool,
324}
325
326impl ExtensionAwareTlsSession {
327 pub fn new(
329 inner_session: Box<dyn crate::crypto::Session>,
330 extension_hooks: Arc<dyn TlsExtensionHooks>,
331 conn_id: String,
332 is_client: bool,
333 ) -> Self {
334 Self {
335 inner_session,
336 extension_hooks,
337 conn_id,
338 is_client,
339 handshake_complete: false,
340 }
341 }
342
343 pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
345 self.extension_hooks.get_negotiation_result(&self.conn_id)
346 }
347}
348
349impl crate::crypto::Session for ExtensionAwareTlsSession {
351 fn initial_keys(
352 &self,
353 dst_cid: &crate::ConnectionId,
354 side: crate::Side,
355 ) -> crate::crypto::Keys {
356 self.inner_session.initial_keys(dst_cid, side)
357 }
358
359 fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
360 self.inner_session.handshake_data()
361 }
362
363 fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
364 self.inner_session.peer_identity()
365 }
366
367 fn early_crypto(
368 &self,
369 ) -> Option<(
370 Box<dyn crate::crypto::HeaderKey>,
371 Box<dyn crate::crypto::PacketKey>,
372 )> {
373 self.inner_session.early_crypto()
374 }
375
376 fn early_data_accepted(&self) -> Option<bool> {
377 self.inner_session.early_data_accepted()
378 }
379
380 fn is_handshaking(&self) -> bool {
381 self.inner_session.is_handshaking()
382 }
383
384 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
385 let result = self.inner_session.read_handshake(buf)?;
386
387 if result && !self.handshake_complete && !self.is_handshaking() {
389 self.handshake_complete = true;
390 self.extension_hooks
391 .on_handshake_complete(&self.conn_id, self.is_client);
392 }
393
394 Ok(result)
395 }
396
397 fn transport_parameters(
398 &self,
399 ) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError>
400 {
401 self.inner_session.transport_parameters()
402 }
403
404 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
405 self.inner_session.write_handshake(buf)
406 }
407
408 fn next_1rtt_keys(
409 &mut self,
410 ) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
411 self.inner_session.next_1rtt_keys()
412 }
413
414 fn is_valid_retry(
415 &self,
416 orig_dst_cid: &crate::ConnectionId,
417 header: &[u8],
418 payload: &[u8],
419 ) -> bool {
420 self.inner_session
421 .is_valid_retry(orig_dst_cid, header, payload)
422 }
423
424 fn export_keying_material(
425 &self,
426 output: &mut [u8],
427 label: &[u8],
428 context: &[u8],
429 ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
430 self.inner_session
431 .export_keying_material(output, label, context)
432 }
433}
434
435pub struct Rfc7250QuicClientConfig {
437 base_config: Arc<dyn QuicClientConfig>,
439 extension_context: Arc<SimulatedExtensionContext>,
441}
442
443impl Rfc7250QuicClientConfig {
444 pub fn new(
446 base_config: Arc<dyn QuicClientConfig>,
447 preferences: CertificateTypePreferences,
448 ) -> Self {
449 Self {
450 base_config,
451 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
452 }
453 }
454}
455
456impl QuicClientConfig for Rfc7250QuicClientConfig {
457 fn start_session(
458 self: Arc<Self>,
459 version: u32,
460 server_name: &str,
461 params: &crate::transport_parameters::TransportParameters,
462 ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
463 let inner_session = self
465 .base_config
466 .clone()
467 .start_session(version, server_name, params)?;
468
469 let conn_id = format!(
471 "client-{}-{}",
472 server_name,
473 std::time::SystemTime::now()
474 .duration_since(std::time::UNIX_EPOCH)
475 .unwrap()
476 .as_nanos()
477 );
478
479 Ok(Box::new(ExtensionAwareTlsSession::new(
481 inner_session,
482 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
483 conn_id,
484 true, )))
486 }
487}
488
489pub struct Rfc7250QuicServerConfig {
491 base_config: Arc<dyn QuicServerConfig>,
493 extension_context: Arc<SimulatedExtensionContext>,
495}
496
497impl Rfc7250QuicServerConfig {
498 pub fn new(
500 base_config: Arc<dyn QuicServerConfig>,
501 preferences: CertificateTypePreferences,
502 ) -> Self {
503 Self {
504 base_config,
505 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
506 }
507 }
508}
509
510impl QuicServerConfig for Rfc7250QuicServerConfig {
511 fn start_session(
512 self: Arc<Self>,
513 version: u32,
514 params: &crate::transport_parameters::TransportParameters,
515 ) -> Box<dyn crate::crypto::Session> {
516 let inner_session = self.base_config.clone().start_session(version, params);
518
519 let conn_id = format!(
521 "server-{}",
522 std::time::SystemTime::now()
523 .duration_since(std::time::UNIX_EPOCH)
524 .unwrap()
525 .as_nanos()
526 );
527
528 Box::new(ExtensionAwareTlsSession::new(
530 inner_session,
531 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
532 conn_id,
533 false, ))
535 }
536
537 fn initial_keys(
538 &self,
539 version: u32,
540 dst_cid: &crate::ConnectionId,
541 ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
542 self.base_config.initial_keys(version, dst_cid)
543 }
544
545 fn retry_tag(
546 &self,
547 version: u32,
548 orig_dst_cid: &crate::ConnectionId,
549 packet: &[u8],
550 ) -> [u8; 16] {
551 self.base_config.retry_tag(version, orig_dst_cid, packet)
552 }
553}
554
555#[cfg(test)]
556mod tests {
557 use super::super::tls_extensions::CertificateType;
558 use super::*;
559 use std::sync::Once;
560
561 static INIT: Once = Once::new();
562
563 fn ensure_crypto_provider() {
565 INIT.call_once(|| {
566 #[cfg(feature = "rustls-aws-lc-rs")]
568 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
569
570 #[cfg(feature = "rustls-ring")]
571 let _ = rustls::crypto::ring::default_provider().install_default();
572 });
573 }
574
575 #[test]
576 fn test_simulated_negotiation_flow() {
577 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
579 let client_ctx = SimulatedExtensionContext::new(client_prefs);
580
581 let server_prefs = CertificateTypePreferences::raw_public_key_only();
583 let server_ctx = SimulatedExtensionContext::new(server_prefs);
584
585 let conn_id = "test-connection";
586
587 let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
589 assert!(client_types.is_some());
590 assert!(server_types.is_some());
591
592 server_ctx.simulate_send_preferences(conn_id);
594 server_ctx
595 .simulate_receive_preferences(conn_id, client_types.as_deref(), server_types.as_deref())
596 .unwrap();
597
598 let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
600 assert!(server_result.is_raw_public_key_only());
601
602 let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
604 let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
605
606 client_ctx
607 .simulate_receive_preferences(
608 conn_id,
609 Some(&server_response_client),
610 Some(&server_response_server),
611 )
612 .unwrap();
613
614 let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
616 assert_eq!(client_result, server_result);
617 }
618
619 #[test]
620 fn test_wrapper_configs() {
621 ensure_crypto_provider();
622 let client_config = ClientConfig::builder()
623 .dangerous()
624 .with_custom_certificate_verifier(Arc::new(
625 crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new()),
626 ))
627 .with_no_client_auth();
628
629 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
630 let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
631
632 let conn_id = "test-conn";
633 let extensions = wrapped_client.get_client_hello_extensions(conn_id);
634
635 assert_eq!(extensions.len(), 2);
636 assert_eq!(extensions[0].0, 47); assert_eq!(extensions[1].0, 48); }
639}