1use std::{
8 collections::HashMap,
9 hash::{Hash, Hasher},
10 sync::{Arc, Mutex, RwLock},
11 time::{Duration, Instant},
12};
13
14use tracing::{Level, debug, info, span, warn};
15
16use super::tls_extensions::{
17 CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
18};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NegotiationState {
23 Pending,
25 Waiting {
27 sent_at: Instant,
28 our_preferences: CertificateTypePreferences,
29 },
30 Completed {
32 result: NegotiationResult,
33 completed_at: Instant,
34 },
35 Failed {
37 error: String,
39 failed_at: Instant,
41 },
42 TimedOut {
44 timeout_at: Instant,
46 },
47}
48
49impl NegotiationState {
50 pub fn is_complete(&self) -> bool {
52 matches!(
53 self,
54 Self::Completed { .. } | Self::Failed { .. } | Self::TimedOut { .. }
55 )
56 }
57
58 pub fn is_successful(&self) -> bool {
60 matches!(self, Self::Completed { .. })
61 }
62
63 pub fn get_result(&self) -> Option<&NegotiationResult> {
65 match self {
66 Self::Completed { result, .. } => Some(result),
67 _ => None,
68 }
69 }
70
71 pub fn get_error(&self) -> Option<&str> {
73 match self {
74 Self::Failed { error, .. } => Some(error),
75 _ => None,
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct NegotiationConfig {
83 pub timeout: Duration,
85 pub enable_caching: bool,
87 pub max_cache_size: usize,
89 pub allow_fallback: bool,
91 pub default_preferences: CertificateTypePreferences,
93}
94
95impl Default for NegotiationConfig {
96 fn default() -> Self {
97 Self {
98 timeout: Duration::from_secs(10),
99 enable_caching: true,
100 max_cache_size: 1000,
101 allow_fallback: true,
102 default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
103 }
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
109pub struct NegotiationId(u64);
110
111impl Default for NegotiationId {
112 fn default() -> Self {
113 Self::new()
114 }
115}
116
117impl NegotiationId {
118 pub fn new() -> Self {
120 use std::sync::atomic::{AtomicU64, Ordering};
121 static COUNTER: AtomicU64 = AtomicU64::new(1);
122 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
123 }
124
125 pub fn as_u64(self) -> u64 {
127 self.0
128 }
129}
130
131#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133struct CacheKey {
134 local_preferences: String, remote_preferences: String, }
139
140impl CacheKey {
141 fn new(
143 local: &CertificateTypePreferences,
144 remote_client: Option<&CertificateTypeList>,
145 remote_server: Option<&CertificateTypeList>,
146 ) -> Self {
147 use std::collections::hash_map::DefaultHasher;
148
149 let mut hasher = DefaultHasher::new();
150 local.hash(&mut hasher);
151 let local_hash = hasher.finish();
152
153 let mut hasher = DefaultHasher::new();
154 if let Some(types) = remote_client {
155 types.hash(&mut hasher);
156 }
157 if let Some(types) = remote_server {
158 types.hash(&mut hasher);
159 }
160 let remote_hash = hasher.finish();
161
162 Self {
163 local_preferences: format!("{local_hash:x}"),
164 remote_preferences: format!("{remote_hash:x}"),
165 }
166 }
167}
168
169impl Hash for CertificateTypePreferences {
171 fn hash<H: Hasher>(&self, state: &mut H) {
172 self.client_types.types.hash(state);
173 self.server_types.types.hash(state);
174 self.require_extensions.hash(state);
175 self.fallback_client.hash(state);
176 self.fallback_server.hash(state);
177 }
178}
179
180impl Hash for CertificateTypeList {
182 fn hash<H: Hasher>(&self, state: &mut H) {
183 self.types.hash(state);
184 }
185}
186
187pub struct CertificateNegotiationManager {
189 config: NegotiationConfig,
191 sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
193 cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
195 stats: Arc<Mutex<NegotiationStats>>,
197}
198
199#[derive(Debug, Default, Clone)]
201pub struct NegotiationStats {
202 pub total_attempts: u64,
204 pub successful: u64,
206 pub failed: u64,
208 pub timed_out: u64,
210 pub cache_hits: u64,
212 pub cache_misses: u64,
214 pub avg_negotiation_time: Duration,
216}
217
218impl CertificateNegotiationManager {
219 pub fn new(config: NegotiationConfig) -> Self {
221 Self {
222 config,
223 sessions: RwLock::new(HashMap::new()),
224 cache: Arc::new(Mutex::new(HashMap::new())),
225 stats: Arc::new(Mutex::new(NegotiationStats::default())),
226 }
227 }
228
229 pub fn start_negotiation(&self, preferences: CertificateTypePreferences) -> NegotiationId {
231 let id = NegotiationId::new();
232 let state = NegotiationState::Waiting {
233 sent_at: Instant::now(),
234 our_preferences: preferences,
235 };
236
237 let mut sessions = self.sessions.write().unwrap();
238 sessions.insert(id, state);
239
240 let mut stats = self.stats.lock().unwrap();
241 stats.total_attempts += 1;
242
243 debug!("Started certificate type negotiation: {:?}", id);
244 id
245 }
246
247 pub fn complete_negotiation(
249 &self,
250 id: NegotiationId,
251 remote_client_types: Option<CertificateTypeList>,
252 remote_server_types: Option<CertificateTypeList>,
253 ) -> Result<NegotiationResult, TlsExtensionError> {
254 let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
255
256 let mut sessions = self.sessions.write().unwrap();
257 let state = sessions.get(&id).ok_or_else(|| {
258 TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {id:?}"))
259 })?;
260
261 let our_preferences = match state {
262 NegotiationState::Waiting {
263 our_preferences, ..
264 } => our_preferences.clone(),
265 _ => {
266 return Err(TlsExtensionError::InvalidExtensionData(
267 "Negotiation not in waiting state".to_string(),
268 ));
269 }
270 };
271
272 if self.config.enable_caching {
274 let cache_key = CacheKey::new(
275 &our_preferences,
276 remote_client_types.as_ref(),
277 remote_server_types.as_ref(),
278 );
279
280 let mut cache = self.cache.lock().unwrap();
281 if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
282 if cached_at.elapsed() < Duration::from_secs(300) {
284 let mut stats = self.stats.lock().unwrap();
286 stats.cache_hits += 1;
287
288 sessions.insert(
290 id,
291 NegotiationState::Completed {
292 result: cached_result.clone(),
293 completed_at: Instant::now(),
294 },
295 );
296
297 debug!("Cache hit for negotiation: {:?}", id);
298 return Ok(cached_result.clone());
299 } else {
300 cache.remove(&cache_key);
302 }
303 }
304
305 let mut stats = self.stats.lock().unwrap();
306 stats.cache_misses += 1;
307 }
308
309 let negotiation_start = Instant::now();
311 let result =
312 our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
313
314 match result {
315 Ok(negotiation_result) => {
316 let completed_at = Instant::now();
317 let negotiation_time = negotiation_start.elapsed();
318
319 sessions.insert(
321 id,
322 NegotiationState::Completed {
323 result: negotiation_result.clone(),
324 completed_at,
325 },
326 );
327
328 let mut stats = self.stats.lock().unwrap();
330 stats.successful += 1;
331
332 let total_completed = stats.successful + stats.failed;
334 stats.avg_negotiation_time = if total_completed == 1 {
335 negotiation_time
336 } else {
337 Duration::from_nanos(
338 (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
339 + negotiation_time.as_nanos() as u64)
340 / total_completed,
341 )
342 };
343
344 if self.config.enable_caching {
346 let cache_key = CacheKey::new(
347 &our_preferences,
348 remote_client_types.as_ref(),
349 remote_server_types.as_ref(),
350 );
351
352 let mut cache = self.cache.lock().unwrap();
353
354 if cache.len() >= self.config.max_cache_size {
356 let mut entries: Vec<_> =
358 cache.iter().map(|(k, (_, t))| (k.clone(), *t)).collect();
359 entries.sort_by_key(|(_, timestamp)| *timestamp);
360
361 let to_remove = cache.len() - self.config.max_cache_size + 1;
362 let keys_to_remove: Vec<_> = entries
363 .iter()
364 .take(to_remove)
365 .map(|(key, _)| key.clone())
366 .collect();
367
368 for key in keys_to_remove {
369 cache.remove(&key);
370 }
371 }
372
373 cache.insert(cache_key, (negotiation_result.clone(), completed_at));
374 }
375
376 info!(
377 "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
378 id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
379 );
380
381 Ok(negotiation_result)
382 }
383 Err(error) => {
384 sessions.insert(
386 id,
387 NegotiationState::Failed {
388 error: error.to_string(),
389 failed_at: Instant::now(),
390 },
391 );
392
393 let mut stats = self.stats.lock().unwrap();
395 stats.failed += 1;
396
397 warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
398 Err(error)
399 }
400 }
401 }
402
403 pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
405 let mut sessions = self.sessions.write().unwrap();
406 sessions.insert(
407 id,
408 NegotiationState::Failed {
409 error,
410 failed_at: Instant::now(),
411 },
412 );
413
414 let mut stats = self.stats.lock().unwrap();
415 stats.failed += 1;
416
417 warn!("Certificate type negotiation failed: {:?}", id);
418 }
419
420 pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
422 let sessions = self.sessions.read().unwrap();
423 sessions.get(&id).cloned()
424 }
425
426 pub fn handle_timeouts(&self) {
428 let mut sessions = self.sessions.write().unwrap();
429 let mut timed_out_ids = Vec::new();
430
431 for (id, state) in sessions.iter() {
432 if let NegotiationState::Waiting { sent_at, .. } = state {
433 if sent_at.elapsed() > self.config.timeout {
434 timed_out_ids.push(*id);
435 }
436 }
437 }
438
439 for id in timed_out_ids {
440 sessions.insert(
441 id,
442 NegotiationState::TimedOut {
443 timeout_at: Instant::now(),
444 },
445 );
446
447 let mut stats = self.stats.lock().unwrap();
448 stats.timed_out += 1;
449
450 warn!("Certificate type negotiation timed out: {:?}", id);
451 }
452 }
453
454 pub fn cleanup_old_sessions(&self, max_age: Duration) {
456 let mut sessions = self.sessions.write().unwrap();
457 let cutoff = Instant::now() - max_age;
458
459 sessions.retain(|id, state| {
460 let should_retain = match state {
461 NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
462 NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
463 NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
464 _ => true, };
466
467 if !should_retain {
468 debug!("Cleaned up old negotiation session: {:?}", id);
469 }
470
471 should_retain
472 });
473 }
474
475 pub fn get_stats(&self) -> NegotiationStats {
477 self.stats.lock().unwrap().clone()
478 }
479
480 pub fn clear_cache(&self) {
482 let mut cache = self.cache.lock().unwrap();
483 cache.clear();
484 debug!("Cleared certificate type negotiation cache");
485 }
486
487 pub fn get_cache_stats(&self) -> (usize, usize) {
489 let cache = self.cache.lock().unwrap();
490 (cache.len(), self.config.max_cache_size)
491 }
492}
493
494impl Default for CertificateNegotiationManager {
495 fn default() -> Self {
496 Self::new(NegotiationConfig::default())
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::super::tls_extensions::CertificateType;
503 use super::*;
504
505 #[test]
506 fn test_negotiation_id_generation() {
507 let id1 = NegotiationId::new();
508 let id2 = NegotiationId::new();
509
510 assert_ne!(id1, id2);
511 assert!(id1.as_u64() > 0);
512 assert!(id2.as_u64() > 0);
513 }
514
515 #[test]
516 fn test_negotiation_state_checks() {
517 let pending = NegotiationState::Pending;
518 assert!(!pending.is_complete());
519 assert!(!pending.is_successful());
520
521 let completed = NegotiationState::Completed {
522 result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
523 completed_at: Instant::now(),
524 };
525 assert!(completed.is_complete());
526 assert!(completed.is_successful());
527 assert!(completed.get_result().is_some());
528
529 let failed = NegotiationState::Failed {
530 error: "Test error".to_string(),
531 failed_at: Instant::now(),
532 };
533 assert!(failed.is_complete());
534 assert!(!failed.is_successful());
535 assert_eq!(failed.get_error().unwrap(), "Test error");
536 }
537
538 #[test]
539 fn test_negotiation_manager_basic_flow() {
540 let manager = CertificateNegotiationManager::default();
541 let preferences = CertificateTypePreferences::prefer_raw_public_key();
542
543 let id = manager.start_negotiation(preferences);
545
546 let state = manager.get_negotiation_state(id).unwrap();
547 assert!(matches!(state, NegotiationState::Waiting { .. }));
548
549 let remote_types = CertificateTypeList::raw_public_key_only();
551 let result = manager
552 .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
553 .unwrap();
554
555 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
556 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
557
558 let state = manager.get_negotiation_state(id).unwrap();
559 assert!(state.is_successful());
560 }
561
562 #[test]
563 fn test_negotiation_caching() {
564 let config = NegotiationConfig {
565 enable_caching: true,
566 ..Default::default()
567 };
568 let manager = CertificateNegotiationManager::new(config);
569 let preferences = CertificateTypePreferences::prefer_raw_public_key();
570
571 let id1 = manager.start_negotiation(preferences.clone());
573 let remote_types = CertificateTypeList::raw_public_key_only();
574 let result1 = manager
575 .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
576 .unwrap();
577
578 let id2 = manager.start_negotiation(preferences);
580 let result2 = manager
581 .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
582 .unwrap();
583
584 assert_eq!(result1, result2);
585
586 let stats = manager.get_stats();
587 assert_eq!(stats.cache_hits, 1);
588 assert_eq!(stats.cache_misses, 1);
589 }
590
591 #[test]
592 fn test_negotiation_timeout_handling() {
593 let config = NegotiationConfig {
594 timeout: Duration::from_millis(1),
595 ..Default::default()
596 };
597 let manager = CertificateNegotiationManager::new(config);
598 let preferences = CertificateTypePreferences::prefer_raw_public_key();
599
600 let id = manager.start_negotiation(preferences);
601
602 std::thread::sleep(Duration::from_millis(10));
604 manager.handle_timeouts();
605
606 let state = manager.get_negotiation_state(id).unwrap();
607 assert!(matches!(state, NegotiationState::TimedOut { .. }));
608
609 let stats = manager.get_stats();
610 assert_eq!(stats.timed_out, 1);
611 }
612}