1#![allow(missing_docs)]
8
9use std::{
16 collections::HashMap,
17 hash::{Hash, Hasher},
18 sync::{Arc, Mutex, RwLock},
19 time::{Duration, Instant},
20};
21
22use tracing::{Level, debug, info, span, warn};
23
24use super::tls_extensions::{
25 CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
26};
27
28#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum NegotiationState {
31 Pending,
33 Waiting {
35 sent_at: Instant,
36 our_preferences: CertificateTypePreferences,
37 },
38 Completed {
40 result: NegotiationResult,
41 completed_at: Instant,
42 },
43 Failed {
45 error: String,
47 failed_at: Instant,
49 },
50 TimedOut {
52 timeout_at: Instant,
54 },
55}
56
57impl NegotiationState {
58 pub fn is_complete(&self) -> bool {
60 matches!(
61 self,
62 Self::Completed { .. } | Self::Failed { .. } | Self::TimedOut { .. }
63 )
64 }
65
66 pub fn is_successful(&self) -> bool {
68 matches!(self, Self::Completed { .. })
69 }
70
71 pub fn get_result(&self) -> Option<&NegotiationResult> {
73 match self {
74 Self::Completed { result, .. } => Some(result),
75 _ => None,
76 }
77 }
78
79 pub fn get_error(&self) -> Option<&str> {
81 match self {
82 Self::Failed { error, .. } => Some(error),
83 _ => None,
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90pub struct NegotiationConfig {
91 pub timeout: Duration,
93 pub enable_caching: bool,
95 pub max_cache_size: usize,
97 pub allow_fallback: bool,
99 pub default_preferences: CertificateTypePreferences,
101}
102
103impl Default for NegotiationConfig {
104 fn default() -> Self {
105 Self {
106 timeout: Duration::from_secs(10),
107 enable_caching: true,
108 max_cache_size: 1000,
109 allow_fallback: true,
110 default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
111 }
112 }
113}
114
115#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
117pub struct NegotiationId(u64);
118
119impl Default for NegotiationId {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl NegotiationId {
126 pub fn new() -> Self {
128 use std::sync::atomic::{AtomicU64, Ordering};
129 static COUNTER: AtomicU64 = AtomicU64::new(1);
130 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
131 }
132
133 pub fn as_u64(self) -> u64 {
135 self.0
136 }
137}
138
139#[derive(Debug, Clone, PartialEq, Eq, Hash)]
141struct CacheKey {
142 local_preferences: String, remote_preferences: String, }
147
148impl CacheKey {
149 fn new(
151 local: &CertificateTypePreferences,
152 remote_client: Option<&CertificateTypeList>,
153 remote_server: Option<&CertificateTypeList>,
154 ) -> Self {
155 use std::collections::hash_map::DefaultHasher;
156
157 let mut hasher = DefaultHasher::new();
158 local.hash(&mut hasher);
159 let local_hash = hasher.finish();
160
161 let mut hasher = DefaultHasher::new();
162 if let Some(types) = remote_client {
163 types.hash(&mut hasher);
164 }
165 if let Some(types) = remote_server {
166 types.hash(&mut hasher);
167 }
168 let remote_hash = hasher.finish();
169
170 Self {
171 local_preferences: format!("{local_hash:x}"),
172 remote_preferences: format!("{remote_hash:x}"),
173 }
174 }
175}
176
177impl Hash for CertificateTypePreferences {
179 fn hash<H: Hasher>(&self, state: &mut H) {
180 self.client_types.types.hash(state);
181 self.server_types.types.hash(state);
182 self.require_extensions.hash(state);
183 self.fallback_client.hash(state);
184 self.fallback_server.hash(state);
185 }
186}
187
188impl Hash for CertificateTypeList {
190 fn hash<H: Hasher>(&self, state: &mut H) {
191 self.types.hash(state);
192 }
193}
194
195pub struct CertificateNegotiationManager {
197 config: NegotiationConfig,
199 sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
201 cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
203 stats: Arc<Mutex<NegotiationStats>>,
205}
206
207#[derive(Debug, Default, Clone)]
209pub struct NegotiationStats {
210 pub total_attempts: u64,
212 pub successful: u64,
214 pub failed: u64,
216 pub timed_out: u64,
218 pub cache_hits: u64,
220 pub cache_misses: u64,
222 pub avg_negotiation_time: Duration,
224}
225
226impl CertificateNegotiationManager {
227 pub fn new(config: NegotiationConfig) -> Self {
229 Self {
230 config,
231 sessions: RwLock::new(HashMap::new()),
232 cache: Arc::new(Mutex::new(HashMap::new())),
233 stats: Arc::new(Mutex::new(NegotiationStats::default())),
234 }
235 }
236
237 pub fn start_negotiation(
239 &self,
240 preferences: CertificateTypePreferences,
241 ) -> Result<NegotiationId, TlsExtensionError> {
242 let id = NegotiationId::new();
243 let state = NegotiationState::Waiting {
244 sent_at: Instant::now(),
245 our_preferences: preferences,
246 };
247
248 let mut sessions = self.sessions.write().map_err(|e| {
249 TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
250 })?;
251 sessions.insert(id, state);
252
253 let mut stats = self.stats.lock().map_err(|e| {
254 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
255 })?;
256 stats.total_attempts += 1;
257
258 debug!("Started certificate type negotiation: {:?}", id);
259 Ok(id)
260 }
261
262 pub fn complete_negotiation(
264 &self,
265 id: NegotiationId,
266 remote_client_types: Option<CertificateTypeList>,
267 remote_server_types: Option<CertificateTypeList>,
268 ) -> Result<NegotiationResult, TlsExtensionError> {
269 let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
270
271 let mut sessions = self.sessions.write().map_err(|e| {
272 TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
273 })?;
274 let state = sessions.get(&id).ok_or_else(|| {
275 TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {id:?}"))
276 })?;
277
278 let our_preferences = match state {
279 NegotiationState::Waiting {
280 our_preferences, ..
281 } => our_preferences.clone(),
282 _ => {
283 return Err(TlsExtensionError::InvalidExtensionData(
284 "Negotiation not in waiting state".to_string(),
285 ));
286 }
287 };
288
289 if self.config.enable_caching {
291 let cache_key = CacheKey::new(
292 &our_preferences,
293 remote_client_types.as_ref(),
294 remote_server_types.as_ref(),
295 );
296
297 let mut cache = self.cache.lock().map_err(|e| {
298 TlsExtensionError::InvalidExtensionData(format!("Cache lock poisoned: {}", e))
299 })?;
300 if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
301 if cached_at.elapsed() < Duration::from_secs(300) {
303 let mut stats = self.stats.lock().map_err(|e| {
305 TlsExtensionError::InvalidExtensionData(format!(
306 "Stats lock poisoned: {}",
307 e
308 ))
309 })?;
310 stats.cache_hits += 1;
311
312 sessions.insert(
314 id,
315 NegotiationState::Completed {
316 result: cached_result.clone(),
317 completed_at: Instant::now(),
318 },
319 );
320
321 debug!("Cache hit for negotiation: {:?}", id);
322 return Ok(cached_result.clone());
323 } else {
324 cache.remove(&cache_key);
326 }
327 }
328
329 let mut stats = self.stats.lock().map_err(|e| {
330 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
331 })?;
332 stats.cache_misses += 1;
333 }
334
335 let negotiation_start = Instant::now();
337 let result =
338 our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
339
340 match result {
341 Ok(negotiation_result) => {
342 let completed_at = Instant::now();
343 let negotiation_time = negotiation_start.elapsed();
344
345 sessions.insert(
347 id,
348 NegotiationState::Completed {
349 result: negotiation_result.clone(),
350 completed_at,
351 },
352 );
353
354 let mut stats = self.stats.lock().map_err(|e| {
356 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
357 })?;
358 stats.successful += 1;
359
360 let total_completed = stats.successful + stats.failed;
362 stats.avg_negotiation_time = if total_completed == 1 {
363 negotiation_time
364 } else {
365 Duration::from_nanos(
366 (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
367 + negotiation_time.as_nanos() as u64)
368 / total_completed,
369 )
370 };
371
372 if self.config.enable_caching {
374 let cache_key = CacheKey::new(
375 &our_preferences,
376 remote_client_types.as_ref(),
377 remote_server_types.as_ref(),
378 );
379
380 let mut cache = self.cache.lock().map_err(|e| {
381 TlsExtensionError::InvalidExtensionData(format!(
382 "Cache lock poisoned: {}",
383 e
384 ))
385 })?;
386
387 if cache.len() >= self.config.max_cache_size {
389 let mut entries: Vec<_> =
391 cache.iter().map(|(k, (_, t))| (k.clone(), *t)).collect();
392 entries.sort_by_key(|(_, timestamp)| *timestamp);
393
394 let to_remove = cache.len() - self.config.max_cache_size + 1;
395 let keys_to_remove: Vec<_> = entries
396 .iter()
397 .take(to_remove)
398 .map(|(key, _)| key.clone())
399 .collect();
400
401 for key in keys_to_remove {
402 cache.remove(&key);
403 }
404 }
405
406 cache.insert(cache_key, (negotiation_result.clone(), completed_at));
407 }
408
409 info!(
410 "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
411 id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
412 );
413
414 Ok(negotiation_result)
415 }
416 Err(error) => {
417 sessions.insert(
419 id,
420 NegotiationState::Failed {
421 error: error.to_string(),
422 failed_at: Instant::now(),
423 },
424 );
425
426 let mut stats = self.stats.lock().map_err(|e| {
428 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
429 })?;
430 stats.failed += 1;
431
432 warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
433 Err(error)
434 }
435 }
436 }
437
438 #[allow(clippy::unwrap_used, clippy::expect_used)]
440 pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
441 let mut sessions = self
442 .sessions
443 .write()
444 .expect("Mutex poisoning is unexpected in normal operation");
445 sessions.insert(
446 id,
447 NegotiationState::Failed {
448 error,
449 failed_at: Instant::now(),
450 },
451 );
452
453 let mut stats = self
454 .stats
455 .lock()
456 .expect("Mutex poisoning is unexpected in normal operation");
457 stats.failed += 1;
458
459 warn!("Certificate type negotiation failed: {:?}", id);
460 }
461
462 #[allow(clippy::unwrap_used, clippy::expect_used)]
464 pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
465 let sessions = self
466 .sessions
467 .read()
468 .expect("Mutex poisoning is unexpected in normal operation");
469 sessions.get(&id).cloned()
470 }
471
472 #[allow(clippy::unwrap_used, clippy::expect_used)]
474 pub fn handle_timeouts(&self) {
475 let mut sessions = self
476 .sessions
477 .write()
478 .expect("Mutex poisoning is unexpected in normal operation");
479 let mut timed_out_ids = Vec::new();
480
481 for (id, state) in sessions.iter() {
482 if let NegotiationState::Waiting { sent_at, .. } = state {
483 if sent_at.elapsed() > self.config.timeout {
484 timed_out_ids.push(*id);
485 }
486 }
487 }
488
489 for id in timed_out_ids {
490 sessions.insert(
491 id,
492 NegotiationState::TimedOut {
493 timeout_at: Instant::now(),
494 },
495 );
496
497 let mut stats = self
498 .stats
499 .lock()
500 .expect("Mutex poisoning is unexpected in normal operation");
501 stats.timed_out += 1;
502
503 warn!("Certificate type negotiation timed out: {:?}", id);
504 }
505 }
506
507 #[allow(clippy::unwrap_used, clippy::expect_used)]
509 pub fn cleanup_old_sessions(&self, max_age: Duration) {
510 let mut sessions = self
511 .sessions
512 .write()
513 .expect("Mutex poisoning is unexpected in normal operation");
514 let cutoff = Instant::now() - max_age;
515
516 sessions.retain(|id, state| {
517 let should_retain = match state {
518 NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
519 NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
520 NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
521 _ => true, };
523
524 if !should_retain {
525 debug!("Cleaned up old negotiation session: {:?}", id);
526 }
527
528 should_retain
529 });
530 }
531
532 #[allow(clippy::unwrap_used, clippy::expect_used)]
534 pub fn get_stats(&self) -> NegotiationStats {
535 self.stats
536 .lock()
537 .expect("Mutex poisoning is unexpected in normal operation")
538 .clone()
539 }
540
541 #[allow(clippy::unwrap_used, clippy::expect_used)]
543 pub fn clear_cache(&self) {
544 let mut cache = self
545 .cache
546 .lock()
547 .expect("Mutex poisoning is unexpected in normal operation");
548 cache.clear();
549 debug!("Cleared certificate type negotiation cache");
550 }
551
552 #[allow(clippy::unwrap_used, clippy::expect_used)]
554 pub fn get_cache_stats(&self) -> (usize, usize) {
555 let cache = self
556 .cache
557 .lock()
558 .expect("Mutex poisoning is unexpected in normal operation");
559 (cache.len(), self.config.max_cache_size)
560 }
561}
562
563impl Default for CertificateNegotiationManager {
564 fn default() -> Self {
565 Self::new(NegotiationConfig::default())
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::super::tls_extensions::CertificateType;
572 use super::*;
573
574 #[test]
575 fn test_negotiation_id_generation() {
576 let id1 = NegotiationId::new();
577 let id2 = NegotiationId::new();
578
579 assert_ne!(id1, id2);
580 assert!(id1.as_u64() > 0);
581 assert!(id2.as_u64() > 0);
582 }
583
584 #[test]
585 fn test_negotiation_state_checks() {
586 let pending = NegotiationState::Pending;
587 assert!(!pending.is_complete());
588 assert!(!pending.is_successful());
589
590 let completed = NegotiationState::Completed {
591 result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
592 completed_at: Instant::now(),
593 };
594 assert!(completed.is_complete());
595 assert!(completed.is_successful());
596 assert!(completed.get_result().is_some());
597
598 let failed = NegotiationState::Failed {
599 error: "Test error".to_string(),
600 failed_at: Instant::now(),
601 };
602 assert!(failed.is_complete());
603 assert!(!failed.is_successful());
604 assert_eq!(failed.get_error().unwrap(), "Test error");
605 }
606
607 #[test]
608 fn test_negotiation_manager_basic_flow() {
609 let manager = CertificateNegotiationManager::default();
610 let preferences = CertificateTypePreferences::prefer_raw_public_key();
611
612 let id = manager.start_negotiation(preferences).unwrap();
614
615 let state = manager.get_negotiation_state(id).unwrap();
616 assert!(matches!(state, NegotiationState::Waiting { .. }));
617
618 let remote_types = CertificateTypeList::raw_public_key_only();
620 let result = manager
621 .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
622 .unwrap();
623
624 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
625 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
626
627 let state = manager.get_negotiation_state(id).unwrap();
628 assert!(state.is_successful());
629 }
630
631 #[test]
632 fn test_negotiation_caching() {
633 let config = NegotiationConfig {
634 enable_caching: true,
635 ..Default::default()
636 };
637 let manager = CertificateNegotiationManager::new(config);
638 let preferences = CertificateTypePreferences::prefer_raw_public_key();
639
640 let id1 = manager.start_negotiation(preferences.clone()).unwrap();
642 let remote_types = CertificateTypeList::raw_public_key_only();
643 let result1 = manager
644 .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
645 .unwrap();
646
647 let id2 = manager.start_negotiation(preferences).unwrap();
649 let result2 = manager
650 .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
651 .unwrap();
652
653 assert_eq!(result1, result2);
654
655 let stats = manager.get_stats();
656 assert_eq!(stats.cache_hits, 1);
657 assert_eq!(stats.cache_misses, 1);
658 }
659
660 #[test]
661 fn test_negotiation_timeout_handling() {
662 let config = NegotiationConfig {
663 timeout: Duration::from_millis(1),
664 ..Default::default()
665 };
666 let manager = CertificateNegotiationManager::new(config);
667 let preferences = CertificateTypePreferences::prefer_raw_public_key();
668
669 let id = manager.start_negotiation(preferences).unwrap();
670
671 std::thread::sleep(Duration::from_millis(10));
673 manager.handle_timeouts();
674
675 let state = manager.get_negotiation_state(id).unwrap();
676 assert!(matches!(state, NegotiationState::TimedOut { .. }));
677
678 let stats = manager.get_stats();
679 assert_eq!(stats.timed_out, 1);
680 }
681}