1use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
8use rustls::{ClientConfig, ServerConfig};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11
12use super::tls_extensions::{
13 CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
14};
15
16pub trait TlsExtensionHooks: Send + Sync {
18 fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
20
21 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
23
24 fn process_server_hello_extensions(
26 &self,
27 conn_id: &str,
28 extensions: &[(u16, Vec<u8>)],
29 ) -> Result<(), TlsExtensionError>;
30
31 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
33}
34
35#[derive(Debug)]
37pub struct SimulatedExtensionContext {
38 negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
40 local_preferences: CertificateTypePreferences,
42}
43
44#[derive(Debug, Clone)]
45struct NegotiationState {
46 local_preferences: CertificateTypePreferences,
47 remote_client_types: Option<CertificateTypeList>,
48 remote_server_types: Option<CertificateTypeList>,
49 result: Option<NegotiationResult>,
50}
51
52impl SimulatedExtensionContext {
53 pub fn new(preferences: CertificateTypePreferences) -> Self {
55 Self {
56 negotiations: Arc::new(Mutex::new(HashMap::new())),
57 local_preferences: preferences,
58 }
59 }
60
61 pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
64 let mut negotiations = self.negotiations.lock().unwrap();
65
66 let state = NegotiationState {
67 local_preferences: self.local_preferences.clone(),
68 remote_client_types: None,
69 remote_server_types: None,
70 result: None,
71 };
72
73 negotiations.insert(conn_id.to_string(), state);
74
75 let client_ext_data = self.local_preferences.client_types.to_bytes();
77 let server_ext_data = self.local_preferences.server_types.to_bytes();
78
79 (Some(client_ext_data), Some(server_ext_data))
80 }
81
82 pub fn simulate_receive_preferences(
84 &self,
85 conn_id: &str,
86 client_types_data: Option<&[u8]>,
87 server_types_data: Option<&[u8]>,
88 ) -> Result<(), TlsExtensionError> {
89 let mut negotiations = self.negotiations.lock().unwrap();
90
91 let state = negotiations.get_mut(conn_id).ok_or_else(|| {
92 TlsExtensionError::InvalidExtensionData(format!(
93 "No negotiation state for connection {}",
94 conn_id
95 ))
96 })?;
97
98 if let Some(data) = client_types_data {
99 state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
100 }
101
102 if let Some(data) = server_types_data {
103 state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
104 }
105
106 Ok(())
107 }
108
109 pub fn complete_negotiation(
111 &self,
112 conn_id: &str,
113 ) -> Result<NegotiationResult, TlsExtensionError> {
114 let mut negotiations = self.negotiations.lock().unwrap();
115
116 let state = negotiations.get_mut(conn_id).ok_or_else(|| {
117 TlsExtensionError::InvalidExtensionData(format!(
118 "No negotiation state for connection {}",
119 conn_id
120 ))
121 })?;
122
123 if let Some(result) = &state.result {
124 return Ok(result.clone());
125 }
126
127 let result = state.local_preferences.negotiate(
128 state.remote_client_types.as_ref(),
129 state.remote_server_types.as_ref(),
130 )?;
131
132 state.result = Some(result.clone());
133 Ok(result)
134 }
135
136 pub fn cleanup_connection(&self, conn_id: &str) {
138 let mut negotiations = self.negotiations.lock().unwrap();
139 negotiations.remove(conn_id);
140 }
141}
142
143impl TlsExtensionHooks for SimulatedExtensionContext {
144 fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
145 let _ = self.complete_negotiation(conn_id);
147 }
148
149 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
150 let (client_types, server_types) = self.simulate_send_preferences(conn_id);
151
152 let mut extensions = Vec::new();
153
154 if let Some(data) = client_types {
155 extensions.push((47, data)); }
157
158 if let Some(data) = server_types {
159 extensions.push((48, data)); }
161
162 extensions
163 }
164
165 fn process_server_hello_extensions(
166 &self,
167 conn_id: &str,
168 extensions: &[(u16, Vec<u8>)],
169 ) -> Result<(), TlsExtensionError> {
170 let mut client_types_data = None;
171 let mut server_types_data = None;
172
173 for (ext_id, data) in extensions {
174 match *ext_id {
175 47 => client_types_data = Some(data.as_slice()),
176 48 => server_types_data = Some(data.as_slice()),
177 _ => {}
178 }
179 }
180
181 self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
182 }
183
184 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
185 self.complete_negotiation(conn_id).ok()
186 }
187}
188
189pub struct Rfc7250ClientConfig {
191 inner: Arc<ClientConfig>,
192 extension_context: Arc<SimulatedExtensionContext>,
193}
194
195impl Rfc7250ClientConfig {
196 pub fn new(base_config: ClientConfig, preferences: CertificateTypePreferences) -> Self {
198 Self {
199 inner: Arc::new(base_config),
200 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
201 }
202 }
203
204 pub fn inner(&self) -> &Arc<ClientConfig> {
206 &self.inner
207 }
208
209 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
211 &self.extension_context
212 }
213
214 pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
216 let (client_types, server_types) =
217 self.extension_context.simulate_send_preferences(conn_id);
218
219 let mut extensions = Vec::new();
220
221 if let Some(data) = client_types {
222 extensions.push((47, data)); }
224
225 if let Some(data) = server_types {
226 extensions.push((48, data)); }
228
229 extensions
230 }
231}
232
233pub struct Rfc7250ServerConfig {
235 inner: Arc<ServerConfig>,
236 extension_context: Arc<SimulatedExtensionContext>,
237}
238
239impl Rfc7250ServerConfig {
240 pub fn new(base_config: ServerConfig, preferences: CertificateTypePreferences) -> Self {
242 Self {
243 inner: Arc::new(base_config),
244 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
245 }
246 }
247
248 pub fn inner(&self) -> &Arc<ServerConfig> {
250 &self.inner
251 }
252
253 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
255 &self.extension_context
256 }
257
258 pub fn process_client_hello_extensions(
260 &self,
261 conn_id: &str,
262 client_extensions: &[(u16, Vec<u8>)],
263 ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
264 self.extension_context.simulate_send_preferences(conn_id);
266
267 let mut client_types_data = None;
269 let mut server_types_data = None;
270
271 for (ext_id, data) in client_extensions {
272 match *ext_id {
273 47 => client_types_data = Some(data.as_slice()),
274 48 => server_types_data = Some(data.as_slice()),
275 _ => {}
276 }
277 }
278
279 self.extension_context.simulate_receive_preferences(
281 conn_id,
282 client_types_data,
283 server_types_data,
284 )?;
285
286 let result = self.extension_context.complete_negotiation(conn_id)?;
288
289 let mut response_extensions = Vec::new();
291
292 response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
294 response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
295
296 Ok(response_extensions)
297 }
298}
299
300pub fn should_use_raw_public_key(negotiation_result: &NegotiationResult, is_client: bool) -> bool {
302 if is_client {
303 negotiation_result.client_cert_type.is_raw_public_key()
304 } else {
305 negotiation_result.server_cert_type.is_raw_public_key()
306 }
307}
308
309pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
311 format!("{}-{}", local_addr, remote_addr)
312}
313
314pub struct ExtensionAwareTlsSession {
316 inner_session: Box<dyn crate::crypto::Session>,
318 extension_hooks: Arc<dyn TlsExtensionHooks>,
320 conn_id: String,
322 is_client: bool,
324 handshake_complete: bool,
326}
327
328impl ExtensionAwareTlsSession {
329 pub fn new(
331 inner_session: Box<dyn crate::crypto::Session>,
332 extension_hooks: Arc<dyn TlsExtensionHooks>,
333 conn_id: String,
334 is_client: bool,
335 ) -> Self {
336 Self {
337 inner_session,
338 extension_hooks,
339 conn_id,
340 is_client,
341 handshake_complete: false,
342 }
343 }
344
345 pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
347 self.extension_hooks.get_negotiation_result(&self.conn_id)
348 }
349}
350
351impl crate::crypto::Session for ExtensionAwareTlsSession {
353 fn initial_keys(
354 &self,
355 dst_cid: &crate::ConnectionId,
356 side: crate::Side,
357 ) -> crate::crypto::Keys {
358 self.inner_session.initial_keys(dst_cid, side)
359 }
360
361 fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
362 self.inner_session.handshake_data()
363 }
364
365 fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
366 self.inner_session.peer_identity()
367 }
368
369 fn early_crypto(
370 &self,
371 ) -> Option<(
372 Box<dyn crate::crypto::HeaderKey>,
373 Box<dyn crate::crypto::PacketKey>,
374 )> {
375 self.inner_session.early_crypto()
376 }
377
378 fn early_data_accepted(&self) -> Option<bool> {
379 self.inner_session.early_data_accepted()
380 }
381
382 fn is_handshaking(&self) -> bool {
383 self.inner_session.is_handshaking()
384 }
385
386 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
387 let result = self.inner_session.read_handshake(buf)?;
388
389 if result && !self.handshake_complete && !self.is_handshaking() {
391 self.handshake_complete = true;
392 self.extension_hooks
393 .on_handshake_complete(&self.conn_id, self.is_client);
394 }
395
396 Ok(result)
397 }
398
399 fn transport_parameters(
400 &self,
401 ) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError>
402 {
403 self.inner_session.transport_parameters()
404 }
405
406 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
407 self.inner_session.write_handshake(buf)
408 }
409
410 fn next_1rtt_keys(
411 &mut self,
412 ) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
413 self.inner_session.next_1rtt_keys()
414 }
415
416 fn is_valid_retry(
417 &self,
418 orig_dst_cid: &crate::ConnectionId,
419 header: &[u8],
420 payload: &[u8],
421 ) -> bool {
422 self.inner_session
423 .is_valid_retry(orig_dst_cid, header, payload)
424 }
425
426 fn export_keying_material(
427 &self,
428 output: &mut [u8],
429 label: &[u8],
430 context: &[u8],
431 ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
432 self.inner_session
433 .export_keying_material(output, label, context)
434 }
435}
436
437pub struct Rfc7250QuicClientConfig {
439 base_config: Arc<dyn QuicClientConfig>,
441 extension_context: Arc<SimulatedExtensionContext>,
443}
444
445impl Rfc7250QuicClientConfig {
446 pub fn new(
448 base_config: Arc<dyn QuicClientConfig>,
449 preferences: CertificateTypePreferences,
450 ) -> Self {
451 Self {
452 base_config,
453 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
454 }
455 }
456}
457
458impl QuicClientConfig for Rfc7250QuicClientConfig {
459 fn start_session(
460 self: Arc<Self>,
461 version: u32,
462 server_name: &str,
463 params: &crate::transport_parameters::TransportParameters,
464 ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
465 let inner_session = self
467 .base_config
468 .clone()
469 .start_session(version, server_name, params)?;
470
471 let conn_id = format!(
473 "client-{}-{}",
474 server_name,
475 std::time::SystemTime::now()
476 .duration_since(std::time::UNIX_EPOCH)
477 .unwrap()
478 .as_nanos()
479 );
480
481 Ok(Box::new(ExtensionAwareTlsSession::new(
483 inner_session,
484 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
485 conn_id,
486 true, )))
488 }
489}
490
491pub struct Rfc7250QuicServerConfig {
493 base_config: Arc<dyn QuicServerConfig>,
495 extension_context: Arc<SimulatedExtensionContext>,
497}
498
499impl Rfc7250QuicServerConfig {
500 pub fn new(
502 base_config: Arc<dyn QuicServerConfig>,
503 preferences: CertificateTypePreferences,
504 ) -> Self {
505 Self {
506 base_config,
507 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
508 }
509 }
510}
511
512impl QuicServerConfig for Rfc7250QuicServerConfig {
513 fn start_session(
514 self: Arc<Self>,
515 version: u32,
516 params: &crate::transport_parameters::TransportParameters,
517 ) -> Box<dyn crate::crypto::Session> {
518 let inner_session = self.base_config.clone().start_session(version, params);
520
521 let conn_id = format!(
523 "server-{}",
524 std::time::SystemTime::now()
525 .duration_since(std::time::UNIX_EPOCH)
526 .unwrap()
527 .as_nanos()
528 );
529
530 Box::new(ExtensionAwareTlsSession::new(
532 inner_session,
533 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
534 conn_id,
535 false, ))
537 }
538
539 fn initial_keys(
540 &self,
541 version: u32,
542 dst_cid: &crate::ConnectionId,
543 ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
544 self.base_config.initial_keys(version, dst_cid)
545 }
546
547 fn retry_tag(
548 &self,
549 version: u32,
550 orig_dst_cid: &crate::ConnectionId,
551 packet: &[u8],
552 ) -> [u8; 16] {
553 self.base_config.retry_tag(version, orig_dst_cid, packet)
554 }
555}
556
557#[cfg(test)]
558mod tests {
559 use super::super::tls_extensions::CertificateType;
560 use super::*;
561
562 #[test]
563 fn test_simulated_negotiation_flow() {
564 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
566 let client_ctx = SimulatedExtensionContext::new(client_prefs);
567
568 let server_prefs = CertificateTypePreferences::raw_public_key_only();
570 let server_ctx = SimulatedExtensionContext::new(server_prefs);
571
572 let conn_id = "test-connection";
573
574 let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
576 assert!(client_types.is_some());
577 assert!(server_types.is_some());
578
579 server_ctx.simulate_send_preferences(conn_id);
581 server_ctx
582 .simulate_receive_preferences(conn_id, client_types.as_deref(), server_types.as_deref())
583 .unwrap();
584
585 let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
587 assert!(server_result.is_raw_public_key_only());
588
589 let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
591 let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
592
593 client_ctx
594 .simulate_receive_preferences(
595 conn_id,
596 Some(&server_response_client),
597 Some(&server_response_server),
598 )
599 .unwrap();
600
601 let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
603 assert_eq!(client_result, server_result);
604 }
605
606 #[test]
607 fn test_wrapper_configs() {
608 let client_config = ClientConfig::builder()
609 .dangerous()
610 .with_custom_certificate_verifier(Arc::new(
611 crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new()),
612 ))
613 .with_no_client_auth();
614
615 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
616 let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
617
618 let conn_id = "test-conn";
619 let extensions = wrapped_client.get_client_hello_extensions(conn_id);
620
621 assert_eq!(extensions.len(), 2);
622 assert_eq!(extensions[0].0, 47); assert_eq!(extensions[1].0, 48); }
625}