ant_quic/crypto/
certificate_negotiation.rs

1//! Certificate Type Negotiation Protocol Implementation
2//!
3//! This module implements the complete certificate type negotiation protocol
4//! as defined in RFC 7250, including state management, caching, and integration
5//! with both client and server sides of TLS connections.
6
7use std::{
8    collections::HashMap,
9    hash::{Hash, Hasher},
10    sync::{Arc, Mutex, RwLock},
11    time::{Duration, Instant},
12};
13
14use tracing::{Level, debug, info, span, warn};
15
16use super::tls_extensions::{
17    CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
18};
19
20/// Negotiation state for a single TLS connection
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NegotiationState {
23    /// Negotiation not yet started
24    Pending,
25    /// Extensions sent, waiting for response
26    Waiting {
27        sent_at: Instant,
28        our_preferences: CertificateTypePreferences,
29    },
30    /// Negotiation completed successfully
31    Completed {
32        result: NegotiationResult,
33        completed_at: Instant,
34    },
35    /// Negotiation failed
36    Failed {
37        /// The error message
38        error: String,
39        /// When the failure occurred
40        failed_at: Instant,
41    },
42    /// Timed out waiting for response
43    TimedOut {
44        /// When the timeout occurred
45        timeout_at: Instant,
46    },
47}
48
49impl NegotiationState {
50    /// Check if negotiation is complete (either succeeded or failed)
51    pub fn is_complete(&self) -> bool {
52        matches!(
53            self,
54            Self::Completed { .. } | Self::Failed { .. } | Self::TimedOut { .. }
55        )
56    }
57
58    /// Check if negotiation succeeded
59    pub fn is_successful(&self) -> bool {
60        matches!(self, Self::Completed { .. })
61    }
62
63    /// Get the negotiation result if successful
64    pub fn get_result(&self) -> Option<&NegotiationResult> {
65        match self {
66            Self::Completed { result, .. } => Some(result),
67            _ => None,
68        }
69    }
70
71    /// Get error message if failed
72    pub fn get_error(&self) -> Option<&str> {
73        match self {
74            Self::Failed { error, .. } => Some(error),
75            _ => None,
76        }
77    }
78}
79
80/// Configuration for certificate type negotiation
81#[derive(Debug, Clone)]
82pub struct NegotiationConfig {
83    /// Timeout for waiting for negotiation response
84    pub timeout: Duration,
85    /// Whether to cache negotiation results
86    pub enable_caching: bool,
87    /// Maximum cache size
88    pub max_cache_size: usize,
89    /// Whether to allow fallback to X.509 if RPK negotiation fails
90    pub allow_fallback: bool,
91    /// Default preferences if none specified
92    pub default_preferences: CertificateTypePreferences,
93}
94
95impl Default for NegotiationConfig {
96    fn default() -> Self {
97        Self {
98            timeout: Duration::from_secs(10),
99            enable_caching: true,
100            max_cache_size: 1000,
101            allow_fallback: true,
102            default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
103        }
104    }
105}
106
107/// Unique identifier for a negotiation session
108#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
109pub struct NegotiationId(u64);
110
111impl Default for NegotiationId {
112    fn default() -> Self {
113        Self::new()
114    }
115}
116
117impl NegotiationId {
118    /// Generate a new unique negotiation ID
119    pub fn new() -> Self {
120        use std::sync::atomic::{AtomicU64, Ordering};
121        static COUNTER: AtomicU64 = AtomicU64::new(1);
122        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
123    }
124
125    /// Get the raw ID value
126    pub fn as_u64(self) -> u64 {
127        self.0
128    }
129}
130
131/// Cache key for negotiation results
132#[derive(Debug, Clone, PartialEq, Eq, Hash)]
133struct CacheKey {
134    /// Our certificate type preferences
135    local_preferences: String, // Serialized preferences for hashing
136    /// Remote certificate type preferences  
137    remote_preferences: String, // Serialized preferences for hashing
138}
139
140impl CacheKey {
141    /// Create a cache key from preferences
142    fn new(
143        local: &CertificateTypePreferences,
144        remote_client: Option<&CertificateTypeList>,
145        remote_server: Option<&CertificateTypeList>,
146    ) -> Self {
147        use std::collections::hash_map::DefaultHasher;
148
149        let mut hasher = DefaultHasher::new();
150        local.hash(&mut hasher);
151        let local_hash = hasher.finish();
152
153        let mut hasher = DefaultHasher::new();
154        if let Some(types) = remote_client {
155            types.hash(&mut hasher);
156        }
157        if let Some(types) = remote_server {
158            types.hash(&mut hasher);
159        }
160        let remote_hash = hasher.finish();
161
162        Self {
163            local_preferences: format!("{local_hash:x}"),
164            remote_preferences: format!("{remote_hash:x}"),
165        }
166    }
167}
168
169/// Hash implementation for CertificateTypePreferences
170impl Hash for CertificateTypePreferences {
171    fn hash<H: Hasher>(&self, state: &mut H) {
172        self.client_types.types.hash(state);
173        self.server_types.types.hash(state);
174        self.require_extensions.hash(state);
175        self.fallback_client.hash(state);
176        self.fallback_server.hash(state);
177    }
178}
179
180/// Hash implementation for CertificateTypeList  
181impl Hash for CertificateTypeList {
182    fn hash<H: Hasher>(&self, state: &mut H) {
183        self.types.hash(state);
184    }
185}
186
187/// Certificate type negotiation manager
188pub struct CertificateNegotiationManager {
189    /// Configuration for negotiation behavior
190    config: NegotiationConfig,
191    /// Active negotiation sessions
192    sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
193    /// Result cache for performance optimization
194    cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
195    /// Negotiation statistics
196    stats: Arc<Mutex<NegotiationStats>>,
197}
198
199/// Statistics for certificate type negotiation
200#[derive(Debug, Default, Clone)]
201pub struct NegotiationStats {
202    /// Total number of negotiations attempted
203    pub total_attempts: u64,
204    /// Number of successful negotiations
205    pub successful: u64,
206    /// Number of failed negotiations
207    pub failed: u64,
208    /// Number of timed out negotiations
209    pub timed_out: u64,
210    /// Number of cache hits
211    pub cache_hits: u64,
212    /// Number of cache misses
213    pub cache_misses: u64,
214    /// Average negotiation time
215    pub avg_negotiation_time: Duration,
216}
217
218impl CertificateNegotiationManager {
219    /// Create a new negotiation manager
220    pub fn new(config: NegotiationConfig) -> Self {
221        Self {
222            config,
223            sessions: RwLock::new(HashMap::new()),
224            cache: Arc::new(Mutex::new(HashMap::new())),
225            stats: Arc::new(Mutex::new(NegotiationStats::default())),
226        }
227    }
228
229    /// Start a new certificate type negotiation
230    pub fn start_negotiation(&self, preferences: CertificateTypePreferences) -> NegotiationId {
231        let id = NegotiationId::new();
232        let state = NegotiationState::Waiting {
233            sent_at: Instant::now(),
234            our_preferences: preferences,
235        };
236
237        let mut sessions = self.sessions.write().unwrap();
238        sessions.insert(id, state);
239
240        let mut stats = self.stats.lock().unwrap();
241        stats.total_attempts += 1;
242
243        debug!("Started certificate type negotiation: {:?}", id);
244        id
245    }
246
247    /// Complete a negotiation with remote preferences
248    pub fn complete_negotiation(
249        &self,
250        id: NegotiationId,
251        remote_client_types: Option<CertificateTypeList>,
252        remote_server_types: Option<CertificateTypeList>,
253    ) -> Result<NegotiationResult, TlsExtensionError> {
254        let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
255
256        let mut sessions = self.sessions.write().unwrap();
257        let state = sessions.get(&id).ok_or_else(|| {
258            TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {id:?}"))
259        })?;
260
261        let our_preferences = match state {
262            NegotiationState::Waiting {
263                our_preferences, ..
264            } => our_preferences.clone(),
265            _ => {
266                return Err(TlsExtensionError::InvalidExtensionData(
267                    "Negotiation not in waiting state".to_string(),
268                ));
269            }
270        };
271
272        // Check cache first if enabled
273        if self.config.enable_caching {
274            let cache_key = CacheKey::new(
275                &our_preferences,
276                remote_client_types.as_ref(),
277                remote_server_types.as_ref(),
278            );
279
280            let mut cache = self.cache.lock().unwrap();
281            if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
282                // Check if cache entry is still valid (not expired)
283                if cached_at.elapsed() < Duration::from_secs(300) {
284                    // 5 minute cache
285                    let mut stats = self.stats.lock().unwrap();
286                    stats.cache_hits += 1;
287
288                    // Update session state
289                    sessions.insert(
290                        id,
291                        NegotiationState::Completed {
292                            result: cached_result.clone(),
293                            completed_at: Instant::now(),
294                        },
295                    );
296
297                    debug!("Cache hit for negotiation: {:?}", id);
298                    return Ok(cached_result.clone());
299                } else {
300                    // Remove expired entry
301                    cache.remove(&cache_key);
302                }
303            }
304
305            let mut stats = self.stats.lock().unwrap();
306            stats.cache_misses += 1;
307        }
308
309        // Perform actual negotiation
310        let negotiation_start = Instant::now();
311        let result =
312            our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
313
314        match result {
315            Ok(negotiation_result) => {
316                let completed_at = Instant::now();
317                let negotiation_time = negotiation_start.elapsed();
318
319                // Update session state
320                sessions.insert(
321                    id,
322                    NegotiationState::Completed {
323                        result: negotiation_result.clone(),
324                        completed_at,
325                    },
326                );
327
328                // Update statistics
329                let mut stats = self.stats.lock().unwrap();
330                stats.successful += 1;
331
332                // Update average negotiation time (simple moving average)
333                let total_completed = stats.successful + stats.failed;
334                stats.avg_negotiation_time = if total_completed == 1 {
335                    negotiation_time
336                } else {
337                    Duration::from_nanos(
338                        (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
339                            + negotiation_time.as_nanos() as u64)
340                            / total_completed,
341                    )
342                };
343
344                // Cache the result if caching is enabled
345                if self.config.enable_caching {
346                    let cache_key = CacheKey::new(
347                        &our_preferences,
348                        remote_client_types.as_ref(),
349                        remote_server_types.as_ref(),
350                    );
351
352                    let mut cache = self.cache.lock().unwrap();
353
354                    // Evict old entries if cache is full
355                    if cache.len() >= self.config.max_cache_size {
356                        // Simple eviction: remove oldest entries
357                        let mut entries: Vec<_> =
358                            cache.iter().map(|(k, (_, t))| (k.clone(), *t)).collect();
359                        entries.sort_by_key(|(_, timestamp)| *timestamp);
360
361                        let to_remove = cache.len() - self.config.max_cache_size + 1;
362                        let keys_to_remove: Vec<_> = entries
363                            .iter()
364                            .take(to_remove)
365                            .map(|(key, _)| key.clone())
366                            .collect();
367
368                        for key in keys_to_remove {
369                            cache.remove(&key);
370                        }
371                    }
372
373                    cache.insert(cache_key, (negotiation_result.clone(), completed_at));
374                }
375
376                info!(
377                    "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
378                    id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
379                );
380
381                Ok(negotiation_result)
382            }
383            Err(error) => {
384                // Update session state
385                sessions.insert(
386                    id,
387                    NegotiationState::Failed {
388                        error: error.to_string(),
389                        failed_at: Instant::now(),
390                    },
391                );
392
393                // Update statistics
394                let mut stats = self.stats.lock().unwrap();
395                stats.failed += 1;
396
397                warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
398                Err(error)
399            }
400        }
401    }
402
403    /// Fail a negotiation with an error
404    pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
405        let mut sessions = self.sessions.write().unwrap();
406        sessions.insert(
407            id,
408            NegotiationState::Failed {
409                error,
410                failed_at: Instant::now(),
411            },
412        );
413
414        let mut stats = self.stats.lock().unwrap();
415        stats.failed += 1;
416
417        warn!("Certificate type negotiation failed: {:?}", id);
418    }
419
420    /// Get the current state of a negotiation
421    pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
422        let sessions = self.sessions.read().unwrap();
423        sessions.get(&id).cloned()
424    }
425
426    /// Check for and handle timed out negotiations
427    pub fn handle_timeouts(&self) {
428        let mut sessions = self.sessions.write().unwrap();
429        let mut timed_out_ids = Vec::new();
430
431        for (id, state) in sessions.iter() {
432            if let NegotiationState::Waiting { sent_at, .. } = state {
433                if sent_at.elapsed() > self.config.timeout {
434                    timed_out_ids.push(*id);
435                }
436            }
437        }
438
439        for id in timed_out_ids {
440            sessions.insert(
441                id,
442                NegotiationState::TimedOut {
443                    timeout_at: Instant::now(),
444                },
445            );
446
447            let mut stats = self.stats.lock().unwrap();
448            stats.timed_out += 1;
449
450            warn!("Certificate type negotiation timed out: {:?}", id);
451        }
452    }
453
454    /// Clean up completed negotiations older than the specified duration
455    pub fn cleanup_old_sessions(&self, max_age: Duration) {
456        let mut sessions = self.sessions.write().unwrap();
457        let cutoff = Instant::now() - max_age;
458
459        sessions.retain(|id, state| {
460            let should_retain = match state {
461                NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
462                NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
463                NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
464                _ => true, // Keep pending and waiting sessions
465            };
466
467            if !should_retain {
468                debug!("Cleaned up old negotiation session: {:?}", id);
469            }
470
471            should_retain
472        });
473    }
474
475    /// Get current negotiation statistics
476    pub fn get_stats(&self) -> NegotiationStats {
477        self.stats.lock().unwrap().clone()
478    }
479
480    /// Clear all cached results
481    pub fn clear_cache(&self) {
482        let mut cache = self.cache.lock().unwrap();
483        cache.clear();
484        debug!("Cleared certificate type negotiation cache");
485    }
486
487    /// Get cache statistics
488    pub fn get_cache_stats(&self) -> (usize, usize) {
489        let cache = self.cache.lock().unwrap();
490        (cache.len(), self.config.max_cache_size)
491    }
492}
493
494impl Default for CertificateNegotiationManager {
495    fn default() -> Self {
496        Self::new(NegotiationConfig::default())
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::super::tls_extensions::CertificateType;
503    use super::*;
504
505    #[test]
506    fn test_negotiation_id_generation() {
507        let id1 = NegotiationId::new();
508        let id2 = NegotiationId::new();
509
510        assert_ne!(id1, id2);
511        assert!(id1.as_u64() > 0);
512        assert!(id2.as_u64() > 0);
513    }
514
515    #[test]
516    fn test_negotiation_state_checks() {
517        let pending = NegotiationState::Pending;
518        assert!(!pending.is_complete());
519        assert!(!pending.is_successful());
520
521        let completed = NegotiationState::Completed {
522            result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
523            completed_at: Instant::now(),
524        };
525        assert!(completed.is_complete());
526        assert!(completed.is_successful());
527        assert!(completed.get_result().is_some());
528
529        let failed = NegotiationState::Failed {
530            error: "Test error".to_string(),
531            failed_at: Instant::now(),
532        };
533        assert!(failed.is_complete());
534        assert!(!failed.is_successful());
535        assert_eq!(failed.get_error().unwrap(), "Test error");
536    }
537
538    #[test]
539    fn test_negotiation_manager_basic_flow() {
540        let manager = CertificateNegotiationManager::default();
541        let preferences = CertificateTypePreferences::prefer_raw_public_key();
542
543        // Start negotiation
544        let id = manager.start_negotiation(preferences);
545
546        let state = manager.get_negotiation_state(id).unwrap();
547        assert!(matches!(state, NegotiationState::Waiting { .. }));
548
549        // Complete negotiation
550        let remote_types = CertificateTypeList::raw_public_key_only();
551        let result = manager
552            .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
553            .unwrap();
554
555        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
556        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
557
558        let state = manager.get_negotiation_state(id).unwrap();
559        assert!(state.is_successful());
560    }
561
562    #[test]
563    fn test_negotiation_caching() {
564        let config = NegotiationConfig {
565            enable_caching: true,
566            ..Default::default()
567        };
568        let manager = CertificateNegotiationManager::new(config);
569        let preferences = CertificateTypePreferences::prefer_raw_public_key();
570
571        // First negotiation
572        let id1 = manager.start_negotiation(preferences.clone());
573        let remote_types = CertificateTypeList::raw_public_key_only();
574        let result1 = manager
575            .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
576            .unwrap();
577
578        // Second negotiation with same preferences should hit cache
579        let id2 = manager.start_negotiation(preferences);
580        let result2 = manager
581            .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
582            .unwrap();
583
584        assert_eq!(result1, result2);
585
586        let stats = manager.get_stats();
587        assert_eq!(stats.cache_hits, 1);
588        assert_eq!(stats.cache_misses, 1);
589    }
590
591    #[test]
592    fn test_negotiation_timeout_handling() {
593        let config = NegotiationConfig {
594            timeout: Duration::from_millis(1),
595            ..Default::default()
596        };
597        let manager = CertificateNegotiationManager::new(config);
598        let preferences = CertificateTypePreferences::prefer_raw_public_key();
599
600        let id = manager.start_negotiation(preferences);
601
602        // Wait for timeout
603        std::thread::sleep(Duration::from_millis(10));
604        manager.handle_timeouts();
605
606        let state = manager.get_negotiation_state(id).unwrap();
607        assert!(matches!(state, NegotiationState::TimedOut { .. }));
608
609        let stats = manager.get_stats();
610        assert_eq!(stats.timed_out, 1);
611    }
612}