1use std::{
8 collections::HashMap,
9 hash::{Hash, Hasher},
10 sync::{Arc, Mutex, RwLock},
11 time::{Duration, Instant},
12};
13
14use tracing::{Level, debug, info, span, warn};
15
16use super::tls_extensions::{
17 CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
18};
19
20#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NegotiationState {
23 Pending,
25 Waiting {
27 sent_at: Instant,
28 our_preferences: CertificateTypePreferences,
29 },
30 Completed {
32 result: NegotiationResult,
33 completed_at: Instant,
34 },
35 Failed { error: String, failed_at: Instant },
37 TimedOut { timeout_at: Instant },
39}
40
41impl NegotiationState {
42 pub fn is_complete(&self) -> bool {
44 matches!(
45 self,
46 NegotiationState::Completed { .. }
47 | NegotiationState::Failed { .. }
48 | NegotiationState::TimedOut { .. }
49 )
50 }
51
52 pub fn is_successful(&self) -> bool {
54 matches!(self, NegotiationState::Completed { .. })
55 }
56
57 pub fn get_result(&self) -> Option<&NegotiationResult> {
59 match self {
60 NegotiationState::Completed { result, .. } => Some(result),
61 _ => None,
62 }
63 }
64
65 pub fn get_error(&self) -> Option<&str> {
67 match self {
68 NegotiationState::Failed { error, .. } => Some(error),
69 _ => None,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct NegotiationConfig {
77 pub timeout: Duration,
79 pub enable_caching: bool,
81 pub max_cache_size: usize,
83 pub allow_fallback: bool,
85 pub default_preferences: CertificateTypePreferences,
87}
88
89impl Default for NegotiationConfig {
90 fn default() -> Self {
91 Self {
92 timeout: Duration::from_secs(10),
93 enable_caching: true,
94 max_cache_size: 1000,
95 allow_fallback: true,
96 default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
97 }
98 }
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
103pub struct NegotiationId(u64);
104
105impl NegotiationId {
106 pub fn new() -> Self {
108 use std::sync::atomic::{AtomicU64, Ordering};
109 static COUNTER: AtomicU64 = AtomicU64::new(1);
110 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
111 }
112
113 pub fn as_u64(self) -> u64 {
115 self.0
116 }
117}
118
119#[derive(Debug, Clone, PartialEq, Eq, Hash)]
121struct CacheKey {
122 local_preferences: String, remote_preferences: String, }
127
128impl CacheKey {
129 fn new(
131 local: &CertificateTypePreferences,
132 remote_client: Option<&CertificateTypeList>,
133 remote_server: Option<&CertificateTypeList>,
134 ) -> Self {
135 use std::collections::hash_map::DefaultHasher;
136
137 let mut hasher = DefaultHasher::new();
138 local.hash(&mut hasher);
139 let local_hash = hasher.finish();
140
141 let mut hasher = DefaultHasher::new();
142 if let Some(types) = remote_client {
143 types.hash(&mut hasher);
144 }
145 if let Some(types) = remote_server {
146 types.hash(&mut hasher);
147 }
148 let remote_hash = hasher.finish();
149
150 Self {
151 local_preferences: format!("{:x}", local_hash),
152 remote_preferences: format!("{:x}", remote_hash),
153 }
154 }
155}
156
157impl Hash for CertificateTypePreferences {
159 fn hash<H: Hasher>(&self, state: &mut H) {
160 self.client_types.types.hash(state);
161 self.server_types.types.hash(state);
162 self.require_extensions.hash(state);
163 self.fallback_client.hash(state);
164 self.fallback_server.hash(state);
165 }
166}
167
168impl Hash for CertificateTypeList {
170 fn hash<H: Hasher>(&self, state: &mut H) {
171 self.types.hash(state);
172 }
173}
174
175pub struct CertificateNegotiationManager {
177 config: NegotiationConfig,
179 sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
181 cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
183 stats: Arc<Mutex<NegotiationStats>>,
185}
186
187#[derive(Debug, Default, Clone)]
189pub struct NegotiationStats {
190 pub total_attempts: u64,
192 pub successful: u64,
194 pub failed: u64,
196 pub timed_out: u64,
198 pub cache_hits: u64,
200 pub cache_misses: u64,
202 pub avg_negotiation_time: Duration,
204}
205
206impl CertificateNegotiationManager {
207 pub fn new(config: NegotiationConfig) -> Self {
209 Self {
210 config,
211 sessions: RwLock::new(HashMap::new()),
212 cache: Arc::new(Mutex::new(HashMap::new())),
213 stats: Arc::new(Mutex::new(NegotiationStats::default())),
214 }
215 }
216
217 pub fn start_negotiation(&self, preferences: CertificateTypePreferences) -> NegotiationId {
219 let id = NegotiationId::new();
220 let state = NegotiationState::Waiting {
221 sent_at: Instant::now(),
222 our_preferences: preferences,
223 };
224
225 let mut sessions = self.sessions.write().unwrap();
226 sessions.insert(id, state);
227
228 let mut stats = self.stats.lock().unwrap();
229 stats.total_attempts += 1;
230
231 debug!("Started certificate type negotiation: {:?}", id);
232 id
233 }
234
235 pub fn complete_negotiation(
237 &self,
238 id: NegotiationId,
239 remote_client_types: Option<CertificateTypeList>,
240 remote_server_types: Option<CertificateTypeList>,
241 ) -> Result<NegotiationResult, TlsExtensionError> {
242 let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
243
244 let mut sessions = self.sessions.write().unwrap();
245 let state = sessions.get(&id).ok_or_else(|| {
246 TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {:?}", id))
247 })?;
248
249 let our_preferences = match state {
250 NegotiationState::Waiting {
251 our_preferences, ..
252 } => our_preferences.clone(),
253 _ => {
254 return Err(TlsExtensionError::InvalidExtensionData(
255 "Negotiation not in waiting state".to_string(),
256 ));
257 }
258 };
259
260 if self.config.enable_caching {
262 let cache_key = CacheKey::new(
263 &our_preferences,
264 remote_client_types.as_ref(),
265 remote_server_types.as_ref(),
266 );
267
268 let mut cache = self.cache.lock().unwrap();
269 if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
270 if cached_at.elapsed() < Duration::from_secs(300) {
272 let mut stats = self.stats.lock().unwrap();
274 stats.cache_hits += 1;
275
276 sessions.insert(
278 id,
279 NegotiationState::Completed {
280 result: cached_result.clone(),
281 completed_at: Instant::now(),
282 },
283 );
284
285 debug!("Cache hit for negotiation: {:?}", id);
286 return Ok(cached_result.clone());
287 } else {
288 cache.remove(&cache_key);
290 }
291 }
292
293 let mut stats = self.stats.lock().unwrap();
294 stats.cache_misses += 1;
295 }
296
297 let negotiation_start = Instant::now();
299 let result =
300 our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
301
302 match result {
303 Ok(negotiation_result) => {
304 let completed_at = Instant::now();
305 let negotiation_time = negotiation_start.elapsed();
306
307 sessions.insert(
309 id,
310 NegotiationState::Completed {
311 result: negotiation_result.clone(),
312 completed_at,
313 },
314 );
315
316 let mut stats = self.stats.lock().unwrap();
318 stats.successful += 1;
319
320 let total_completed = stats.successful + stats.failed;
322 stats.avg_negotiation_time = if total_completed == 1 {
323 negotiation_time
324 } else {
325 Duration::from_nanos(
326 (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
327 + negotiation_time.as_nanos() as u64)
328 / total_completed,
329 )
330 };
331
332 if self.config.enable_caching {
334 let cache_key = CacheKey::new(
335 &our_preferences,
336 remote_client_types.as_ref(),
337 remote_server_types.as_ref(),
338 );
339
340 let mut cache = self.cache.lock().unwrap();
341
342 if cache.len() >= self.config.max_cache_size {
344 let mut entries: Vec<_> = cache
346 .iter()
347 .map(|(k, (_, t))| (k.clone(), t.clone()))
348 .collect();
349 entries.sort_by_key(|(_, timestamp)| *timestamp);
350
351 let to_remove = cache.len() - self.config.max_cache_size + 1;
352 let keys_to_remove: Vec<_> = entries
353 .iter()
354 .take(to_remove)
355 .map(|(key, _)| key.clone())
356 .collect();
357
358 for key in keys_to_remove {
359 cache.remove(&key);
360 }
361 }
362
363 cache.insert(cache_key, (negotiation_result.clone(), completed_at));
364 }
365
366 info!(
367 "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
368 id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
369 );
370
371 Ok(negotiation_result)
372 }
373 Err(error) => {
374 sessions.insert(
376 id,
377 NegotiationState::Failed {
378 error: error.to_string(),
379 failed_at: Instant::now(),
380 },
381 );
382
383 let mut stats = self.stats.lock().unwrap();
385 stats.failed += 1;
386
387 warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
388 Err(error)
389 }
390 }
391 }
392
393 pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
395 let mut sessions = self.sessions.write().unwrap();
396 sessions.insert(
397 id,
398 NegotiationState::Failed {
399 error,
400 failed_at: Instant::now(),
401 },
402 );
403
404 let mut stats = self.stats.lock().unwrap();
405 stats.failed += 1;
406
407 warn!("Certificate type negotiation failed: {:?}", id);
408 }
409
410 pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
412 let sessions = self.sessions.read().unwrap();
413 sessions.get(&id).cloned()
414 }
415
416 pub fn handle_timeouts(&self) {
418 let mut sessions = self.sessions.write().unwrap();
419 let mut timed_out_ids = Vec::new();
420
421 for (id, state) in sessions.iter() {
422 if let NegotiationState::Waiting { sent_at, .. } = state {
423 if sent_at.elapsed() > self.config.timeout {
424 timed_out_ids.push(*id);
425 }
426 }
427 }
428
429 for id in timed_out_ids {
430 sessions.insert(
431 id,
432 NegotiationState::TimedOut {
433 timeout_at: Instant::now(),
434 },
435 );
436
437 let mut stats = self.stats.lock().unwrap();
438 stats.timed_out += 1;
439
440 warn!("Certificate type negotiation timed out: {:?}", id);
441 }
442 }
443
444 pub fn cleanup_old_sessions(&self, max_age: Duration) {
446 let mut sessions = self.sessions.write().unwrap();
447 let cutoff = Instant::now() - max_age;
448
449 sessions.retain(|id, state| {
450 let should_retain = match state {
451 NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
452 NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
453 NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
454 _ => true, };
456
457 if !should_retain {
458 debug!("Cleaned up old negotiation session: {:?}", id);
459 }
460
461 should_retain
462 });
463 }
464
465 pub fn get_stats(&self) -> NegotiationStats {
467 self.stats.lock().unwrap().clone()
468 }
469
470 pub fn clear_cache(&self) {
472 let mut cache = self.cache.lock().unwrap();
473 cache.clear();
474 debug!("Cleared certificate type negotiation cache");
475 }
476
477 pub fn get_cache_stats(&self) -> (usize, usize) {
479 let cache = self.cache.lock().unwrap();
480 (cache.len(), self.config.max_cache_size)
481 }
482}
483
484impl Default for CertificateNegotiationManager {
485 fn default() -> Self {
486 Self::new(NegotiationConfig::default())
487 }
488}
489
490#[cfg(test)]
491mod tests {
492 use super::super::tls_extensions::CertificateType;
493 use super::*;
494
495 #[test]
496 fn test_negotiation_id_generation() {
497 let id1 = NegotiationId::new();
498 let id2 = NegotiationId::new();
499
500 assert_ne!(id1, id2);
501 assert!(id1.as_u64() > 0);
502 assert!(id2.as_u64() > 0);
503 }
504
505 #[test]
506 fn test_negotiation_state_checks() {
507 let pending = NegotiationState::Pending;
508 assert!(!pending.is_complete());
509 assert!(!pending.is_successful());
510
511 let completed = NegotiationState::Completed {
512 result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
513 completed_at: Instant::now(),
514 };
515 assert!(completed.is_complete());
516 assert!(completed.is_successful());
517 assert!(completed.get_result().is_some());
518
519 let failed = NegotiationState::Failed {
520 error: "Test error".to_string(),
521 failed_at: Instant::now(),
522 };
523 assert!(failed.is_complete());
524 assert!(!failed.is_successful());
525 assert_eq!(failed.get_error().unwrap(), "Test error");
526 }
527
528 #[test]
529 fn test_negotiation_manager_basic_flow() {
530 let manager = CertificateNegotiationManager::default();
531 let preferences = CertificateTypePreferences::prefer_raw_public_key();
532
533 let id = manager.start_negotiation(preferences);
535
536 let state = manager.get_negotiation_state(id).unwrap();
537 assert!(matches!(state, NegotiationState::Waiting { .. }));
538
539 let remote_types = CertificateTypeList::raw_public_key_only();
541 let result = manager
542 .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
543 .unwrap();
544
545 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
546 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
547
548 let state = manager.get_negotiation_state(id).unwrap();
549 assert!(state.is_successful());
550 }
551
552 #[test]
553 fn test_negotiation_caching() {
554 let config = NegotiationConfig {
555 enable_caching: true,
556 ..Default::default()
557 };
558 let manager = CertificateNegotiationManager::new(config);
559 let preferences = CertificateTypePreferences::prefer_raw_public_key();
560
561 let id1 = manager.start_negotiation(preferences.clone());
563 let remote_types = CertificateTypeList::raw_public_key_only();
564 let result1 = manager
565 .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
566 .unwrap();
567
568 let id2 = manager.start_negotiation(preferences);
570 let result2 = manager
571 .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
572 .unwrap();
573
574 assert_eq!(result1, result2);
575
576 let stats = manager.get_stats();
577 assert_eq!(stats.cache_hits, 1);
578 assert_eq!(stats.cache_misses, 1);
579 }
580
581 #[test]
582 fn test_negotiation_timeout_handling() {
583 let config = NegotiationConfig {
584 timeout: Duration::from_millis(1),
585 ..Default::default()
586 };
587 let manager = CertificateNegotiationManager::new(config);
588 let preferences = CertificateTypePreferences::prefer_raw_public_key();
589
590 let id = manager.start_negotiation(preferences);
591
592 std::thread::sleep(Duration::from_millis(10));
594 manager.handle_timeouts();
595
596 let state = manager.get_negotiation_state(id).unwrap();
597 assert!(matches!(state, NegotiationState::TimedOut { .. }));
598
599 let stats = manager.get_stats();
600 assert_eq!(stats.timed_out, 1);
601 }
602}