1use std::sync::{Arc, Mutex};
8use std::collections::HashMap;
9use rustls::{ClientConfig, ServerConfig};
10use crate::crypto::{ClientConfig as QuicClientConfig, ServerConfig as QuicServerConfig};
11
12use super::tls_extensions::{
13 CertificateTypeList, CertificateTypePreferences,
14 NegotiationResult, TlsExtensionError,
15};
16
17pub trait TlsExtensionHooks: Send + Sync {
19 fn on_handshake_complete(&self, conn_id: &str, is_client: bool);
21
22 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)>;
24
25 fn process_server_hello_extensions(&self, conn_id: &str, extensions: &[(u16, Vec<u8>)]) -> Result<(), TlsExtensionError>;
27
28 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult>;
30}
31
32#[derive(Debug)]
34pub struct SimulatedExtensionContext {
35 negotiations: Arc<Mutex<HashMap<String, NegotiationState>>>,
37 local_preferences: CertificateTypePreferences,
39}
40
41#[derive(Debug, Clone)]
42struct NegotiationState {
43 local_preferences: CertificateTypePreferences,
44 remote_client_types: Option<CertificateTypeList>,
45 remote_server_types: Option<CertificateTypeList>,
46 result: Option<NegotiationResult>,
47}
48
49impl SimulatedExtensionContext {
50 pub fn new(preferences: CertificateTypePreferences) -> Self {
52 Self {
53 negotiations: Arc::new(Mutex::new(HashMap::new())),
54 local_preferences: preferences,
55 }
56 }
57
58 pub fn simulate_send_preferences(&self, conn_id: &str) -> (Option<Vec<u8>>, Option<Vec<u8>>) {
61 let mut negotiations = self.negotiations.lock().unwrap();
62
63 let state = NegotiationState {
64 local_preferences: self.local_preferences.clone(),
65 remote_client_types: None,
66 remote_server_types: None,
67 result: None,
68 };
69
70 negotiations.insert(conn_id.to_string(), state);
71
72 let client_ext_data = self.local_preferences.client_types.to_bytes();
74 let server_ext_data = self.local_preferences.server_types.to_bytes();
75
76 (Some(client_ext_data), Some(server_ext_data))
77 }
78
79 pub fn simulate_receive_preferences(
81 &self,
82 conn_id: &str,
83 client_types_data: Option<&[u8]>,
84 server_types_data: Option<&[u8]>,
85 ) -> Result<(), TlsExtensionError> {
86 let mut negotiations = self.negotiations.lock().unwrap();
87
88 let state = negotiations.get_mut(conn_id)
89 .ok_or_else(|| TlsExtensionError::InvalidExtensionData(
90 format!("No negotiation state for connection {}", conn_id)
91 ))?;
92
93 if let Some(data) = client_types_data {
94 state.remote_client_types = Some(CertificateTypeList::from_bytes(data)?);
95 }
96
97 if let Some(data) = server_types_data {
98 state.remote_server_types = Some(CertificateTypeList::from_bytes(data)?);
99 }
100
101 Ok(())
102 }
103
104 pub fn complete_negotiation(&self, conn_id: &str) -> Result<NegotiationResult, TlsExtensionError> {
106 let mut negotiations = self.negotiations.lock().unwrap();
107
108 let state = negotiations.get_mut(conn_id)
109 .ok_or_else(|| TlsExtensionError::InvalidExtensionData(
110 format!("No negotiation state for connection {}", conn_id)
111 ))?;
112
113 if let Some(result) = &state.result {
114 return Ok(result.clone());
115 }
116
117 let result = state.local_preferences.negotiate(
118 state.remote_client_types.as_ref(),
119 state.remote_server_types.as_ref(),
120 )?;
121
122 state.result = Some(result.clone());
123 Ok(result)
124 }
125
126 pub fn cleanup_connection(&self, conn_id: &str) {
128 let mut negotiations = self.negotiations.lock().unwrap();
129 negotiations.remove(conn_id);
130 }
131}
132
133impl TlsExtensionHooks for SimulatedExtensionContext {
134 fn on_handshake_complete(&self, conn_id: &str, _is_client: bool) {
135 let _ = self.complete_negotiation(conn_id);
137 }
138
139 fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
140 let (client_types, server_types) = self.simulate_send_preferences(conn_id);
141
142 let mut extensions = Vec::new();
143
144 if let Some(data) = client_types {
145 extensions.push((47, data)); }
147
148 if let Some(data) = server_types {
149 extensions.push((48, data)); }
151
152 extensions
153 }
154
155 fn process_server_hello_extensions(&self, conn_id: &str, extensions: &[(u16, Vec<u8>)]) -> Result<(), TlsExtensionError> {
156 let mut client_types_data = None;
157 let mut server_types_data = None;
158
159 for (ext_id, data) in extensions {
160 match *ext_id {
161 47 => client_types_data = Some(data.as_slice()),
162 48 => server_types_data = Some(data.as_slice()),
163 _ => {}
164 }
165 }
166
167 self.simulate_receive_preferences(conn_id, client_types_data, server_types_data)
168 }
169
170 fn get_negotiation_result(&self, conn_id: &str) -> Option<NegotiationResult> {
171 self.complete_negotiation(conn_id).ok()
172 }
173}
174
175pub struct Rfc7250ClientConfig {
177 inner: Arc<ClientConfig>,
178 extension_context: Arc<SimulatedExtensionContext>,
179}
180
181impl Rfc7250ClientConfig {
182 pub fn new(
184 base_config: ClientConfig,
185 preferences: CertificateTypePreferences,
186 ) -> Self {
187 Self {
188 inner: Arc::new(base_config),
189 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
190 }
191 }
192
193 pub fn inner(&self) -> &Arc<ClientConfig> {
195 &self.inner
196 }
197
198 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
200 &self.extension_context
201 }
202
203 pub fn get_client_hello_extensions(&self, conn_id: &str) -> Vec<(u16, Vec<u8>)> {
205 let (client_types, server_types) = self.extension_context.simulate_send_preferences(conn_id);
206
207 let mut extensions = Vec::new();
208
209 if let Some(data) = client_types {
210 extensions.push((47, data)); }
212
213 if let Some(data) = server_types {
214 extensions.push((48, data)); }
216
217 extensions
218 }
219}
220
221pub struct Rfc7250ServerConfig {
223 inner: Arc<ServerConfig>,
224 extension_context: Arc<SimulatedExtensionContext>,
225}
226
227impl Rfc7250ServerConfig {
228 pub fn new(
230 base_config: ServerConfig,
231 preferences: CertificateTypePreferences,
232 ) -> Self {
233 Self {
234 inner: Arc::new(base_config),
235 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
236 }
237 }
238
239 pub fn inner(&self) -> &Arc<ServerConfig> {
241 &self.inner
242 }
243
244 pub fn extension_context(&self) -> &Arc<SimulatedExtensionContext> {
246 &self.extension_context
247 }
248
249 pub fn process_client_hello_extensions(
251 &self,
252 conn_id: &str,
253 client_extensions: &[(u16, Vec<u8>)],
254 ) -> Result<Vec<(u16, Vec<u8>)>, TlsExtensionError> {
255 self.extension_context.simulate_send_preferences(conn_id);
257
258 let mut client_types_data = None;
260 let mut server_types_data = None;
261
262 for (ext_id, data) in client_extensions {
263 match *ext_id {
264 47 => client_types_data = Some(data.as_slice()),
265 48 => server_types_data = Some(data.as_slice()),
266 _ => {}
267 }
268 }
269
270 self.extension_context.simulate_receive_preferences(
272 conn_id,
273 client_types_data,
274 server_types_data,
275 )?;
276
277 let result = self.extension_context.complete_negotiation(conn_id)?;
279
280 let mut response_extensions = Vec::new();
282
283 response_extensions.push((47, vec![1, result.client_cert_type.to_u8()]));
285 response_extensions.push((48, vec![1, result.server_cert_type.to_u8()]));
286
287 Ok(response_extensions)
288 }
289}
290
291pub fn should_use_raw_public_key(
293 negotiation_result: &NegotiationResult,
294 is_client: bool,
295) -> bool {
296 if is_client {
297 negotiation_result.client_cert_type.is_raw_public_key()
298 } else {
299 negotiation_result.server_cert_type.is_raw_public_key()
300 }
301}
302
303pub fn create_connection_id(local_addr: &str, remote_addr: &str) -> String {
305 format!("{}-{}", local_addr, remote_addr)
306}
307
308pub struct ExtensionAwareTlsSession {
310 inner_session: Box<dyn crate::crypto::Session>,
312 extension_hooks: Arc<dyn TlsExtensionHooks>,
314 conn_id: String,
316 is_client: bool,
318 handshake_complete: bool,
320}
321
322impl ExtensionAwareTlsSession {
323 pub fn new(
325 inner_session: Box<dyn crate::crypto::Session>,
326 extension_hooks: Arc<dyn TlsExtensionHooks>,
327 conn_id: String,
328 is_client: bool,
329 ) -> Self {
330 Self {
331 inner_session,
332 extension_hooks,
333 conn_id,
334 is_client,
335 handshake_complete: false,
336 }
337 }
338
339 pub fn get_negotiation_result(&self) -> Option<NegotiationResult> {
341 self.extension_hooks.get_negotiation_result(&self.conn_id)
342 }
343}
344
345impl crate::crypto::Session for ExtensionAwareTlsSession {
347 fn initial_keys(&self, dst_cid: &crate::ConnectionId, side: crate::Side) -> crate::crypto::Keys {
348 self.inner_session.initial_keys(dst_cid, side)
349 }
350
351 fn handshake_data(&self) -> Option<Box<dyn std::any::Any>> {
352 self.inner_session.handshake_data()
353 }
354
355 fn peer_identity(&self) -> Option<Box<dyn std::any::Any>> {
356 self.inner_session.peer_identity()
357 }
358
359 fn early_crypto(&self) -> Option<(Box<dyn crate::crypto::HeaderKey>, Box<dyn crate::crypto::PacketKey>)> {
360 self.inner_session.early_crypto()
361 }
362
363 fn early_data_accepted(&self) -> Option<bool> {
364 self.inner_session.early_data_accepted()
365 }
366
367 fn is_handshaking(&self) -> bool {
368 self.inner_session.is_handshaking()
369 }
370
371 fn read_handshake(&mut self, buf: &[u8]) -> Result<bool, crate::TransportError> {
372 let result = self.inner_session.read_handshake(buf)?;
373
374 if result && !self.handshake_complete && !self.is_handshaking() {
376 self.handshake_complete = true;
377 self.extension_hooks.on_handshake_complete(&self.conn_id, self.is_client);
378 }
379
380 Ok(result)
381 }
382
383 fn transport_parameters(&self) -> Result<Option<crate::transport_parameters::TransportParameters>, crate::TransportError> {
384 self.inner_session.transport_parameters()
385 }
386
387 fn write_handshake(&mut self, buf: &mut Vec<u8>) -> Option<crate::crypto::Keys> {
388 self.inner_session.write_handshake(buf)
389 }
390
391 fn next_1rtt_keys(&mut self) -> Option<crate::crypto::KeyPair<Box<dyn crate::crypto::PacketKey>>> {
392 self.inner_session.next_1rtt_keys()
393 }
394
395 fn is_valid_retry(&self, orig_dst_cid: &crate::ConnectionId, header: &[u8], payload: &[u8]) -> bool {
396 self.inner_session.is_valid_retry(orig_dst_cid, header, payload)
397 }
398
399 fn export_keying_material(
400 &self,
401 output: &mut [u8],
402 label: &[u8],
403 context: &[u8],
404 ) -> Result<(), crate::crypto::ExportKeyingMaterialError> {
405 self.inner_session.export_keying_material(output, label, context)
406 }
407}
408
409pub struct Rfc7250QuicClientConfig {
411 base_config: Arc<dyn QuicClientConfig>,
413 extension_context: Arc<SimulatedExtensionContext>,
415}
416
417impl Rfc7250QuicClientConfig {
418 pub fn new(
420 base_config: Arc<dyn QuicClientConfig>,
421 preferences: CertificateTypePreferences,
422 ) -> Self {
423 Self {
424 base_config,
425 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
426 }
427 }
428}
429
430impl QuicClientConfig for Rfc7250QuicClientConfig {
431 fn start_session(
432 self: Arc<Self>,
433 version: u32,
434 server_name: &str,
435 params: &crate::transport_parameters::TransportParameters,
436 ) -> Result<Box<dyn crate::crypto::Session>, crate::ConnectError> {
437 let inner_session = self.base_config.clone().start_session(version, server_name, params)?;
439
440 let conn_id = format!("client-{}-{}", server_name, std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
442
443 Ok(Box::new(ExtensionAwareTlsSession::new(
445 inner_session,
446 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
447 conn_id,
448 true, )))
450 }
451}
452
453pub struct Rfc7250QuicServerConfig {
455 base_config: Arc<dyn QuicServerConfig>,
457 extension_context: Arc<SimulatedExtensionContext>,
459}
460
461impl Rfc7250QuicServerConfig {
462 pub fn new(
464 base_config: Arc<dyn QuicServerConfig>,
465 preferences: CertificateTypePreferences,
466 ) -> Self {
467 Self {
468 base_config,
469 extension_context: Arc::new(SimulatedExtensionContext::new(preferences)),
470 }
471 }
472}
473
474impl QuicServerConfig for Rfc7250QuicServerConfig {
475 fn start_session(
476 self: Arc<Self>,
477 version: u32,
478 params: &crate::transport_parameters::TransportParameters,
479 ) -> Box<dyn crate::crypto::Session> {
480 let inner_session = self.base_config.clone().start_session(version, params);
482
483 let conn_id = format!("server-{}", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos());
485
486 Box::new(ExtensionAwareTlsSession::new(
488 inner_session,
489 self.extension_context.clone() as Arc<dyn TlsExtensionHooks>,
490 conn_id,
491 false, ))
493 }
494
495 fn initial_keys(
496 &self,
497 version: u32,
498 dst_cid: &crate::ConnectionId,
499 ) -> Result<crate::crypto::Keys, crate::crypto::UnsupportedVersion> {
500 self.base_config.initial_keys(version, dst_cid)
501 }
502
503 fn retry_tag(&self, version: u32, orig_dst_cid: &crate::ConnectionId, packet: &[u8]) -> [u8; 16] {
504 self.base_config.retry_tag(version, orig_dst_cid, packet)
505 }
506}
507
508#[cfg(test)]
509mod tests {
510 use super::*;
511 use super::super::tls_extensions::CertificateType;
512
513 #[test]
514 fn test_simulated_negotiation_flow() {
515 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
517 let client_ctx = SimulatedExtensionContext::new(client_prefs);
518
519 let server_prefs = CertificateTypePreferences::raw_public_key_only();
521 let server_ctx = SimulatedExtensionContext::new(server_prefs);
522
523 let conn_id = "test-connection";
524
525 let (client_types, server_types) = client_ctx.simulate_send_preferences(conn_id);
527 assert!(client_types.is_some());
528 assert!(server_types.is_some());
529
530 server_ctx.simulate_send_preferences(conn_id);
532 server_ctx.simulate_receive_preferences(
533 conn_id,
534 client_types.as_deref(),
535 server_types.as_deref(),
536 ).unwrap();
537
538 let server_result = server_ctx.complete_negotiation(conn_id).unwrap();
540 assert!(server_result.is_raw_public_key_only());
541
542 let server_response_client = vec![1, CertificateType::RawPublicKey.to_u8()];
544 let server_response_server = vec![1, CertificateType::RawPublicKey.to_u8()];
545
546 client_ctx.simulate_receive_preferences(
547 conn_id,
548 Some(&server_response_client),
549 Some(&server_response_server),
550 ).unwrap();
551
552 let client_result = client_ctx.complete_negotiation(conn_id).unwrap();
554 assert_eq!(client_result, server_result);
555 }
556
557 #[test]
558 fn test_wrapper_configs() {
559 let client_config = ClientConfig::builder()
560 .dangerous()
561 .with_custom_certificate_verifier(Arc::new(crate::crypto::raw_public_keys::RawPublicKeyVerifier::new(Vec::new())))
562 .with_no_client_auth();
563
564 let client_prefs = CertificateTypePreferences::prefer_raw_public_key();
565 let wrapped_client = Rfc7250ClientConfig::new(client_config, client_prefs);
566
567 let conn_id = "test-conn";
568 let extensions = wrapped_client.get_client_hello_extensions(conn_id);
569
570 assert_eq!(extensions.len(), 2);
571 assert_eq!(extensions[0].0, 47); assert_eq!(extensions[1].0, 48); }
574}