ant_quic/crypto/
certificate_negotiation.rs

1//! Certificate Type Negotiation Protocol Implementation
2//!
3//! This module implements the complete certificate type negotiation protocol
4//! as defined in RFC 7250, including state management, caching, and integration
5//! with both client and server sides of TLS connections.
6
7use std::{
8    collections::HashMap,
9    hash::{Hash, Hasher},
10    sync::{Arc, Mutex, RwLock},
11    time::{Duration, Instant},
12};
13
14use tracing::{Level, debug, info, span, warn};
15
16use super::tls_extensions::{
17    CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
18};
19
20/// Negotiation state for a single TLS connection
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NegotiationState {
23    /// Negotiation not yet started
24    Pending,
25    /// Extensions sent, waiting for response
26    Waiting {
27        sent_at: Instant,
28        our_preferences: CertificateTypePreferences,
29    },
30    /// Negotiation completed successfully
31    Completed {
32        result: NegotiationResult,
33        completed_at: Instant,
34    },
35    /// Negotiation failed
36    Failed {
37        /// The error message
38        error: String,
39        /// When the failure occurred
40        failed_at: Instant,
41    },
42    /// Timed out waiting for response
43    TimedOut {
44        /// When the timeout occurred
45        timeout_at: Instant,
46    },
47}
48
49impl NegotiationState {
50    /// Check if negotiation is complete (either succeeded or failed)
51    pub fn is_complete(&self) -> bool {
52        matches!(
53            self,
54            Self::Completed { .. } | Self::Failed { .. } | Self::TimedOut { .. }
55        )
56    }
57
58    /// Check if negotiation succeeded
59    pub fn is_successful(&self) -> bool {
60        matches!(self, Self::Completed { .. })
61    }
62
63    /// Get the negotiation result if successful
64    pub fn get_result(&self) -> Option<&NegotiationResult> {
65        match self {
66            Self::Completed { result, .. } => Some(result),
67            _ => None,
68        }
69    }
70
71    /// Get error message if failed
72    pub fn get_error(&self) -> Option<&str> {
73        match self {
74            Self::Failed { error, .. } => Some(error),
75            _ => None,
76        }
77    }
78}
79
80/// Configuration for certificate type negotiation
81#[derive(Debug, Clone)]
82pub struct NegotiationConfig {
83    /// Timeout for waiting for negotiation response
84    pub timeout: Duration,
85    /// Whether to cache negotiation results
86    pub enable_caching: bool,
87    /// Maximum cache size
88    pub max_cache_size: usize,
89    /// Whether to allow fallback to X.509 if RPK negotiation fails
90    pub allow_fallback: bool,
91    /// Default preferences if none specified
92    pub default_preferences: CertificateTypePreferences,
93}
94
95impl Default for NegotiationConfig {
96    fn default() -> Self {
97        Self {
98            timeout: Duration::from_secs(10),
99            enable_caching: true,
100            max_cache_size: 1000,
101            allow_fallback: true,
102            default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
103        }
104    }
105}
106
107/// Unique identifier for a negotiation session
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
109pub struct NegotiationId(u64);
110
111impl Default for NegotiationId {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl NegotiationId {
118    /// Generate a new unique negotiation ID
119    pub fn new() -> Self {
120        use std::sync::atomic::{AtomicU64, Ordering};
121        static COUNTER: AtomicU64 = AtomicU64::new(1);
122        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
123    }
124
125    /// Get the raw ID value
126    pub fn as_u64(self) -> u64 {
127        self.0
128    }
129}
130
131/// Cache key for negotiation results
132#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133struct CacheKey {
134    /// Our certificate type preferences
135    local_preferences: String, // Serialized preferences for hashing
136    /// Remote certificate type preferences  
137    remote_preferences: String, // Serialized preferences for hashing
138}
139
140impl CacheKey {
141    /// Create a cache key from preferences
142    fn new(
143        local: &CertificateTypePreferences,
144        remote_client: Option<&CertificateTypeList>,
145        remote_server: Option<&CertificateTypeList>,
146    ) -> Self {
147        use std::collections::hash_map::DefaultHasher;
148
149        let mut hasher = DefaultHasher::new();
150        local.hash(&mut hasher);
151        let local_hash = hasher.finish();
152
153        let mut hasher = DefaultHasher::new();
154        if let Some(types) = remote_client {
155            types.hash(&mut hasher);
156        }
157        if let Some(types) = remote_server {
158            types.hash(&mut hasher);
159        }
160        let remote_hash = hasher.finish();
161
162        Self {
163            local_preferences: format!("{local_hash:x}"),
164            remote_preferences: format!("{remote_hash:x}"),
165        }
166    }
167}
168
169/// Hash implementation for CertificateTypePreferences
170impl Hash for CertificateTypePreferences {
171    fn hash<H: Hasher>(&self, state: &mut H) {
172        self.client_types.types.hash(state);
173        self.server_types.types.hash(state);
174        self.require_extensions.hash(state);
175        self.fallback_client.hash(state);
176        self.fallback_server.hash(state);
177    }
178}
179
180/// Hash implementation for CertificateTypeList  
181impl Hash for CertificateTypeList {
182    fn hash<H: Hasher>(&self, state: &mut H) {
183        self.types.hash(state);
184    }
185}
186
187/// Certificate type negotiation manager
188pub struct CertificateNegotiationManager {
189    /// Configuration for negotiation behavior
190    config: NegotiationConfig,
191    /// Active negotiation sessions
192    sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
193    /// Result cache for performance optimization
194    cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
195    /// Negotiation statistics
196    stats: Arc<Mutex<NegotiationStats>>,
197}
198
199/// Statistics for certificate type negotiation
200#[derive(Debug, Default, Clone)]
201pub struct NegotiationStats {
202    /// Total number of negotiations attempted
203    pub total_attempts: u64,
204    /// Number of successful negotiations
205    pub successful: u64,
206    /// Number of failed negotiations
207    pub failed: u64,
208    /// Number of timed out negotiations
209    pub timed_out: u64,
210    /// Number of cache hits
211    pub cache_hits: u64,
212    /// Number of cache misses
213    pub cache_misses: u64,
214    /// Average negotiation time
215    pub avg_negotiation_time: Duration,
216}
217
218impl CertificateNegotiationManager {
219    /// Create a new negotiation manager
220    pub fn new(config: NegotiationConfig) -> Self {
221        Self {
222            config,
223            sessions: RwLock::new(HashMap::new()),
224            cache: Arc::new(Mutex::new(HashMap::new())),
225            stats: Arc::new(Mutex::new(NegotiationStats::default())),
226        }
227    }
228
229    /// Start a new certificate type negotiation
230    pub fn start_negotiation(
231        &self,
232        preferences: CertificateTypePreferences,
233    ) -> Result<NegotiationId, TlsExtensionError> {
234        let id = NegotiationId::new();
235        let state = NegotiationState::Waiting {
236            sent_at: Instant::now(),
237            our_preferences: preferences,
238        };
239
240        let mut sessions = self.sessions.write().map_err(|e| {
241            TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
242        })?;
243        sessions.insert(id, state);
244
245        let mut stats = self.stats.lock().map_err(|e| {
246            TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
247        })?;
248        stats.total_attempts += 1;
249
250        debug!("Started certificate type negotiation: {:?}", id);
251        Ok(id)
252    }
253
254    /// Complete a negotiation with remote preferences
255    pub fn complete_negotiation(
256        &self,
257        id: NegotiationId,
258        remote_client_types: Option<CertificateTypeList>,
259        remote_server_types: Option<CertificateTypeList>,
260    ) -> Result<NegotiationResult, TlsExtensionError> {
261        let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
262
263        let mut sessions = self.sessions.write().map_err(|e| {
264            TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
265        })?;
266        let state = sessions.get(&id).ok_or_else(|| {
267            TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {id:?}"))
268        })?;
269
270        let our_preferences = match state {
271            NegotiationState::Waiting {
272                our_preferences, ..
273            } => our_preferences.clone(),
274            _ => {
275                return Err(TlsExtensionError::InvalidExtensionData(
276                    "Negotiation not in waiting state".to_string(),
277                ));
278            }
279        };
280
281        // Check cache first if enabled
282        if self.config.enable_caching {
283            let cache_key = CacheKey::new(
284                &our_preferences,
285                remote_client_types.as_ref(),
286                remote_server_types.as_ref(),
287            );
288
289            let mut cache = self.cache.lock().map_err(|e| {
290                TlsExtensionError::InvalidExtensionData(format!("Cache lock poisoned: {}", e))
291            })?;
292            if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
293                // Check if cache entry is still valid (not expired)
294                if cached_at.elapsed() < Duration::from_secs(300) {
295                    // 5 minute cache
296                    let mut stats = self.stats.lock().map_err(|e| {
297                        TlsExtensionError::InvalidExtensionData(format!(
298                            "Stats lock poisoned: {}",
299                            e
300                        ))
301                    })?;
302                    stats.cache_hits += 1;
303
304                    // Update session state
305                    sessions.insert(
306                        id,
307                        NegotiationState::Completed {
308                            result: cached_result.clone(),
309                            completed_at: Instant::now(),
310                        },
311                    );
312
313                    debug!("Cache hit for negotiation: {:?}", id);
314                    return Ok(cached_result.clone());
315                } else {
316                    // Remove expired entry
317                    cache.remove(&cache_key);
318                }
319            }
320
321            let mut stats = self.stats.lock().map_err(|e| {
322                TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
323            })?;
324            stats.cache_misses += 1;
325        }
326
327        // Perform actual negotiation
328        let negotiation_start = Instant::now();
329        let result =
330            our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
331
332        match result {
333            Ok(negotiation_result) => {
334                let completed_at = Instant::now();
335                let negotiation_time = negotiation_start.elapsed();
336
337                // Update session state
338                sessions.insert(
339                    id,
340                    NegotiationState::Completed {
341                        result: negotiation_result.clone(),
342                        completed_at,
343                    },
344                );
345
346                // Update statistics
347                let mut stats = self.stats.lock().map_err(|e| {
348                    TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
349                })?;
350                stats.successful += 1;
351
352                // Update average negotiation time (simple moving average)
353                let total_completed = stats.successful + stats.failed;
354                stats.avg_negotiation_time = if total_completed == 1 {
355                    negotiation_time
356                } else {
357                    Duration::from_nanos(
358                        (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
359                            + negotiation_time.as_nanos() as u64)
360                            / total_completed,
361                    )
362                };
363
364                // Cache the result if caching is enabled
365                if self.config.enable_caching {
366                    let cache_key = CacheKey::new(
367                        &our_preferences,
368                        remote_client_types.as_ref(),
369                        remote_server_types.as_ref(),
370                    );
371
372                    let mut cache = self.cache.lock().map_err(|e| {
373                        TlsExtensionError::InvalidExtensionData(format!(
374                            "Cache lock poisoned: {}",
375                            e
376                        ))
377                    })?;
378
379                    // Evict old entries if cache is full
380                    if cache.len() >= self.config.max_cache_size {
381                        // Simple eviction: remove oldest entries
382                        let mut entries: Vec<_> =
383                            cache.iter().map(|(k, (_, t))| (k.clone(), *t)).collect();
384                        entries.sort_by_key(|(_, timestamp)| *timestamp);
385
386                        let to_remove = cache.len() - self.config.max_cache_size + 1;
387                        let keys_to_remove: Vec<_> = entries
388                            .iter()
389                            .take(to_remove)
390                            .map(|(key, _)| key.clone())
391                            .collect();
392
393                        for key in keys_to_remove {
394                            cache.remove(&key);
395                        }
396                    }
397
398                    cache.insert(cache_key, (negotiation_result.clone(), completed_at));
399                }
400
401                info!(
402                    "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
403                    id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
404                );
405
406                Ok(negotiation_result)
407            }
408            Err(error) => {
409                // Update session state
410                sessions.insert(
411                    id,
412                    NegotiationState::Failed {
413                        error: error.to_string(),
414                        failed_at: Instant::now(),
415                    },
416                );
417
418                // Update statistics
419                let mut stats = self.stats.lock().map_err(|e| {
420                    TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
421                })?;
422                stats.failed += 1;
423
424                warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
425                Err(error)
426            }
427        }
428    }
429
430    /// Fail a negotiation with an error
431    pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
432        let mut sessions = self
433            .sessions
434            .write()
435            .expect("Session lock should not be poisoned");
436        sessions.insert(
437            id,
438            NegotiationState::Failed {
439                error,
440                failed_at: Instant::now(),
441            },
442        );
443
444        let mut stats = self
445            .stats
446            .lock()
447            .expect("Stats lock should not be poisoned");
448        stats.failed += 1;
449
450        warn!("Certificate type negotiation failed: {:?}", id);
451    }
452
453    /// Get the current state of a negotiation
454    pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
455        let sessions = self
456            .sessions
457            .read()
458            .expect("Session lock should not be poisoned");
459        sessions.get(&id).cloned()
460    }
461
462    /// Check for and handle timed out negotiations
463    pub fn handle_timeouts(&self) {
464        let mut sessions = self
465            .sessions
466            .write()
467            .expect("Session lock should not be poisoned");
468        let mut timed_out_ids = Vec::new();
469
470        for (id, state) in sessions.iter() {
471            if let NegotiationState::Waiting { sent_at, .. } = state {
472                if sent_at.elapsed() > self.config.timeout {
473                    timed_out_ids.push(*id);
474                }
475            }
476        }
477
478        for id in timed_out_ids {
479            sessions.insert(
480                id,
481                NegotiationState::TimedOut {
482                    timeout_at: Instant::now(),
483                },
484            );
485
486            let mut stats = self
487                .stats
488                .lock()
489                .expect("Stats lock should not be poisoned");
490            stats.timed_out += 1;
491
492            warn!("Certificate type negotiation timed out: {:?}", id);
493        }
494    }
495
496    /// Clean up completed negotiations older than the specified duration
497    pub fn cleanup_old_sessions(&self, max_age: Duration) {
498        let mut sessions = self
499            .sessions
500            .write()
501            .expect("Session lock should not be poisoned");
502        let cutoff = Instant::now() - max_age;
503
504        sessions.retain(|id, state| {
505            let should_retain = match state {
506                NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
507                NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
508                NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
509                _ => true, // Keep pending and waiting sessions
510            };
511
512            if !should_retain {
513                debug!("Cleaned up old negotiation session: {:?}", id);
514            }
515
516            should_retain
517        });
518    }
519
520    /// Get current negotiation statistics
521    pub fn get_stats(&self) -> NegotiationStats {
522        self.stats
523            .lock()
524            .expect("Stats lock should not be poisoned")
525            .clone()
526    }
527
528    /// Clear all cached results
529    pub fn clear_cache(&self) {
530        let mut cache = self
531            .cache
532            .lock()
533            .expect("Cache lock should not be poisoned");
534        cache.clear();
535        debug!("Cleared certificate type negotiation cache");
536    }
537
538    /// Get cache statistics
539    pub fn get_cache_stats(&self) -> (usize, usize) {
540        let cache = self
541            .cache
542            .lock()
543            .expect("Cache lock should not be poisoned");
544        (cache.len(), self.config.max_cache_size)
545    }
546}
547
548impl Default for CertificateNegotiationManager {
549    fn default() -> Self {
550        Self::new(NegotiationConfig::default())
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::super::tls_extensions::CertificateType;
557    use super::*;
558
559    #[test]
560    fn test_negotiation_id_generation() {
561        let id1 = NegotiationId::new();
562        let id2 = NegotiationId::new();
563
564        assert_ne!(id1, id2);
565        assert!(id1.as_u64() > 0);
566        assert!(id2.as_u64() > 0);
567    }
568
569    #[test]
570    fn test_negotiation_state_checks() {
571        let pending = NegotiationState::Pending;
572        assert!(!pending.is_complete());
573        assert!(!pending.is_successful());
574
575        let completed = NegotiationState::Completed {
576            result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
577            completed_at: Instant::now(),
578        };
579        assert!(completed.is_complete());
580        assert!(completed.is_successful());
581        assert!(completed.get_result().is_some());
582
583        let failed = NegotiationState::Failed {
584            error: "Test error".to_string(),
585            failed_at: Instant::now(),
586        };
587        assert!(failed.is_complete());
588        assert!(!failed.is_successful());
589        assert_eq!(failed.get_error().unwrap(), "Test error");
590    }
591
592    #[test]
593    fn test_negotiation_manager_basic_flow() {
594        let manager = CertificateNegotiationManager::default();
595        let preferences = CertificateTypePreferences::prefer_raw_public_key();
596
597        // Start negotiation
598        let id = manager.start_negotiation(preferences).unwrap();
599
600        let state = manager.get_negotiation_state(id).unwrap();
601        assert!(matches!(state, NegotiationState::Waiting { .. }));
602
603        // Complete negotiation
604        let remote_types = CertificateTypeList::raw_public_key_only();
605        let result = manager
606            .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
607            .unwrap();
608
609        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
610        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
611
612        let state = manager.get_negotiation_state(id).unwrap();
613        assert!(state.is_successful());
614    }
615
616    #[test]
617    fn test_negotiation_caching() {
618        let config = NegotiationConfig {
619            enable_caching: true,
620            ..Default::default()
621        };
622        let manager = CertificateNegotiationManager::new(config);
623        let preferences = CertificateTypePreferences::prefer_raw_public_key();
624
625        // First negotiation
626        let id1 = manager.start_negotiation(preferences.clone()).unwrap();
627        let remote_types = CertificateTypeList::raw_public_key_only();
628        let result1 = manager
629            .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
630            .unwrap();
631
632        // Second negotiation with same preferences should hit cache
633        let id2 = manager.start_negotiation(preferences).unwrap();
634        let result2 = manager
635            .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
636            .unwrap();
637
638        assert_eq!(result1, result2);
639
640        let stats = manager.get_stats();
641        assert_eq!(stats.cache_hits, 1);
642        assert_eq!(stats.cache_misses, 1);
643    }
644
645    #[test]
646    fn test_negotiation_timeout_handling() {
647        let config = NegotiationConfig {
648            timeout: Duration::from_millis(1),
649            ..Default::default()
650        };
651        let manager = CertificateNegotiationManager::new(config);
652        let preferences = CertificateTypePreferences::prefer_raw_public_key();
653
654        let id = manager.start_negotiation(preferences).unwrap();
655
656        // Wait for timeout
657        std::thread::sleep(Duration::from_millis(10));
658        manager.handle_timeouts();
659
660        let state = manager.get_negotiation_state(id).unwrap();
661        assert!(matches!(state, NegotiationState::TimedOut { .. }));
662
663        let stats = manager.get_stats();
664        assert_eq!(stats.timed_out, 1);
665    }
666}