1use std::{
8 collections::HashMap,
9 sync::{Arc, Mutex, RwLock},
10 time::{Duration, Instant},
11 hash::{Hash, Hasher},
12};
13
14use tracing::{debug, info, warn, span, Level};
15
16use super::tls_extensions::{
17 CertificateTypeList, CertificateTypePreferences,
18 NegotiationResult, TlsExtensionError,
19};
20
21#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum NegotiationState {
24 Pending,
26 Waiting {
28 sent_at: Instant,
29 our_preferences: CertificateTypePreferences,
30 },
31 Completed {
33 result: NegotiationResult,
34 completed_at: Instant,
35 },
36 Failed {
38 error: String,
39 failed_at: Instant,
40 },
41 TimedOut {
43 timeout_at: Instant,
44 },
45}
46
47impl NegotiationState {
48 pub fn is_complete(&self) -> bool {
50 matches!(self,
51 NegotiationState::Completed { .. } |
52 NegotiationState::Failed { .. } |
53 NegotiationState::TimedOut { .. }
54 )
55 }
56
57 pub fn is_successful(&self) -> bool {
59 matches!(self, NegotiationState::Completed { .. })
60 }
61
62 pub fn get_result(&self) -> Option<&NegotiationResult> {
64 match self {
65 NegotiationState::Completed { result, .. } => Some(result),
66 _ => None,
67 }
68 }
69
70 pub fn get_error(&self) -> Option<&str> {
72 match self {
73 NegotiationState::Failed { error, .. } => Some(error),
74 _ => None,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct NegotiationConfig {
82 pub timeout: Duration,
84 pub enable_caching: bool,
86 pub max_cache_size: usize,
88 pub allow_fallback: bool,
90 pub default_preferences: CertificateTypePreferences,
92}
93
94impl Default for NegotiationConfig {
95 fn default() -> Self {
96 Self {
97 timeout: Duration::from_secs(10),
98 enable_caching: true,
99 max_cache_size: 1000,
100 allow_fallback: true,
101 default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
102 }
103 }
104}
105
106#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub struct NegotiationId(u64);
109
110impl NegotiationId {
111 pub fn new() -> Self {
113 use std::sync::atomic::{AtomicU64, Ordering};
114 static COUNTER: AtomicU64 = AtomicU64::new(1);
115 Self(COUNTER.fetch_add(1, Ordering::Relaxed))
116 }
117
118 pub fn as_u64(self) -> u64 {
120 self.0
121 }
122}
123
124#[derive(Debug, Clone, PartialEq, Eq, Hash)]
126struct CacheKey {
127 local_preferences: String, remote_preferences: String, }
132
133impl CacheKey {
134 fn new(
136 local: &CertificateTypePreferences,
137 remote_client: Option<&CertificateTypeList>,
138 remote_server: Option<&CertificateTypeList>,
139 ) -> Self {
140 use std::collections::hash_map::DefaultHasher;
141
142 let mut hasher = DefaultHasher::new();
143 local.hash(&mut hasher);
144 let local_hash = hasher.finish();
145
146 let mut hasher = DefaultHasher::new();
147 if let Some(types) = remote_client {
148 types.hash(&mut hasher);
149 }
150 if let Some(types) = remote_server {
151 types.hash(&mut hasher);
152 }
153 let remote_hash = hasher.finish();
154
155 Self {
156 local_preferences: format!("{:x}", local_hash),
157 remote_preferences: format!("{:x}", remote_hash),
158 }
159 }
160}
161
162impl Hash for CertificateTypePreferences {
164 fn hash<H: Hasher>(&self, state: &mut H) {
165 self.client_types.types.hash(state);
166 self.server_types.types.hash(state);
167 self.require_extensions.hash(state);
168 self.fallback_client.hash(state);
169 self.fallback_server.hash(state);
170 }
171}
172
173impl Hash for CertificateTypeList {
175 fn hash<H: Hasher>(&self, state: &mut H) {
176 self.types.hash(state);
177 }
178}
179
180pub struct CertificateNegotiationManager {
182 config: NegotiationConfig,
184 sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
186 cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
188 stats: Arc<Mutex<NegotiationStats>>,
190}
191
192#[derive(Debug, Default, Clone)]
194pub struct NegotiationStats {
195 pub total_attempts: u64,
197 pub successful: u64,
199 pub failed: u64,
201 pub timed_out: u64,
203 pub cache_hits: u64,
205 pub cache_misses: u64,
207 pub avg_negotiation_time: Duration,
209}
210
211impl CertificateNegotiationManager {
212 pub fn new(config: NegotiationConfig) -> Self {
214 Self {
215 config,
216 sessions: RwLock::new(HashMap::new()),
217 cache: Arc::new(Mutex::new(HashMap::new())),
218 stats: Arc::new(Mutex::new(NegotiationStats::default())),
219 }
220 }
221
222 pub fn start_negotiation(
224 &self,
225 preferences: CertificateTypePreferences,
226 ) -> NegotiationId {
227 let id = NegotiationId::new();
228 let state = NegotiationState::Waiting {
229 sent_at: Instant::now(),
230 our_preferences: preferences,
231 };
232
233 let mut sessions = self.sessions.write().unwrap();
234 sessions.insert(id, state);
235
236 let mut stats = self.stats.lock().unwrap();
237 stats.total_attempts += 1;
238
239 debug!("Started certificate type negotiation: {:?}", id);
240 id
241 }
242
243 pub fn complete_negotiation(
245 &self,
246 id: NegotiationId,
247 remote_client_types: Option<CertificateTypeList>,
248 remote_server_types: Option<CertificateTypeList>,
249 ) -> Result<NegotiationResult, TlsExtensionError> {
250 let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
251
252 let mut sessions = self.sessions.write().unwrap();
253 let state = sessions.get(&id).ok_or_else(|| {
254 TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {:?}", id))
255 })?;
256
257 let our_preferences = match state {
258 NegotiationState::Waiting { our_preferences, .. } => our_preferences.clone(),
259 _ => {
260 return Err(TlsExtensionError::InvalidExtensionData(
261 "Negotiation not in waiting state".to_string()
262 ));
263 }
264 };
265
266 if self.config.enable_caching {
268 let cache_key = CacheKey::new(
269 &our_preferences,
270 remote_client_types.as_ref(),
271 remote_server_types.as_ref(),
272 );
273
274 let mut cache = self.cache.lock().unwrap();
275 if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
276 if cached_at.elapsed() < Duration::from_secs(300) { let mut stats = self.stats.lock().unwrap();
279 stats.cache_hits += 1;
280
281 sessions.insert(id, NegotiationState::Completed {
283 result: cached_result.clone(),
284 completed_at: Instant::now(),
285 });
286
287 debug!("Cache hit for negotiation: {:?}", id);
288 return Ok(cached_result.clone());
289 } else {
290 cache.remove(&cache_key);
292 }
293 }
294
295 let mut stats = self.stats.lock().unwrap();
296 stats.cache_misses += 1;
297 }
298
299 let negotiation_start = Instant::now();
301 let result = our_preferences.negotiate(
302 remote_client_types.as_ref(),
303 remote_server_types.as_ref(),
304 );
305
306 match result {
307 Ok(negotiation_result) => {
308 let completed_at = Instant::now();
309 let negotiation_time = negotiation_start.elapsed();
310
311 sessions.insert(id, NegotiationState::Completed {
313 result: negotiation_result.clone(),
314 completed_at,
315 });
316
317 let mut stats = self.stats.lock().unwrap();
319 stats.successful += 1;
320
321 let total_completed = stats.successful + stats.failed;
323 stats.avg_negotiation_time = if total_completed == 1 {
324 negotiation_time
325 } else {
326 Duration::from_nanos(
327 (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1) +
328 negotiation_time.as_nanos() as u64) / total_completed
329 )
330 };
331
332 if self.config.enable_caching {
334 let cache_key = CacheKey::new(
335 &our_preferences,
336 remote_client_types.as_ref(),
337 remote_server_types.as_ref(),
338 );
339
340 let mut cache = self.cache.lock().unwrap();
341
342 if cache.len() >= self.config.max_cache_size {
344 let mut entries: Vec<_> = cache.iter()
346 .map(|(k, (_, t))| (k.clone(), t.clone()))
347 .collect();
348 entries.sort_by_key(|(_, timestamp)| *timestamp);
349
350 let to_remove = cache.len() - self.config.max_cache_size + 1;
351 let keys_to_remove: Vec<_> = entries.iter()
352 .take(to_remove)
353 .map(|(key, _)| key.clone())
354 .collect();
355
356 for key in keys_to_remove {
357 cache.remove(&key);
358 }
359 }
360
361 cache.insert(cache_key, (negotiation_result.clone(), completed_at));
362 }
363
364 info!("Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
365 id, negotiation_result.client_cert_type, negotiation_result.server_cert_type);
366
367 Ok(negotiation_result)
368 }
369 Err(error) => {
370 sessions.insert(id, NegotiationState::Failed {
372 error: error.to_string(),
373 failed_at: Instant::now(),
374 });
375
376 let mut stats = self.stats.lock().unwrap();
378 stats.failed += 1;
379
380 warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
381 Err(error)
382 }
383 }
384 }
385
386 pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
388 let mut sessions = self.sessions.write().unwrap();
389 sessions.insert(id, NegotiationState::Failed {
390 error,
391 failed_at: Instant::now(),
392 });
393
394 let mut stats = self.stats.lock().unwrap();
395 stats.failed += 1;
396
397 warn!("Certificate type negotiation failed: {:?}", id);
398 }
399
400 pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
402 let sessions = self.sessions.read().unwrap();
403 sessions.get(&id).cloned()
404 }
405
406 pub fn handle_timeouts(&self) {
408 let mut sessions = self.sessions.write().unwrap();
409 let mut timed_out_ids = Vec::new();
410
411 for (id, state) in sessions.iter() {
412 if let NegotiationState::Waiting { sent_at, .. } = state {
413 if sent_at.elapsed() > self.config.timeout {
414 timed_out_ids.push(*id);
415 }
416 }
417 }
418
419 for id in timed_out_ids {
420 sessions.insert(id, NegotiationState::TimedOut {
421 timeout_at: Instant::now(),
422 });
423
424 let mut stats = self.stats.lock().unwrap();
425 stats.timed_out += 1;
426
427 warn!("Certificate type negotiation timed out: {:?}", id);
428 }
429 }
430
431 pub fn cleanup_old_sessions(&self, max_age: Duration) {
433 let mut sessions = self.sessions.write().unwrap();
434 let cutoff = Instant::now() - max_age;
435
436 sessions.retain(|id, state| {
437 let should_retain = match state {
438 NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
439 NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
440 NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
441 _ => true, };
443
444 if !should_retain {
445 debug!("Cleaned up old negotiation session: {:?}", id);
446 }
447
448 should_retain
449 });
450 }
451
452 pub fn get_stats(&self) -> NegotiationStats {
454 self.stats.lock().unwrap().clone()
455 }
456
457 pub fn clear_cache(&self) {
459 let mut cache = self.cache.lock().unwrap();
460 cache.clear();
461 debug!("Cleared certificate type negotiation cache");
462 }
463
464 pub fn get_cache_stats(&self) -> (usize, usize) {
466 let cache = self.cache.lock().unwrap();
467 (cache.len(), self.config.max_cache_size)
468 }
469}
470
471impl Default for CertificateNegotiationManager {
472 fn default() -> Self {
473 Self::new(NegotiationConfig::default())
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480 use super::super::tls_extensions::CertificateType;
481
482 #[test]
483 fn test_negotiation_id_generation() {
484 let id1 = NegotiationId::new();
485 let id2 = NegotiationId::new();
486
487 assert_ne!(id1, id2);
488 assert!(id1.as_u64() > 0);
489 assert!(id2.as_u64() > 0);
490 }
491
492 #[test]
493 fn test_negotiation_state_checks() {
494 let pending = NegotiationState::Pending;
495 assert!(!pending.is_complete());
496 assert!(!pending.is_successful());
497
498 let completed = NegotiationState::Completed {
499 result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
500 completed_at: Instant::now(),
501 };
502 assert!(completed.is_complete());
503 assert!(completed.is_successful());
504 assert!(completed.get_result().is_some());
505
506 let failed = NegotiationState::Failed {
507 error: "Test error".to_string(),
508 failed_at: Instant::now(),
509 };
510 assert!(failed.is_complete());
511 assert!(!failed.is_successful());
512 assert_eq!(failed.get_error().unwrap(), "Test error");
513 }
514
515 #[test]
516 fn test_negotiation_manager_basic_flow() {
517 let manager = CertificateNegotiationManager::default();
518 let preferences = CertificateTypePreferences::prefer_raw_public_key();
519
520 let id = manager.start_negotiation(preferences);
522
523 let state = manager.get_negotiation_state(id).unwrap();
524 assert!(matches!(state, NegotiationState::Waiting { .. }));
525
526 let remote_types = CertificateTypeList::raw_public_key_only();
528 let result = manager.complete_negotiation(
529 id,
530 Some(remote_types.clone()),
531 Some(remote_types),
532 ).unwrap();
533
534 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
535 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
536
537 let state = manager.get_negotiation_state(id).unwrap();
538 assert!(state.is_successful());
539 }
540
541 #[test]
542 fn test_negotiation_caching() {
543 let config = NegotiationConfig {
544 enable_caching: true,
545 ..Default::default()
546 };
547 let manager = CertificateNegotiationManager::new(config);
548 let preferences = CertificateTypePreferences::prefer_raw_public_key();
549
550 let id1 = manager.start_negotiation(preferences.clone());
552 let remote_types = CertificateTypeList::raw_public_key_only();
553 let result1 = manager.complete_negotiation(
554 id1,
555 Some(remote_types.clone()),
556 Some(remote_types.clone()),
557 ).unwrap();
558
559 let id2 = manager.start_negotiation(preferences);
561 let result2 = manager.complete_negotiation(
562 id2,
563 Some(remote_types.clone()),
564 Some(remote_types),
565 ).unwrap();
566
567 assert_eq!(result1, result2);
568
569 let stats = manager.get_stats();
570 assert_eq!(stats.cache_hits, 1);
571 assert_eq!(stats.cache_misses, 1);
572 }
573
574 #[test]
575 fn test_negotiation_timeout_handling() {
576 let config = NegotiationConfig {
577 timeout: Duration::from_millis(1),
578 ..Default::default()
579 };
580 let manager = CertificateNegotiationManager::new(config);
581 let preferences = CertificateTypePreferences::prefer_raw_public_key();
582
583 let id = manager.start_negotiation(preferences);
584
585 std::thread::sleep(Duration::from_millis(10));
587 manager.handle_timeouts();
588
589 let state = manager.get_negotiation_state(id).unwrap();
590 assert!(matches!(state, NegotiationState::TimedOut { .. }));
591
592 let stats = manager.get_stats();
593 assert_eq!(stats.timed_out, 1);
594 }
595}