1use std::{
8 collections::HashMap,
9 hash::{Hash, Hasher},
10 sync::{Arc, Mutex, RwLock},
11 time::{Duration, Instant},
12};
13
14use tracing::{Level, debug, info, span, warn};
15
16use super::tls_extensions::{
17 CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
18};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NegotiationState {
23 Pending,
25 Waiting {
27 sent_at: Instant,
28 our_preferences: CertificateTypePreferences,
29 },
30 Completed {
32 result: NegotiationResult,
33 completed_at: Instant,
34 },
35 Failed {
37 error: String,
39 failed_at: Instant,
41 },
42 TimedOut {
44 timeout_at: Instant,
46 },
47}
48
49impl NegotiationState {
50 pub fn is_complete(&self) -> bool {
52 matches!(
53 self,
54 Self::Completed { .. } | Self::Failed { .. } | Self::TimedOut { .. }
55 )
56 }
57
58 pub fn is_successful(&self) -> bool {
60 matches!(self, Self::Completed { .. })
61 }
62
63 pub fn get_result(&self) -> Option<&NegotiationResult> {
65 match self {
66 Self::Completed { result, .. } => Some(result),
67 _ => None,
68 }
69 }
70
71 pub fn get_error(&self) -> Option<&str> {
73 match self {
74 Self::Failed { error, .. } => Some(error),
75 _ => None,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct NegotiationConfig {
83 pub timeout: Duration,
85 pub enable_caching: bool,
87 pub max_cache_size: usize,
89 pub allow_fallback: bool,
91 pub default_preferences: CertificateTypePreferences,
93}
94
95impl Default for NegotiationConfig {
96 fn default() -> Self {
97 Self {
98 timeout: Duration::from_secs(10),
99 enable_caching: true,
100 max_cache_size: 1000,
101 allow_fallback: true,
102 default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
109pub struct NegotiationId(u64);
110
111impl Default for NegotiationId {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl NegotiationId {
118 pub fn new() -> Self {
120 use std::sync::atomic::{AtomicU64, Ordering};
121 static COUNTER: AtomicU64 = AtomicU64::new(1);
122 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
123 }
124
125 pub fn as_u64(self) -> u64 {
127 self.0
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133struct CacheKey {
134 local_preferences: String, remote_preferences: String, }
139
140impl CacheKey {
141 fn new(
143 local: &CertificateTypePreferences,
144 remote_client: Option<&CertificateTypeList>,
145 remote_server: Option<&CertificateTypeList>,
146 ) -> Self {
147 use std::collections::hash_map::DefaultHasher;
148
149 let mut hasher = DefaultHasher::new();
150 local.hash(&mut hasher);
151 let local_hash = hasher.finish();
152
153 let mut hasher = DefaultHasher::new();
154 if let Some(types) = remote_client {
155 types.hash(&mut hasher);
156 }
157 if let Some(types) = remote_server {
158 types.hash(&mut hasher);
159 }
160 let remote_hash = hasher.finish();
161
162 Self {
163 local_preferences: format!("{local_hash:x}"),
164 remote_preferences: format!("{remote_hash:x}"),
165 }
166 }
167}
168
169impl Hash for CertificateTypePreferences {
171 fn hash<H: Hasher>(&self, state: &mut H) {
172 self.client_types.types.hash(state);
173 self.server_types.types.hash(state);
174 self.require_extensions.hash(state);
175 self.fallback_client.hash(state);
176 self.fallback_server.hash(state);
177 }
178}
179
180impl Hash for CertificateTypeList {
182 fn hash<H: Hasher>(&self, state: &mut H) {
183 self.types.hash(state);
184 }
185}
186
187pub struct CertificateNegotiationManager {
189 config: NegotiationConfig,
191 sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
193 cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
195 stats: Arc<Mutex<NegotiationStats>>,
197}
198
199#[derive(Debug, Default, Clone)]
201pub struct NegotiationStats {
202 pub total_attempts: u64,
204 pub successful: u64,
206 pub failed: u64,
208 pub timed_out: u64,
210 pub cache_hits: u64,
212 pub cache_misses: u64,
214 pub avg_negotiation_time: Duration,
216}
217
218impl CertificateNegotiationManager {
219 pub fn new(config: NegotiationConfig) -> Self {
221 Self {
222 config,
223 sessions: RwLock::new(HashMap::new()),
224 cache: Arc::new(Mutex::new(HashMap::new())),
225 stats: Arc::new(Mutex::new(NegotiationStats::default())),
226 }
227 }
228
229 pub fn start_negotiation(
231 &self,
232 preferences: CertificateTypePreferences,
233 ) -> Result<NegotiationId, TlsExtensionError> {
234 let id = NegotiationId::new();
235 let state = NegotiationState::Waiting {
236 sent_at: Instant::now(),
237 our_preferences: preferences,
238 };
239
240 let mut sessions = self.sessions.write().map_err(|e| {
241 TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
242 })?;
243 sessions.insert(id, state);
244
245 let mut stats = self.stats.lock().map_err(|e| {
246 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
247 })?;
248 stats.total_attempts += 1;
249
250 debug!("Started certificate type negotiation: {:?}", id);
251 Ok(id)
252 }
253
254 pub fn complete_negotiation(
256 &self,
257 id: NegotiationId,
258 remote_client_types: Option<CertificateTypeList>,
259 remote_server_types: Option<CertificateTypeList>,
260 ) -> Result<NegotiationResult, TlsExtensionError> {
261 let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
262
263 let mut sessions = self.sessions.write().map_err(|e| {
264 TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
265 })?;
266 let state = sessions.get(&id).ok_or_else(|| {
267 TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {id:?}"))
268 })?;
269
270 let our_preferences = match state {
271 NegotiationState::Waiting {
272 our_preferences, ..
273 } => our_preferences.clone(),
274 _ => {
275 return Err(TlsExtensionError::InvalidExtensionData(
276 "Negotiation not in waiting state".to_string(),
277 ));
278 }
279 };
280
281 if self.config.enable_caching {
283 let cache_key = CacheKey::new(
284 &our_preferences,
285 remote_client_types.as_ref(),
286 remote_server_types.as_ref(),
287 );
288
289 let mut cache = self.cache.lock().map_err(|e| {
290 TlsExtensionError::InvalidExtensionData(format!("Cache lock poisoned: {}", e))
291 })?;
292 if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
293 if cached_at.elapsed() < Duration::from_secs(300) {
295 let mut stats = self.stats.lock().map_err(|e| {
297 TlsExtensionError::InvalidExtensionData(format!(
298 "Stats lock poisoned: {}",
299 e
300 ))
301 })?;
302 stats.cache_hits += 1;
303
304 sessions.insert(
306 id,
307 NegotiationState::Completed {
308 result: cached_result.clone(),
309 completed_at: Instant::now(),
310 },
311 );
312
313 debug!("Cache hit for negotiation: {:?}", id);
314 return Ok(cached_result.clone());
315 } else {
316 cache.remove(&cache_key);
318 }
319 }
320
321 let mut stats = self.stats.lock().map_err(|e| {
322 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
323 })?;
324 stats.cache_misses += 1;
325 }
326
327 let negotiation_start = Instant::now();
329 let result =
330 our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
331
332 match result {
333 Ok(negotiation_result) => {
334 let completed_at = Instant::now();
335 let negotiation_time = negotiation_start.elapsed();
336
337 sessions.insert(
339 id,
340 NegotiationState::Completed {
341 result: negotiation_result.clone(),
342 completed_at,
343 },
344 );
345
346 let mut stats = self.stats.lock().map_err(|e| {
348 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
349 })?;
350 stats.successful += 1;
351
352 let total_completed = stats.successful + stats.failed;
354 stats.avg_negotiation_time = if total_completed == 1 {
355 negotiation_time
356 } else {
357 Duration::from_nanos(
358 (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
359 + negotiation_time.as_nanos() as u64)
360 / total_completed,
361 )
362 };
363
364 if self.config.enable_caching {
366 let cache_key = CacheKey::new(
367 &our_preferences,
368 remote_client_types.as_ref(),
369 remote_server_types.as_ref(),
370 );
371
372 let mut cache = self.cache.lock().map_err(|e| {
373 TlsExtensionError::InvalidExtensionData(format!(
374 "Cache lock poisoned: {}",
375 e
376 ))
377 })?;
378
379 if cache.len() >= self.config.max_cache_size {
381 let mut entries: Vec<_> =
383 cache.iter().map(|(k, (_, t))| (k.clone(), *t)).collect();
384 entries.sort_by_key(|(_, timestamp)| *timestamp);
385
386 let to_remove = cache.len() - self.config.max_cache_size + 1;
387 let keys_to_remove: Vec<_> = entries
388 .iter()
389 .take(to_remove)
390 .map(|(key, _)| key.clone())
391 .collect();
392
393 for key in keys_to_remove {
394 cache.remove(&key);
395 }
396 }
397
398 cache.insert(cache_key, (negotiation_result.clone(), completed_at));
399 }
400
401 info!(
402 "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
403 id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
404 );
405
406 Ok(negotiation_result)
407 }
408 Err(error) => {
409 sessions.insert(
411 id,
412 NegotiationState::Failed {
413 error: error.to_string(),
414 failed_at: Instant::now(),
415 },
416 );
417
418 let mut stats = self.stats.lock().map_err(|e| {
420 TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
421 })?;
422 stats.failed += 1;
423
424 warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
425 Err(error)
426 }
427 }
428 }
429
430 pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
432 let mut sessions = self
433 .sessions
434 .write()
435 .expect("Session lock should not be poisoned");
436 sessions.insert(
437 id,
438 NegotiationState::Failed {
439 error,
440 failed_at: Instant::now(),
441 },
442 );
443
444 let mut stats = self
445 .stats
446 .lock()
447 .expect("Stats lock should not be poisoned");
448 stats.failed += 1;
449
450 warn!("Certificate type negotiation failed: {:?}", id);
451 }
452
453 pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
455 let sessions = self
456 .sessions
457 .read()
458 .expect("Session lock should not be poisoned");
459 sessions.get(&id).cloned()
460 }
461
462 pub fn handle_timeouts(&self) {
464 let mut sessions = self
465 .sessions
466 .write()
467 .expect("Session lock should not be poisoned");
468 let mut timed_out_ids = Vec::new();
469
470 for (id, state) in sessions.iter() {
471 if let NegotiationState::Waiting { sent_at, .. } = state {
472 if sent_at.elapsed() > self.config.timeout {
473 timed_out_ids.push(*id);
474 }
475 }
476 }
477
478 for id in timed_out_ids {
479 sessions.insert(
480 id,
481 NegotiationState::TimedOut {
482 timeout_at: Instant::now(),
483 },
484 );
485
486 let mut stats = self
487 .stats
488 .lock()
489 .expect("Stats lock should not be poisoned");
490 stats.timed_out += 1;
491
492 warn!("Certificate type negotiation timed out: {:?}", id);
493 }
494 }
495
496 pub fn cleanup_old_sessions(&self, max_age: Duration) {
498 let mut sessions = self
499 .sessions
500 .write()
501 .expect("Session lock should not be poisoned");
502 let cutoff = Instant::now() - max_age;
503
504 sessions.retain(|id, state| {
505 let should_retain = match state {
506 NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
507 NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
508 NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
509 _ => true, };
511
512 if !should_retain {
513 debug!("Cleaned up old negotiation session: {:?}", id);
514 }
515
516 should_retain
517 });
518 }
519
520 pub fn get_stats(&self) -> NegotiationStats {
522 self.stats
523 .lock()
524 .expect("Stats lock should not be poisoned")
525 .clone()
526 }
527
528 pub fn clear_cache(&self) {
530 let mut cache = self
531 .cache
532 .lock()
533 .expect("Cache lock should not be poisoned");
534 cache.clear();
535 debug!("Cleared certificate type negotiation cache");
536 }
537
538 pub fn get_cache_stats(&self) -> (usize, usize) {
540 let cache = self
541 .cache
542 .lock()
543 .expect("Cache lock should not be poisoned");
544 (cache.len(), self.config.max_cache_size)
545 }
546}
547
548impl Default for CertificateNegotiationManager {
549 fn default() -> Self {
550 Self::new(NegotiationConfig::default())
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::super::tls_extensions::CertificateType;
557 use super::*;
558
559 #[test]
560 fn test_negotiation_id_generation() {
561 let id1 = NegotiationId::new();
562 let id2 = NegotiationId::new();
563
564 assert_ne!(id1, id2);
565 assert!(id1.as_u64() > 0);
566 assert!(id2.as_u64() > 0);
567 }
568
569 #[test]
570 fn test_negotiation_state_checks() {
571 let pending = NegotiationState::Pending;
572 assert!(!pending.is_complete());
573 assert!(!pending.is_successful());
574
575 let completed = NegotiationState::Completed {
576 result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
577 completed_at: Instant::now(),
578 };
579 assert!(completed.is_complete());
580 assert!(completed.is_successful());
581 assert!(completed.get_result().is_some());
582
583 let failed = NegotiationState::Failed {
584 error: "Test error".to_string(),
585 failed_at: Instant::now(),
586 };
587 assert!(failed.is_complete());
588 assert!(!failed.is_successful());
589 assert_eq!(failed.get_error().unwrap(), "Test error");
590 }
591
592 #[test]
593 fn test_negotiation_manager_basic_flow() {
594 let manager = CertificateNegotiationManager::default();
595 let preferences = CertificateTypePreferences::prefer_raw_public_key();
596
597 let id = manager.start_negotiation(preferences).unwrap();
599
600 let state = manager.get_negotiation_state(id).unwrap();
601 assert!(matches!(state, NegotiationState::Waiting { .. }));
602
603 let remote_types = CertificateTypeList::raw_public_key_only();
605 let result = manager
606 .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
607 .unwrap();
608
609 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
610 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
611
612 let state = manager.get_negotiation_state(id).unwrap();
613 assert!(state.is_successful());
614 }
615
616 #[test]
617 fn test_negotiation_caching() {
618 let config = NegotiationConfig {
619 enable_caching: true,
620 ..Default::default()
621 };
622 let manager = CertificateNegotiationManager::new(config);
623 let preferences = CertificateTypePreferences::prefer_raw_public_key();
624
625 let id1 = manager.start_negotiation(preferences.clone()).unwrap();
627 let remote_types = CertificateTypeList::raw_public_key_only();
628 let result1 = manager
629 .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
630 .unwrap();
631
632 let id2 = manager.start_negotiation(preferences).unwrap();
634 let result2 = manager
635 .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
636 .unwrap();
637
638 assert_eq!(result1, result2);
639
640 let stats = manager.get_stats();
641 assert_eq!(stats.cache_hits, 1);
642 assert_eq!(stats.cache_misses, 1);
643 }
644
645 #[test]
646 fn test_negotiation_timeout_handling() {
647 let config = NegotiationConfig {
648 timeout: Duration::from_millis(1),
649 ..Default::default()
650 };
651 let manager = CertificateNegotiationManager::new(config);
652 let preferences = CertificateTypePreferences::prefer_raw_public_key();
653
654 let id = manager.start_negotiation(preferences).unwrap();
655
656 std::thread::sleep(Duration::from_millis(10));
658 manager.handle_timeouts();
659
660 let state = manager.get_negotiation_state(id).unwrap();
661 assert!(matches!(state, NegotiationState::TimedOut { .. }));
662
663 let stats = manager.get_stats();
664 assert_eq!(stats.timed_out, 1);
665 }
666}