1use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
15use rustls::{ClientConfig, ServerConfig};
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18
19use super::tls_extensions::{
20 CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
21};
22
23pub trait TlsExtensionHooks: Send + Sync {
25 fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
27
28 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
30
31 fn process_server_hello_extensions(
33 &self,
34 conn_id: &str,
35 extensions: &[(u16, Vec<u8>)],
36 ) -> Result<(), TlsExtensionError>;
37
38 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
40}
41
42#[derive(Debug)]
44pub struct SimulatedExtensionContext {
45 negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
47 local_preferences: CertificateTypePreferences,
49}
50
51#[derive(Debug, Clone)]
52struct NegotiationState {
53 local_preferences: CertificateTypePreferences,
54 remote_client_types: Option<CertificateTypeList>,
55 remote_server_types: Option<CertificateTypeList>,
56 result: Option<NegotiationResult>,
57}
58
59impl SimulatedExtensionContext {
60 pub fn new(preferences: CertificateTypePreferences) -> Self {
62 Self {
63 negotiations: Arc::new(Mutex::new(HashMap::new())),
64 local_preferences: preferences,
65 }
66 }
67
68 #[allow(clippy::unwrap_used, clippy::expect_used)]
71 pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
72 let mut negotiations = self
73 .negotiations
74 .lock()
75 .expect("Mutex poisoning is unexpected in normal operation");
76
77 let state = NegotiationState {
78 local_preferences: self.local_preferences.clone(),
79 remote_client_types: None,
80 remote_server_types: None,
81 result: None,
82 };
83
84 negotiations.insert(conn_id.to_string(), state);
85
86 let client_ext_data = self.local_preferences.client_types.to_bytes();
88 let server_ext_data = self.local_preferences.server_types.to_bytes();
89
90 (Some(client_ext_data), Some(server_ext_data))
91 }
92
93 #[allow(clippy::unwrap_used, clippy::expect_used)]
95 pub fn simulate_receive_preferences(
96 &self,
97 conn_id: &str,
98 client_types_data: Option<&[u8]>,
99 server_types_data: Option<&[u8]>,
100 ) -> Result<(), TlsExtensionError> {
101 let mut negotiations = self
102 .negotiations
103 .lock()
104 .expect("Mutex poisoning is unexpected in normal operation");
105
106 let state = negotiations.get_mut(conn_id).ok_or_else(|| {
107 TlsExtensionError::InvalidExtensionData(format!(
108 "No negotiation state for connection {conn_id}"
109 ))
110 })?;
111
112 if let Some(data) = client_types_data {
113 state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
114 }
115
116 if let Some(data) = server_types_data {
117 state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
118 }
119
120 Ok(())
121 }
122
123 #[allow(clippy::unwrap_used, clippy::expect_used)]
125 pub fn complete_negotiation(
126 &self,
127 conn_id: &str,
128 ) -> Result<NegotiationResult, TlsExtensionError> {
129 let mut negotiations = self
130 .negotiations
131 .lock()
132 .expect("Mutex poisoning is unexpected in normal operation");
133
134 let state = negotiations.get_mut(conn_id).ok_or_else(|| {
135 TlsExtensionError::InvalidExtensionData(format!(
136 "No negotiation state for connection {conn_id}"
137 ))
138 })?;
139
140 if let Some(result) = &state.result {
141 return Ok(result.clone());
142 }
143
144 let result = state.local_preferences.negotiate(
145 state.remote_client_types.as_ref(),
146 state.remote_server_types.as_ref(),
147 )?;
148
149 state.result = Some(result.clone());
150 Ok(result)
151 }
152
153 #[allow(clippy::unwrap_used, clippy::expect_used)]
155 pub fn cleanup_connection(&self, conn_id: &str) {
156 let mut negotiations = self
157 .negotiations
158 .lock()
159 .expect("Mutex poisoning is unexpected in normal operation");
160 negotiations.remove(conn_id);
161 }
162}
163
164impl TlsExtensionHooks for SimulatedExtensionContext {
165 fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
166 let _ = self.complete_negotiation(conn_id);
168 }
169
170 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
171 let (client_types, server_types) = self.simulate_send_preferences(conn_id);
172
173 let mut extensions = Vec::new();
174
175 if let Some(data) = client_types {
176 extensions.push((47, data)); }
178
179 if let Some(data) = server_types {
180 extensions.push((48, data)); }
182
183 extensions
184 }
185
186 fn process_server_hello_extensions(
187 &self,
188 conn_id: &str,
189 extensions: &[(u16, Vec<u8>)],
190 ) -> Result<(), TlsExtensionError> {
191 let mut client_types_data = None;
192 let mut server_types_data = None;
193
194 for (ext_id, data) in extensions {
195 match *ext_id {
196 47 => client_types_data = Some(data.as_slice()),
197 48 => server_types_data = Some(data.as_slice()),
198 _ => {}
199 }
200 }
201
202 self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
203 }
204
205 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
206 self.complete_negotiation(conn_id).ok()
207 }
208}
209
210pub struct Rfc7250ClientConfig {
212 inner: Arc<ClientConfig>,
213 extension_context: Arc<SimulatedExtensionContext>,
214}
215
216impl Rfc7250ClientConfig {
217 pub fn new(base_config: ClientConfig, preferences: CertificateTypePreferences) -> Self {
219 Self {
220 inner: Arc::new(base_config),
221 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
222 }
223 }
224
225 pub fn inner(&self) -> &Arc<ClientConfig> {
227 &self.inner
228 }
229
230 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
232 &self.extension_context
233 }
234
235 pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
237 let (client_types, server_types) =
238 self.extension_context.simulate_send_preferences(conn_id);
239
240 let mut extensions = Vec::new();
241
242 if let Some(data) = client_types {
243 extensions.push((47, data)); }
245
246 if let Some(data) = server_types {
247 extensions.push((48, data)); }
249
250 extensions
251 }
252}
253
254pub struct Rfc7250ServerConfig {
256 inner: Arc<ServerConfig>,
257 extension_context: Arc<SimulatedExtensionContext>,
258}
259
260impl Rfc7250ServerConfig {
261 pub fn new(base_config: ServerConfig, preferences: CertificateTypePreferences) -> Self {
263 Self {
264 inner: Arc::new(base_config),
265 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
266 }
267 }
268
269 pub fn inner(&self) -> &Arc<ServerConfig> {
271 &self.inner
272 }
273
274 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
276 &self.extension_context
277 }
278
279 pub fn process_client_hello_extensions(
281 &self,
282 conn_id: &str,
283 client_extensions: &[(u16, Vec<u8>)],
284 ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
285 self.extension_context.simulate_send_preferences(conn_id);
287
288 let mut client_types_data = None;
290 let mut server_types_data = None;
291
292 for (ext_id, data) in client_extensions {
293 match *ext_id {
294 47 => client_types_data = Some(data.as_slice()),
295 48 => server_types_data = Some(data.as_slice()),
296 _ => {}
297 }
298 }
299
300 self.extension_context.simulate_receive_preferences(
302 conn_id,
303 client_types_data,
304 server_types_data,
305 )?;
306
307 let result = self.extension_context.complete_negotiation(conn_id)?;
309
310 let mut response_extensions = Vec::new();
312
313 response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
315 response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
316
317 Ok(response_extensions)
318 }
319}
320
321pub fn should_use_raw_public_key(negotiation_result: &NegotiationResult, is_client: bool) -> bool {
323 if is_client {
324 negotiation_result.client_cert_type.is_raw_public_key()
325 } else {
326 negotiation_result.server_cert_type.is_raw_public_key()
327 }
328}
329
330pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
332 format!("{local_addr}-{remote_addr}")
333}
334
335pub struct ExtensionAwareTlsSession {
337 inner_session: Box<dyn crate::crypto::Session>,
339 extension_hooks: Arc<dyn TlsExtensionHooks>,
341 conn_id: String,
343 is_client: bool,
345 handshake_complete: bool,
347}
348
349impl ExtensionAwareTlsSession {
350 pub fn new(
352 inner_session: Box<dyn crate::crypto::Session>,
353 extension_hooks: Arc<dyn TlsExtensionHooks>,
354 conn_id: String,
355 is_client: bool,
356 ) -> Self {
357 Self {
358 inner_session,
359 extension_hooks,
360 conn_id,
361 is_client,
362 handshake_complete: false,
363 }
364 }
365
366 pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
368 self.extension_hooks.get_negotiation_result(&self.conn_id)
369 }
370}
371
372impl crate::crypto::Session for ExtensionAwareTlsSession {
374 fn initial_keys(
375 &self,
376 dst_cid: &crate::ConnectionId,
377 side: crate::Side,
378 ) -> crate::crypto::Keys {
379 self.inner_session.initial_keys(dst_cid, side)
380 }
381
382 fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
383 self.inner_session.handshake_data()
384 }
385
386 fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
387 self.inner_session.peer_identity()
388 }
389
390 fn early_crypto(
391 &self,
392 ) -> Option<(
393 Box<dyn crate::crypto::HeaderKey>,
394 Box<dyn crate::crypto::PacketKey>,
395 )> {
396 self.inner_session.early_crypto()
397 }
398
399 fn early_data_accepted(&self) -> Option<bool> {
400 self.inner_session.early_data_accepted()
401 }
402
403 fn is_handshaking(&self) -> bool {
404 self.inner_session.is_handshaking()
405 }
406
407 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
408 let result = self.inner_session.read_handshake(buf)?;
409
410 if result && !self.handshake_complete && !self.is_handshaking() {
412 self.handshake_complete = true;
413 self.extension_hooks
414 .on_handshake_complete(&self.conn_id, self.is_client);
415 }
416
417 Ok(result)
418 }
419
420 fn transport_parameters(
421 &self,
422 ) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError>
423 {
424 self.inner_session.transport_parameters()
425 }
426
427 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
428 self.inner_session.write_handshake(buf)
429 }
430
431 fn next_1rtt_keys(
432 &mut self,
433 ) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
434 self.inner_session.next_1rtt_keys()
435 }
436
437 fn is_valid_retry(
438 &self,
439 orig_dst_cid: &crate::ConnectionId,
440 header: &[u8],
441 payload: &[u8],
442 ) -> bool {
443 self.inner_session
444 .is_valid_retry(orig_dst_cid, header, payload)
445 }
446
447 fn export_keying_material(
448 &self,
449 output: &mut [u8],
450 label: &[u8],
451 context: &[u8],
452 ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
453 self.inner_session
454 .export_keying_material(output, label, context)
455 }
456}
457
458pub struct Rfc7250QuicClientConfig {
460 base_config: Arc<dyn QuicClientConfig>,
462 extension_context: Arc<SimulatedExtensionContext>,
464}
465
466impl Rfc7250QuicClientConfig {
467 pub fn new(
469 base_config: Arc<dyn QuicClientConfig>,
470 preferences: CertificateTypePreferences,
471 ) -> Self {
472 Self {
473 base_config,
474 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
475 }
476 }
477}
478
479impl QuicClientConfig for Rfc7250QuicClientConfig {
480 fn start_session(
481 self: Arc<Self>,
482 version: u32,
483 server_name: &str,
484 params: &crate::transport_parameters::TransportParameters,
485 ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
486 let inner_session = self
488 .base_config
489 .clone()
490 .start_session(version, server_name, params)?;
491
492 let conn_id = format!(
494 "client-{}-{}",
495 server_name,
496 std::time::SystemTime::now()
497 .duration_since(std::time::UNIX_EPOCH)
498 .unwrap_or_else(|_| std::time::Duration::from_secs(0))
499 .as_nanos()
500 );
501
502 Ok(Box::new(ExtensionAwareTlsSession::new(
504 inner_session,
505 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
506 conn_id,
507 true, )))
509 }
510}
511
512pub struct Rfc7250QuicServerConfig {
514 base_config: Arc<dyn QuicServerConfig>,
516 extension_context: Arc<SimulatedExtensionContext>,
518}
519
520impl Rfc7250QuicServerConfig {
521 pub fn new(
523 base_config: Arc<dyn QuicServerConfig>,
524 preferences: CertificateTypePreferences,
525 ) -> Self {
526 Self {
527 base_config,
528 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
529 }
530 }
531}
532
533impl QuicServerConfig for Rfc7250QuicServerConfig {
534 fn start_session(
535 self: Arc<Self>,
536 version: u32,
537 params: &crate::transport_parameters::TransportParameters,
538 ) -> Box<dyn crate::crypto::Session> {
539 let inner_session = self.base_config.clone().start_session(version, params);
541
542 let conn_id = format!(
544 "server-{}",
545 std::time::SystemTime::now()
546 .duration_since(std::time::UNIX_EPOCH)
547 .unwrap_or_else(|_| std::time::Duration::from_secs(0))
548 .as_nanos()
549 );
550
551 Box::new(ExtensionAwareTlsSession::new(
553 inner_session,
554 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
555 conn_id,
556 false, ))
558 }
559
560 fn initial_keys(
561 &self,
562 version: u32,
563 dst_cid: &crate::ConnectionId,
564 ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
565 self.base_config.initial_keys(version, dst_cid)
566 }
567
568 fn retry_tag(
569 &self,
570 version: u32,
571 orig_dst_cid: &crate::ConnectionId,
572 packet: &[u8],
573 ) -> [u8; 16] {
574 self.base_config.retry_tag(version, orig_dst_cid, packet)
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::super::tls_extensions::CertificateType;
581 use super::*;
582 use std::sync::Once;
583
584 static INIT: Once = Once::new();
585
586 fn ensure_crypto_provider() {
588 INIT.call_once(|| {
589 #[cfg(feature = "rustls-aws-lc-rs")]
591 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
592
593 #[cfg(feature = "rustls-ring")]
594 let _ = rustls::crypto::ring::default_provider().install_default();
595 });
596 }
597
598 #[test]
599 fn test_simulated_negotiation_flow() {
600 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
602 let client_ctx = SimulatedExtensionContext::new(client_prefs);
603
604 let server_prefs = CertificateTypePreferences::raw_public_key_only();
606 let server_ctx = SimulatedExtensionContext::new(server_prefs);
607
608 let conn_id = "test-connection";
609
610 let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
612 assert!(client_types.is_some());
613 assert!(server_types.is_some());
614
615 server_ctx.simulate_send_preferences(conn_id);
617 server_ctx
618 .simulate_receive_preferences(conn_id, client_types.as_deref(), server_types.as_deref())
619 .unwrap();
620
621 let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
623 assert!(server_result.is_raw_public_key_only());
624
625 let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
627 let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
628
629 client_ctx
630 .simulate_receive_preferences(
631 conn_id,
632 Some(&server_response_client),
633 Some(&server_response_server),
634 )
635 .unwrap();
636
637 let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
639 assert_eq!(client_result, server_result);
640 }
641
642 #[test]
643 fn test_wrapper_configs() {
644 ensure_crypto_provider();
645 let client_config = ClientConfig::builder()
646 .dangerous()
647 .with_custom_certificate_verifier(Arc::new(
648 crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new()),
649 ))
650 .with_no_client_auth();
651
652 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
653 let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
654
655 let conn_id = "test-conn";
656 let extensions = wrapped_client.get_client_hello_extensions(conn_id);
657
658 assert_eq!(extensions.len(), 2);
659 assert_eq!(extensions[0].0, 47); assert_eq!(extensions[1].0, 48); }
662}