ant_quic/crypto/
certificate_negotiation.rs

1//! Certificate Type Negotiation Protocol Implementation
2//!
3//! This module implements the complete certificate type negotiation protocol
4//! as defined in RFC 7250, including state management, caching, and integration
5//! with both client and server sides of TLS connections.
6
7use std::{
8    collections::HashMap,
9    hash::{Hash, Hasher},
10    sync::{Arc, Mutex, RwLock},
11    time::{Duration, Instant},
12};
13
14use tracing::{Level, debug, info, span, warn};
15
16use super::tls_extensions::{
17    CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
18};
19
20/// Negotiation state for a single TLS connection
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub enum NegotiationState {
23    /// Negotiation not yet started
24    Pending,
25    /// Extensions sent, waiting for response
26    Waiting {
27        sent_at: Instant,
28        our_preferences: CertificateTypePreferences,
29    },
30    /// Negotiation completed successfully
31    Completed {
32        result: NegotiationResult,
33        completed_at: Instant,
34    },
35    /// Negotiation failed
36    Failed { error: String, failed_at: Instant },
37    /// Timed out waiting for response
38    TimedOut { timeout_at: Instant },
39}
40
41impl NegotiationState {
42    /// Check if negotiation is complete (either succeeded or failed)
43    pub fn is_complete(&self) -> bool {
44        matches!(
45            self,
46            NegotiationState::Completed { .. }
47                | NegotiationState::Failed { .. }
48                | NegotiationState::TimedOut { .. }
49        )
50    }
51
52    /// Check if negotiation succeeded
53    pub fn is_successful(&self) -> bool {
54        matches!(self, NegotiationState::Completed { .. })
55    }
56
57    /// Get the negotiation result if successful
58    pub fn get_result(&self) -> Option<&NegotiationResult> {
59        match self {
60            NegotiationState::Completed { result, .. } => Some(result),
61            _ => None,
62        }
63    }
64
65    /// Get error message if failed
66    pub fn get_error(&self) -> Option<&str> {
67        match self {
68            NegotiationState::Failed { error, .. } => Some(error),
69            _ => None,
70        }
71    }
72}
73
74/// Configuration for certificate type negotiation
75#[derive(Debug, Clone)]
76pub struct NegotiationConfig {
77    /// Timeout for waiting for negotiation response
78    pub timeout: Duration,
79    /// Whether to cache negotiation results
80    pub enable_caching: bool,
81    /// Maximum cache size
82    pub max_cache_size: usize,
83    /// Whether to allow fallback to X.509 if RPK negotiation fails
84    pub allow_fallback: bool,
85    /// Default preferences if none specified
86    pub default_preferences: CertificateTypePreferences,
87}
88
89impl Default for NegotiationConfig {
90    fn default() -> Self {
91        Self {
92            timeout: Duration::from_secs(10),
93            enable_caching: true,
94            max_cache_size: 1000,
95            allow_fallback: true,
96            default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
97        }
98    }
99}
100
101/// Unique identifier for a negotiation session
102#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
103pub struct NegotiationId(u64);
104
105impl NegotiationId {
106    /// Generate a new unique negotiation ID
107    pub fn new() -> Self {
108        use std::sync::atomic::{AtomicU64, Ordering};
109        static COUNTER: AtomicU64 = AtomicU64::new(1);
110        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
111    }
112
113    /// Get the raw ID value
114    pub fn as_u64(self) -> u64 {
115        self.0
116    }
117}
118
119/// Cache key for negotiation results
120#[derive(Debug, Clone, PartialEq, Eq, Hash)]
121struct CacheKey {
122    /// Our certificate type preferences
123    local_preferences: String, // Serialized preferences for hashing
124    /// Remote certificate type preferences  
125    remote_preferences: String, // Serialized preferences for hashing
126}
127
128impl CacheKey {
129    /// Create a cache key from preferences
130    fn new(
131        local: &CertificateTypePreferences,
132        remote_client: Option<&CertificateTypeList>,
133        remote_server: Option<&CertificateTypeList>,
134    ) -> Self {
135        use std::collections::hash_map::DefaultHasher;
136
137        let mut hasher = DefaultHasher::new();
138        local.hash(&mut hasher);
139        let local_hash = hasher.finish();
140
141        let mut hasher = DefaultHasher::new();
142        if let Some(types) = remote_client {
143            types.hash(&mut hasher);
144        }
145        if let Some(types) = remote_server {
146            types.hash(&mut hasher);
147        }
148        let remote_hash = hasher.finish();
149
150        Self {
151            local_preferences: format!("{:x}", local_hash),
152            remote_preferences: format!("{:x}", remote_hash),
153        }
154    }
155}
156
157/// Hash implementation for CertificateTypePreferences
158impl Hash for CertificateTypePreferences {
159    fn hash<H: Hasher>(&self, state: &mut H) {
160        self.client_types.types.hash(state);
161        self.server_types.types.hash(state);
162        self.require_extensions.hash(state);
163        self.fallback_client.hash(state);
164        self.fallback_server.hash(state);
165    }
166}
167
168/// Hash implementation for CertificateTypeList  
169impl Hash for CertificateTypeList {
170    fn hash<H: Hasher>(&self, state: &mut H) {
171        self.types.hash(state);
172    }
173}
174
175/// Certificate type negotiation manager
176pub struct CertificateNegotiationManager {
177    /// Configuration for negotiation behavior
178    config: NegotiationConfig,
179    /// Active negotiation sessions
180    sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
181    /// Result cache for performance optimization
182    cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
183    /// Negotiation statistics
184    stats: Arc<Mutex<NegotiationStats>>,
185}
186
187/// Statistics for certificate type negotiation
188#[derive(Debug, Default, Clone)]
189pub struct NegotiationStats {
190    /// Total number of negotiations attempted
191    pub total_attempts: u64,
192    /// Number of successful negotiations
193    pub successful: u64,
194    /// Number of failed negotiations
195    pub failed: u64,
196    /// Number of timed out negotiations
197    pub timed_out: u64,
198    /// Number of cache hits
199    pub cache_hits: u64,
200    /// Number of cache misses
201    pub cache_misses: u64,
202    /// Average negotiation time
203    pub avg_negotiation_time: Duration,
204}
205
206impl CertificateNegotiationManager {
207    /// Create a new negotiation manager
208    pub fn new(config: NegotiationConfig) -> Self {
209        Self {
210            config,
211            sessions: RwLock::new(HashMap::new()),
212            cache: Arc::new(Mutex::new(HashMap::new())),
213            stats: Arc::new(Mutex::new(NegotiationStats::default())),
214        }
215    }
216
217    /// Start a new certificate type negotiation
218    pub fn start_negotiation(&self, preferences: CertificateTypePreferences) -> NegotiationId {
219        let id = NegotiationId::new();
220        let state = NegotiationState::Waiting {
221            sent_at: Instant::now(),
222            our_preferences: preferences,
223        };
224
225        let mut sessions = self.sessions.write().unwrap();
226        sessions.insert(id, state);
227
228        let mut stats = self.stats.lock().unwrap();
229        stats.total_attempts += 1;
230
231        debug!("Started certificate type negotiation: {:?}", id);
232        id
233    }
234
235    /// Complete a negotiation with remote preferences
236    pub fn complete_negotiation(
237        &self,
238        id: NegotiationId,
239        remote_client_types: Option<CertificateTypeList>,
240        remote_server_types: Option<CertificateTypeList>,
241    ) -> Result<NegotiationResult, TlsExtensionError> {
242        let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
243
244        let mut sessions = self.sessions.write().unwrap();
245        let state = sessions.get(&id).ok_or_else(|| {
246            TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {:?}", id))
247        })?;
248
249        let our_preferences = match state {
250            NegotiationState::Waiting {
251                our_preferences, ..
252            } => our_preferences.clone(),
253            _ => {
254                return Err(TlsExtensionError::InvalidExtensionData(
255                    "Negotiation not in waiting state".to_string(),
256                ));
257            }
258        };
259
260        // Check cache first if enabled
261        if self.config.enable_caching {
262            let cache_key = CacheKey::new(
263                &our_preferences,
264                remote_client_types.as_ref(),
265                remote_server_types.as_ref(),
266            );
267
268            let mut cache = self.cache.lock().unwrap();
269            if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
270                // Check if cache entry is still valid (not expired)
271                if cached_at.elapsed() < Duration::from_secs(300) {
272                    // 5 minute cache
273                    let mut stats = self.stats.lock().unwrap();
274                    stats.cache_hits += 1;
275
276                    // Update session state
277                    sessions.insert(
278                        id,
279                        NegotiationState::Completed {
280                            result: cached_result.clone(),
281                            completed_at: Instant::now(),
282                        },
283                    );
284
285                    debug!("Cache hit for negotiation: {:?}", id);
286                    return Ok(cached_result.clone());
287                } else {
288                    // Remove expired entry
289                    cache.remove(&cache_key);
290                }
291            }
292
293            let mut stats = self.stats.lock().unwrap();
294            stats.cache_misses += 1;
295        }
296
297        // Perform actual negotiation
298        let negotiation_start = Instant::now();
299        let result =
300            our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
301
302        match result {
303            Ok(negotiation_result) => {
304                let completed_at = Instant::now();
305                let negotiation_time = negotiation_start.elapsed();
306
307                // Update session state
308                sessions.insert(
309                    id,
310                    NegotiationState::Completed {
311                        result: negotiation_result.clone(),
312                        completed_at,
313                    },
314                );
315
316                // Update statistics
317                let mut stats = self.stats.lock().unwrap();
318                stats.successful += 1;
319
320                // Update average negotiation time (simple moving average)
321                let total_completed = stats.successful + stats.failed;
322                stats.avg_negotiation_time = if total_completed == 1 {
323                    negotiation_time
324                } else {
325                    Duration::from_nanos(
326                        (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
327                            + negotiation_time.as_nanos() as u64)
328                            / total_completed,
329                    )
330                };
331
332                // Cache the result if caching is enabled
333                if self.config.enable_caching {
334                    let cache_key = CacheKey::new(
335                        &our_preferences,
336                        remote_client_types.as_ref(),
337                        remote_server_types.as_ref(),
338                    );
339
340                    let mut cache = self.cache.lock().unwrap();
341
342                    // Evict old entries if cache is full
343                    if cache.len() >= self.config.max_cache_size {
344                        // Simple eviction: remove oldest entries
345                        let mut entries: Vec<_> = cache
346                            .iter()
347                            .map(|(k, (_, t))| (k.clone(), t.clone()))
348                            .collect();
349                        entries.sort_by_key(|(_, timestamp)| *timestamp);
350
351                        let to_remove = cache.len() - self.config.max_cache_size + 1;
352                        let keys_to_remove: Vec<_> = entries
353                            .iter()
354                            .take(to_remove)
355                            .map(|(key, _)| key.clone())
356                            .collect();
357
358                        for key in keys_to_remove {
359                            cache.remove(&key);
360                        }
361                    }
362
363                    cache.insert(cache_key, (negotiation_result.clone(), completed_at));
364                }
365
366                info!(
367                    "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
368                    id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
369                );
370
371                Ok(negotiation_result)
372            }
373            Err(error) => {
374                // Update session state
375                sessions.insert(
376                    id,
377                    NegotiationState::Failed {
378                        error: error.to_string(),
379                        failed_at: Instant::now(),
380                    },
381                );
382
383                // Update statistics
384                let mut stats = self.stats.lock().unwrap();
385                stats.failed += 1;
386
387                warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
388                Err(error)
389            }
390        }
391    }
392
393    /// Fail a negotiation with an error
394    pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
395        let mut sessions = self.sessions.write().unwrap();
396        sessions.insert(
397            id,
398            NegotiationState::Failed {
399                error,
400                failed_at: Instant::now(),
401            },
402        );
403
404        let mut stats = self.stats.lock().unwrap();
405        stats.failed += 1;
406
407        warn!("Certificate type negotiation failed: {:?}", id);
408    }
409
410    /// Get the current state of a negotiation
411    pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
412        let sessions = self.sessions.read().unwrap();
413        sessions.get(&id).cloned()
414    }
415
416    /// Check for and handle timed out negotiations
417    pub fn handle_timeouts(&self) {
418        let mut sessions = self.sessions.write().unwrap();
419        let mut timed_out_ids = Vec::new();
420
421        for (id, state) in sessions.iter() {
422            if let NegotiationState::Waiting { sent_at, .. } = state {
423                if sent_at.elapsed() > self.config.timeout {
424                    timed_out_ids.push(*id);
425                }
426            }
427        }
428
429        for id in timed_out_ids {
430            sessions.insert(
431                id,
432                NegotiationState::TimedOut {
433                    timeout_at: Instant::now(),
434                },
435            );
436
437            let mut stats = self.stats.lock().unwrap();
438            stats.timed_out += 1;
439
440            warn!("Certificate type negotiation timed out: {:?}", id);
441        }
442    }
443
444    /// Clean up completed negotiations older than the specified duration
445    pub fn cleanup_old_sessions(&self, max_age: Duration) {
446        let mut sessions = self.sessions.write().unwrap();
447        let cutoff = Instant::now() - max_age;
448
449        sessions.retain(|id, state| {
450            let should_retain = match state {
451                NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
452                NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
453                NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
454                _ => true, // Keep pending and waiting sessions
455            };
456
457            if !should_retain {
458                debug!("Cleaned up old negotiation session: {:?}", id);
459            }
460
461            should_retain
462        });
463    }
464
465    /// Get current negotiation statistics
466    pub fn get_stats(&self) -> NegotiationStats {
467        self.stats.lock().unwrap().clone()
468    }
469
470    /// Clear all cached results
471    pub fn clear_cache(&self) {
472        let mut cache = self.cache.lock().unwrap();
473        cache.clear();
474        debug!("Cleared certificate type negotiation cache");
475    }
476
477    /// Get cache statistics
478    pub fn get_cache_stats(&self) -> (usize, usize) {
479        let cache = self.cache.lock().unwrap();
480        (cache.len(), self.config.max_cache_size)
481    }
482}
483
484impl Default for CertificateNegotiationManager {
485    fn default() -> Self {
486        Self::new(NegotiationConfig::default())
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::super::tls_extensions::CertificateType;
493    use super::*;
494
495    #[test]
496    fn test_negotiation_id_generation() {
497        let id1 = NegotiationId::new();
498        let id2 = NegotiationId::new();
499
500        assert_ne!(id1, id2);
501        assert!(id1.as_u64() > 0);
502        assert!(id2.as_u64() > 0);
503    }
504
505    #[test]
506    fn test_negotiation_state_checks() {
507        let pending = NegotiationState::Pending;
508        assert!(!pending.is_complete());
509        assert!(!pending.is_successful());
510
511        let completed = NegotiationState::Completed {
512            result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
513            completed_at: Instant::now(),
514        };
515        assert!(completed.is_complete());
516        assert!(completed.is_successful());
517        assert!(completed.get_result().is_some());
518
519        let failed = NegotiationState::Failed {
520            error: "Test error".to_string(),
521            failed_at: Instant::now(),
522        };
523        assert!(failed.is_complete());
524        assert!(!failed.is_successful());
525        assert_eq!(failed.get_error().unwrap(), "Test error");
526    }
527
528    #[test]
529    fn test_negotiation_manager_basic_flow() {
530        let manager = CertificateNegotiationManager::default();
531        let preferences = CertificateTypePreferences::prefer_raw_public_key();
532
533        // Start negotiation
534        let id = manager.start_negotiation(preferences);
535
536        let state = manager.get_negotiation_state(id).unwrap();
537        assert!(matches!(state, NegotiationState::Waiting { .. }));
538
539        // Complete negotiation
540        let remote_types = CertificateTypeList::raw_public_key_only();
541        let result = manager
542            .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
543            .unwrap();
544
545        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
546        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
547
548        let state = manager.get_negotiation_state(id).unwrap();
549        assert!(state.is_successful());
550    }
551
552    #[test]
553    fn test_negotiation_caching() {
554        let config = NegotiationConfig {
555            enable_caching: true,
556            ..Default::default()
557        };
558        let manager = CertificateNegotiationManager::new(config);
559        let preferences = CertificateTypePreferences::prefer_raw_public_key();
560
561        // First negotiation
562        let id1 = manager.start_negotiation(preferences.clone());
563        let remote_types = CertificateTypeList::raw_public_key_only();
564        let result1 = manager
565            .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
566            .unwrap();
567
568        // Second negotiation with same preferences should hit cache
569        let id2 = manager.start_negotiation(preferences);
570        let result2 = manager
571            .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
572            .unwrap();
573
574        assert_eq!(result1, result2);
575
576        let stats = manager.get_stats();
577        assert_eq!(stats.cache_hits, 1);
578        assert_eq!(stats.cache_misses, 1);
579    }
580
581    #[test]
582    fn test_negotiation_timeout_handling() {
583        let config = NegotiationConfig {
584            timeout: Duration::from_millis(1),
585            ..Default::default()
586        };
587        let manager = CertificateNegotiationManager::new(config);
588        let preferences = CertificateTypePreferences::prefer_raw_public_key();
589
590        let id = manager.start_negotiation(preferences);
591
592        // Wait for timeout
593        std::thread::sleep(Duration::from_millis(10));
594        manager.handle_timeouts();
595
596        let state = manager.get_negotiation_state(id).unwrap();
597        assert!(matches!(state, NegotiationState::TimedOut { .. }));
598
599        let stats = manager.get_stats();
600        assert_eq!(stats.timed_out, 1);
601    }
602}