ant_quic/crypto/
certificate_negotiation.rs

1//! Certificate Type Negotiation Protocol Implementation
2//!
3//! This module implements the complete certificate type negotiation protocol
4//! as defined in RFC 7250, including state management, caching, and integration
5//! with both client and server sides of TLS connections.
6
7use std::{
8    collections::HashMap,
9    sync::{Arc, Mutex, RwLock},
10    time::{Duration, Instant},
11    hash::{Hash, Hasher},
12};
13
14use tracing::{debug, info, warn, span, Level};
15
16use super::tls_extensions::{
17    CertificateTypeList, CertificateTypePreferences,
18    NegotiationResult, TlsExtensionError,
19};
20
21/// Negotiation state for a single TLS connection
22#[derive(Debug, Clone, PartialEq, Eq)]
23pub enum NegotiationState {
24    /// Negotiation not yet started
25    Pending,
26    /// Extensions sent, waiting for response
27    Waiting {
28        sent_at: Instant,
29        our_preferences: CertificateTypePreferences,
30    },
31    /// Negotiation completed successfully
32    Completed {
33        result: NegotiationResult,
34        completed_at: Instant,
35    },
36    /// Negotiation failed
37    Failed {
38        error: String,
39        failed_at: Instant,
40    },
41    /// Timed out waiting for response
42    TimedOut {
43        timeout_at: Instant,
44    },
45}
46
47impl NegotiationState {
48    /// Check if negotiation is complete (either succeeded or failed)
49    pub fn is_complete(&self) -> bool {
50        matches!(self, 
51            NegotiationState::Completed { .. } | 
52            NegotiationState::Failed { .. } | 
53            NegotiationState::TimedOut { .. }
54        )
55    }
56
57    /// Check if negotiation succeeded
58    pub fn is_successful(&self) -> bool {
59        matches!(self, NegotiationState::Completed { .. })
60    }
61
62    /// Get the negotiation result if successful
63    pub fn get_result(&self) -> Option<&NegotiationResult> {
64        match self {
65            NegotiationState::Completed { result, .. } => Some(result),
66            _ => None,
67        }
68    }
69
70    /// Get error message if failed
71    pub fn get_error(&self) -> Option<&str> {
72        match self {
73            NegotiationState::Failed { error, .. } => Some(error),
74            _ => None,
75        }
76    }
77}
78
79/// Configuration for certificate type negotiation
80#[derive(Debug, Clone)]
81pub struct NegotiationConfig {
82    /// Timeout for waiting for negotiation response
83    pub timeout: Duration,
84    /// Whether to cache negotiation results
85    pub enable_caching: bool,
86    /// Maximum cache size
87    pub max_cache_size: usize,
88    /// Whether to allow fallback to X.509 if RPK negotiation fails
89    pub allow_fallback: bool,
90    /// Default preferences if none specified
91    pub default_preferences: CertificateTypePreferences,
92}
93
94impl Default for NegotiationConfig {
95    fn default() -> Self {
96        Self {
97            timeout: Duration::from_secs(10),
98            enable_caching: true,
99            max_cache_size: 1000,
100            allow_fallback: true,
101            default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
102        }
103    }
104}
105
106/// Unique identifier for a negotiation session
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
108pub struct NegotiationId(u64);
109
110impl NegotiationId {
111    /// Generate a new unique negotiation ID
112    pub fn new() -> Self {
113        use std::sync::atomic::{AtomicU64, Ordering};
114        static COUNTER: AtomicU64 = AtomicU64::new(1);
115        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
116    }
117
118    /// Get the raw ID value
119    pub fn as_u64(self) -> u64 {
120        self.0
121    }
122}
123
124/// Cache key for negotiation results
125#[derive(Debug, Clone, PartialEq, Eq, Hash)]
126struct CacheKey {
127    /// Our certificate type preferences
128    local_preferences: String, // Serialized preferences for hashing
129    /// Remote certificate type preferences  
130    remote_preferences: String, // Serialized preferences for hashing
131}
132
133impl CacheKey {
134    /// Create a cache key from preferences
135    fn new(
136        local: &CertificateTypePreferences,
137        remote_client: Option<&CertificateTypeList>,
138        remote_server: Option<&CertificateTypeList>,
139    ) -> Self {
140        use std::collections::hash_map::DefaultHasher;
141        
142        let mut hasher = DefaultHasher::new();
143        local.hash(&mut hasher);
144        let local_hash = hasher.finish();
145
146        let mut hasher = DefaultHasher::new();
147        if let Some(types) = remote_client {
148            types.hash(&mut hasher);
149        }
150        if let Some(types) = remote_server {
151            types.hash(&mut hasher);
152        }
153        let remote_hash = hasher.finish();
154
155        Self {
156            local_preferences: format!("{:x}", local_hash),
157            remote_preferences: format!("{:x}", remote_hash),
158        }
159    }
160}
161
162/// Hash implementation for CertificateTypePreferences
163impl Hash for CertificateTypePreferences {
164    fn hash<H: Hasher>(&self, state: &mut H) {
165        self.client_types.types.hash(state);
166        self.server_types.types.hash(state);
167        self.require_extensions.hash(state);
168        self.fallback_client.hash(state);
169        self.fallback_server.hash(state);
170    }
171}
172
173/// Hash implementation for CertificateTypeList  
174impl Hash for CertificateTypeList {
175    fn hash<H: Hasher>(&self, state: &mut H) {
176        self.types.hash(state);
177    }
178}
179
180/// Certificate type negotiation manager
181pub struct CertificateNegotiationManager {
182    /// Configuration for negotiation behavior
183    config: NegotiationConfig,
184    /// Active negotiation sessions
185    sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
186    /// Result cache for performance optimization
187    cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
188    /// Negotiation statistics
189    stats: Arc<Mutex<NegotiationStats>>,
190}
191
192/// Statistics for certificate type negotiation
193#[derive(Debug, Default, Clone)]
194pub struct NegotiationStats {
195    /// Total number of negotiations attempted
196    pub total_attempts: u64,
197    /// Number of successful negotiations
198    pub successful: u64,
199    /// Number of failed negotiations
200    pub failed: u64,
201    /// Number of timed out negotiations
202    pub timed_out: u64,
203    /// Number of cache hits
204    pub cache_hits: u64,
205    /// Number of cache misses
206    pub cache_misses: u64,
207    /// Average negotiation time
208    pub avg_negotiation_time: Duration,
209}
210
211impl CertificateNegotiationManager {
212    /// Create a new negotiation manager
213    pub fn new(config: NegotiationConfig) -> Self {
214        Self {
215            config,
216            sessions: RwLock::new(HashMap::new()),
217            cache: Arc::new(Mutex::new(HashMap::new())),
218            stats: Arc::new(Mutex::new(NegotiationStats::default())),
219        }
220    }
221
222    /// Start a new certificate type negotiation
223    pub fn start_negotiation(
224        &self,
225        preferences: CertificateTypePreferences,
226    ) -> NegotiationId {
227        let id = NegotiationId::new();
228        let state = NegotiationState::Waiting {
229            sent_at: Instant::now(),
230            our_preferences: preferences,
231        };
232
233        let mut sessions = self.sessions.write().unwrap();
234        sessions.insert(id, state);
235
236        let mut stats = self.stats.lock().unwrap();
237        stats.total_attempts += 1;
238
239        debug!("Started certificate type negotiation: {:?}", id);
240        id
241    }
242
243    /// Complete a negotiation with remote preferences
244    pub fn complete_negotiation(
245        &self,
246        id: NegotiationId,
247        remote_client_types: Option<CertificateTypeList>,
248        remote_server_types: Option<CertificateTypeList>,
249    ) -> Result<NegotiationResult, TlsExtensionError> {
250        let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
251
252        let mut sessions = self.sessions.write().unwrap();
253        let state = sessions.get(&id).ok_or_else(|| {
254            TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {:?}", id))
255        })?;
256
257        let our_preferences = match state {
258            NegotiationState::Waiting { our_preferences, .. } => our_preferences.clone(),
259            _ => {
260                return Err(TlsExtensionError::InvalidExtensionData(
261                    "Negotiation not in waiting state".to_string()
262                ));
263            }
264        };
265
266        // Check cache first if enabled
267        if self.config.enable_caching {
268            let cache_key = CacheKey::new(
269                &our_preferences,
270                remote_client_types.as_ref(),
271                remote_server_types.as_ref(),
272            );
273
274            let mut cache = self.cache.lock().unwrap();
275            if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
276                // Check if cache entry is still valid (not expired)
277                if cached_at.elapsed() < Duration::from_secs(300) { // 5 minute cache
278                    let mut stats = self.stats.lock().unwrap();
279                    stats.cache_hits += 1;
280                    
281                    // Update session state
282                    sessions.insert(id, NegotiationState::Completed {
283                        result: cached_result.clone(),
284                        completed_at: Instant::now(),
285                    });
286
287                    debug!("Cache hit for negotiation: {:?}", id);
288                    return Ok(cached_result.clone());
289                } else {
290                    // Remove expired entry
291                    cache.remove(&cache_key);
292                }
293            }
294
295            let mut stats = self.stats.lock().unwrap();
296            stats.cache_misses += 1;
297        }
298
299        // Perform actual negotiation
300        let negotiation_start = Instant::now();
301        let result = our_preferences.negotiate(
302            remote_client_types.as_ref(),
303            remote_server_types.as_ref(),
304        );
305
306        match result {
307            Ok(negotiation_result) => {
308                let completed_at = Instant::now();
309                let negotiation_time = negotiation_start.elapsed();
310
311                // Update session state
312                sessions.insert(id, NegotiationState::Completed {
313                    result: negotiation_result.clone(),
314                    completed_at,
315                });
316
317                // Update statistics
318                let mut stats = self.stats.lock().unwrap();
319                stats.successful += 1;
320                
321                // Update average negotiation time (simple moving average)
322                let total_completed = stats.successful + stats.failed;
323                stats.avg_negotiation_time = if total_completed == 1 {
324                    negotiation_time
325                } else {
326                    Duration::from_nanos(
327                        (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1) + 
328                         negotiation_time.as_nanos() as u64) / total_completed
329                    )
330                };
331
332                // Cache the result if caching is enabled
333                if self.config.enable_caching {
334                    let cache_key = CacheKey::new(
335                        &our_preferences,
336                        remote_client_types.as_ref(),
337                        remote_server_types.as_ref(),
338                    );
339
340                    let mut cache = self.cache.lock().unwrap();
341                    
342                    // Evict old entries if cache is full
343                    if cache.len() >= self.config.max_cache_size {
344                        // Simple eviction: remove oldest entries
345                        let mut entries: Vec<_> = cache.iter()
346                            .map(|(k, (_, t))| (k.clone(), t.clone()))
347                            .collect();
348                        entries.sort_by_key(|(_, timestamp)| *timestamp);
349                        
350                        let to_remove = cache.len() - self.config.max_cache_size + 1;
351                        let keys_to_remove: Vec<_> = entries.iter()
352                            .take(to_remove)
353                            .map(|(key, _)| key.clone())
354                            .collect();
355                        
356                        for key in keys_to_remove {
357                            cache.remove(&key);
358                        }
359                    }
360
361                    cache.insert(cache_key, (negotiation_result.clone(), completed_at));
362                }
363
364                info!("Certificate type negotiation completed successfully: {:?} -> client={}, server={}", 
365                      id, negotiation_result.client_cert_type, negotiation_result.server_cert_type);
366
367                Ok(negotiation_result)
368            }
369            Err(error) => {
370                // Update session state
371                sessions.insert(id, NegotiationState::Failed {
372                    error: error.to_string(),
373                    failed_at: Instant::now(),
374                });
375
376                // Update statistics
377                let mut stats = self.stats.lock().unwrap();
378                stats.failed += 1;
379
380                warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
381                Err(error)
382            }
383        }
384    }
385
386    /// Fail a negotiation with an error
387    pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
388        let mut sessions = self.sessions.write().unwrap();
389        sessions.insert(id, NegotiationState::Failed {
390            error,
391            failed_at: Instant::now(),
392        });
393
394        let mut stats = self.stats.lock().unwrap();
395        stats.failed += 1;
396
397        warn!("Certificate type negotiation failed: {:?}", id);
398    }
399
400    /// Get the current state of a negotiation
401    pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
402        let sessions = self.sessions.read().unwrap();
403        sessions.get(&id).cloned()
404    }
405
406    /// Check for and handle timed out negotiations
407    pub fn handle_timeouts(&self) {
408        let mut sessions = self.sessions.write().unwrap();
409        let mut timed_out_ids = Vec::new();
410
411        for (id, state) in sessions.iter() {
412            if let NegotiationState::Waiting { sent_at, .. } = state {
413                if sent_at.elapsed() > self.config.timeout {
414                    timed_out_ids.push(*id);
415                }
416            }
417        }
418
419        for id in timed_out_ids {
420            sessions.insert(id, NegotiationState::TimedOut {
421                timeout_at: Instant::now(),
422            });
423
424            let mut stats = self.stats.lock().unwrap();
425            stats.timed_out += 1;
426
427            warn!("Certificate type negotiation timed out: {:?}", id);
428        }
429    }
430
431    /// Clean up completed negotiations older than the specified duration
432    pub fn cleanup_old_sessions(&self, max_age: Duration) {
433        let mut sessions = self.sessions.write().unwrap();
434        let cutoff = Instant::now() - max_age;
435
436        sessions.retain(|id, state| {
437            let should_retain = match state {
438                NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
439                NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
440                NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
441                _ => true, // Keep pending and waiting sessions
442            };
443
444            if !should_retain {
445                debug!("Cleaned up old negotiation session: {:?}", id);
446            }
447
448            should_retain
449        });
450    }
451
452    /// Get current negotiation statistics
453    pub fn get_stats(&self) -> NegotiationStats {
454        self.stats.lock().unwrap().clone()
455    }
456
457    /// Clear all cached results
458    pub fn clear_cache(&self) {
459        let mut cache = self.cache.lock().unwrap();
460        cache.clear();
461        debug!("Cleared certificate type negotiation cache");
462    }
463
464    /// Get cache statistics
465    pub fn get_cache_stats(&self) -> (usize, usize) {
466        let cache = self.cache.lock().unwrap();
467        (cache.len(), self.config.max_cache_size)
468    }
469}
470
471impl Default for CertificateNegotiationManager {
472    fn default() -> Self {
473        Self::new(NegotiationConfig::default())
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use super::super::tls_extensions::CertificateType;
481
482    #[test]
483    fn test_negotiation_id_generation() {
484        let id1 = NegotiationId::new();
485        let id2 = NegotiationId::new();
486        
487        assert_ne!(id1, id2);
488        assert!(id1.as_u64() > 0);
489        assert!(id2.as_u64() > 0);
490    }
491
492    #[test]
493    fn test_negotiation_state_checks() {
494        let pending = NegotiationState::Pending;
495        assert!(!pending.is_complete());
496        assert!(!pending.is_successful());
497
498        let completed = NegotiationState::Completed {
499            result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
500            completed_at: Instant::now(),
501        };
502        assert!(completed.is_complete());
503        assert!(completed.is_successful());
504        assert!(completed.get_result().is_some());
505
506        let failed = NegotiationState::Failed {
507            error: "Test error".to_string(),
508            failed_at: Instant::now(),
509        };
510        assert!(failed.is_complete());
511        assert!(!failed.is_successful());
512        assert_eq!(failed.get_error().unwrap(), "Test error");
513    }
514
515    #[test]
516    fn test_negotiation_manager_basic_flow() {
517        let manager = CertificateNegotiationManager::default();
518        let preferences = CertificateTypePreferences::prefer_raw_public_key();
519
520        // Start negotiation
521        let id = manager.start_negotiation(preferences);
522        
523        let state = manager.get_negotiation_state(id).unwrap();
524        assert!(matches!(state, NegotiationState::Waiting { .. }));
525
526        // Complete negotiation
527        let remote_types = CertificateTypeList::raw_public_key_only();
528        let result = manager.complete_negotiation(
529            id,
530            Some(remote_types.clone()),
531            Some(remote_types),
532        ).unwrap();
533
534        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
535        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
536
537        let state = manager.get_negotiation_state(id).unwrap();
538        assert!(state.is_successful());
539    }
540
541    #[test]
542    fn test_negotiation_caching() {
543        let config = NegotiationConfig {
544            enable_caching: true,
545            ..Default::default()
546        };
547        let manager = CertificateNegotiationManager::new(config);
548        let preferences = CertificateTypePreferences::prefer_raw_public_key();
549
550        // First negotiation
551        let id1 = manager.start_negotiation(preferences.clone());
552        let remote_types = CertificateTypeList::raw_public_key_only();
553        let result1 = manager.complete_negotiation(
554            id1,
555            Some(remote_types.clone()),
556            Some(remote_types.clone()),
557        ).unwrap();
558
559        // Second negotiation with same preferences should hit cache
560        let id2 = manager.start_negotiation(preferences);
561        let result2 = manager.complete_negotiation(
562            id2,
563            Some(remote_types.clone()),
564            Some(remote_types),
565        ).unwrap();
566
567        assert_eq!(result1, result2);
568
569        let stats = manager.get_stats();
570        assert_eq!(stats.cache_hits, 1);
571        assert_eq!(stats.cache_misses, 1);
572    }
573
574    #[test]
575    fn test_negotiation_timeout_handling() {
576        let config = NegotiationConfig {
577            timeout: Duration::from_millis(1),
578            ..Default::default()
579        };
580        let manager = CertificateNegotiationManager::new(config);
581        let preferences = CertificateTypePreferences::prefer_raw_public_key();
582
583        let id = manager.start_negotiation(preferences);
584        
585        // Wait for timeout
586        std::thread::sleep(Duration::from_millis(10));
587        manager.handle_timeouts();
588
589        let state = manager.get_negotiation_state(id).unwrap();
590        assert!(matches!(state, NegotiationState::TimedOut { .. }));
591
592        let stats = manager.get_stats();
593        assert_eq!(stats.timed_out, 1);
594    }
595}