ant_quic/crypto/
certificate_negotiation.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7#![allow(missing_docs)]
8
9//! Certificate Type Negotiation Protocol Implementation
10//!
11//! This module implements the complete certificate type negotiation protocol
12//! as defined in RFC 7250, including state management, caching, and integration
13//! with both client and server sides of TLS connections.
14
15use std::{
16    collections::HashMap,
17    hash::{Hash, Hasher},
18    sync::{Arc, Mutex, RwLock},
19    time::{Duration, Instant},
20};
21
22use tracing::{Level, debug, info, span, warn};
23
24use super::tls_extensions::{
25    CertificateTypeList, CertificateTypePreferences, NegotiationResult, TlsExtensionError,
26};
27
28/// Negotiation state for a single TLS connection
29#[derive(Debug, Clone, PartialEq, Eq)]
30pub enum NegotiationState {
31    /// Negotiation not yet started
32    Pending,
33    /// Extensions sent, waiting for response
34    Waiting {
35        sent_at: Instant,
36        our_preferences: CertificateTypePreferences,
37    },
38    /// Negotiation completed successfully
39    Completed {
40        result: NegotiationResult,
41        completed_at: Instant,
42    },
43    /// Negotiation failed
44    Failed {
45        /// The error message
46        error: String,
47        /// When the failure occurred
48        failed_at: Instant,
49    },
50    /// Timed out waiting for response
51    TimedOut {
52        /// When the timeout occurred
53        timeout_at: Instant,
54    },
55}
56
57impl NegotiationState {
58    /// Check if negotiation is complete (either succeeded or failed)
59    pub fn is_complete(&self) -> bool {
60        matches!(
61            self,
62            Self::Completed { .. } | Self::Failed { .. } | Self::TimedOut { .. }
63        )
64    }
65
66    /// Check if negotiation succeeded
67    pub fn is_successful(&self) -> bool {
68        matches!(self, Self::Completed { .. })
69    }
70
71    /// Get the negotiation result if successful
72    pub fn get_result(&self) -> Option<&NegotiationResult> {
73        match self {
74            Self::Completed { result, .. } => Some(result),
75            _ => None,
76        }
77    }
78
79    /// Get error message if failed
80    pub fn get_error(&self) -> Option<&str> {
81        match self {
82            Self::Failed { error, .. } => Some(error),
83            _ => None,
84        }
85    }
86}
87
88/// Configuration for certificate type negotiation
89#[derive(Debug, Clone)]
90pub struct NegotiationConfig {
91    /// Timeout for waiting for negotiation response
92    pub timeout: Duration,
93    /// Whether to cache negotiation results
94    pub enable_caching: bool,
95    /// Maximum cache size
96    pub max_cache_size: usize,
97    /// Whether to allow fallback to X.509 if RPK negotiation fails
98    pub allow_fallback: bool,
99    /// Default preferences if none specified
100    pub default_preferences: CertificateTypePreferences,
101}
102
103impl Default for NegotiationConfig {
104    fn default() -> Self {
105        Self {
106            timeout: Duration::from_secs(10),
107            enable_caching: true,
108            max_cache_size: 1000,
109            allow_fallback: true,
110            default_preferences: CertificateTypePreferences::prefer_raw_public_key(),
111        }
112    }
113}
114
115/// Unique identifier for a negotiation session
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
117pub struct NegotiationId(u64);
118
119impl Default for NegotiationId {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl NegotiationId {
126    /// Generate a new unique negotiation ID
127    pub fn new() -> Self {
128        use std::sync::atomic::{AtomicU64, Ordering};
129        static COUNTER: AtomicU64 = AtomicU64::new(1);
130        Self(COUNTER.fetch_add(1, Ordering::Relaxed))
131    }
132
133    /// Get the raw ID value
134    pub fn as_u64(self) -> u64 {
135        self.0
136    }
137}
138
139/// Cache key for negotiation results
140#[derive(Debug, Clone, PartialEq, Eq, Hash)]
141struct CacheKey {
142    /// Our certificate type preferences
143    local_preferences: String, // Serialized preferences for hashing
144    /// Remote certificate type preferences  
145    remote_preferences: String, // Serialized preferences for hashing
146}
147
148impl CacheKey {
149    /// Create a cache key from preferences
150    fn new(
151        local: &CertificateTypePreferences,
152        remote_client: Option<&CertificateTypeList>,
153        remote_server: Option<&CertificateTypeList>,
154    ) -> Self {
155        use std::collections::hash_map::DefaultHasher;
156
157        let mut hasher = DefaultHasher::new();
158        local.hash(&mut hasher);
159        let local_hash = hasher.finish();
160
161        let mut hasher = DefaultHasher::new();
162        if let Some(types) = remote_client {
163            types.hash(&mut hasher);
164        }
165        if let Some(types) = remote_server {
166            types.hash(&mut hasher);
167        }
168        let remote_hash = hasher.finish();
169
170        Self {
171            local_preferences: format!("{local_hash:x}"),
172            remote_preferences: format!("{remote_hash:x}"),
173        }
174    }
175}
176
177/// Hash implementation for CertificateTypePreferences
178impl Hash for CertificateTypePreferences {
179    fn hash<H: Hasher>(&self, state: &mut H) {
180        self.client_types.types.hash(state);
181        self.server_types.types.hash(state);
182        self.require_extensions.hash(state);
183        self.fallback_client.hash(state);
184        self.fallback_server.hash(state);
185    }
186}
187
188/// Hash implementation for CertificateTypeList  
189impl Hash for CertificateTypeList {
190    fn hash<H: Hasher>(&self, state: &mut H) {
191        self.types.hash(state);
192    }
193}
194
195/// Certificate type negotiation manager
196pub struct CertificateNegotiationManager {
197    /// Configuration for negotiation behavior
198    config: NegotiationConfig,
199    /// Active negotiation sessions
200    sessions: RwLock<HashMap<NegotiationId, NegotiationState>>,
201    /// Result cache for performance optimization
202    cache: Arc<Mutex<HashMap<CacheKey, (NegotiationResult, Instant)>>>,
203    /// Negotiation statistics
204    stats: Arc<Mutex<NegotiationStats>>,
205}
206
207/// Statistics for certificate type negotiation
208#[derive(Debug, Default, Clone)]
209pub struct NegotiationStats {
210    /// Total number of negotiations attempted
211    pub total_attempts: u64,
212    /// Number of successful negotiations
213    pub successful: u64,
214    /// Number of failed negotiations
215    pub failed: u64,
216    /// Number of timed out negotiations
217    pub timed_out: u64,
218    /// Number of cache hits
219    pub cache_hits: u64,
220    /// Number of cache misses
221    pub cache_misses: u64,
222    /// Average negotiation time
223    pub avg_negotiation_time: Duration,
224}
225
226impl CertificateNegotiationManager {
227    /// Create a new negotiation manager
228    pub fn new(config: NegotiationConfig) -> Self {
229        Self {
230            config,
231            sessions: RwLock::new(HashMap::new()),
232            cache: Arc::new(Mutex::new(HashMap::new())),
233            stats: Arc::new(Mutex::new(NegotiationStats::default())),
234        }
235    }
236
237    /// Start a new certificate type negotiation
238    pub fn start_negotiation(
239        &self,
240        preferences: CertificateTypePreferences,
241    ) -> Result<NegotiationId, TlsExtensionError> {
242        let id = NegotiationId::new();
243        let state = NegotiationState::Waiting {
244            sent_at: Instant::now(),
245            our_preferences: preferences,
246        };
247
248        let mut sessions = self.sessions.write().map_err(|e| {
249            TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
250        })?;
251        sessions.insert(id, state);
252
253        let mut stats = self.stats.lock().map_err(|e| {
254            TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
255        })?;
256        stats.total_attempts += 1;
257
258        debug!("Started certificate type negotiation: {:?}", id);
259        Ok(id)
260    }
261
262    /// Complete a negotiation with remote preferences
263    pub fn complete_negotiation(
264        &self,
265        id: NegotiationId,
266        remote_client_types: Option<CertificateTypeList>,
267        remote_server_types: Option<CertificateTypeList>,
268    ) -> Result<NegotiationResult, TlsExtensionError> {
269        let _span = span!(Level::DEBUG, "complete_negotiation", id = id.as_u64()).entered();
270
271        let mut sessions = self.sessions.write().map_err(|e| {
272            TlsExtensionError::InvalidExtensionData(format!("Session lock poisoned: {}", e))
273        })?;
274        let state = sessions.get(&id).ok_or_else(|| {
275            TlsExtensionError::InvalidExtensionData(format!("Unknown negotiation ID: {id:?}"))
276        })?;
277
278        let our_preferences = match state {
279            NegotiationState::Waiting {
280                our_preferences, ..
281            } => our_preferences.clone(),
282            _ => {
283                return Err(TlsExtensionError::InvalidExtensionData(
284                    "Negotiation not in waiting state".to_string(),
285                ));
286            }
287        };
288
289        // Check cache first if enabled
290        if self.config.enable_caching {
291            let cache_key = CacheKey::new(
292                &our_preferences,
293                remote_client_types.as_ref(),
294                remote_server_types.as_ref(),
295            );
296
297            let mut cache = self.cache.lock().map_err(|e| {
298                TlsExtensionError::InvalidExtensionData(format!("Cache lock poisoned: {}", e))
299            })?;
300            if let Some((cached_result, cached_at)) = cache.get(&cache_key) {
301                // Check if cache entry is still valid (not expired)
302                if cached_at.elapsed() < Duration::from_secs(300) {
303                    // 5 minute cache
304                    let mut stats = self.stats.lock().map_err(|e| {
305                        TlsExtensionError::InvalidExtensionData(format!(
306                            "Stats lock poisoned: {}",
307                            e
308                        ))
309                    })?;
310                    stats.cache_hits += 1;
311
312                    // Update session state
313                    sessions.insert(
314                        id,
315                        NegotiationState::Completed {
316                            result: cached_result.clone(),
317                            completed_at: Instant::now(),
318                        },
319                    );
320
321                    debug!("Cache hit for negotiation: {:?}", id);
322                    return Ok(cached_result.clone());
323                } else {
324                    // Remove expired entry
325                    cache.remove(&cache_key);
326                }
327            }
328
329            let mut stats = self.stats.lock().map_err(|e| {
330                TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
331            })?;
332            stats.cache_misses += 1;
333        }
334
335        // Perform actual negotiation
336        let negotiation_start = Instant::now();
337        let result =
338            our_preferences.negotiate(remote_client_types.as_ref(), remote_server_types.as_ref());
339
340        match result {
341            Ok(negotiation_result) => {
342                let completed_at = Instant::now();
343                let negotiation_time = negotiation_start.elapsed();
344
345                // Update session state
346                sessions.insert(
347                    id,
348                    NegotiationState::Completed {
349                        result: negotiation_result.clone(),
350                        completed_at,
351                    },
352                );
353
354                // Update statistics
355                let mut stats = self.stats.lock().map_err(|e| {
356                    TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
357                })?;
358                stats.successful += 1;
359
360                // Update average negotiation time (simple moving average)
361                let total_completed = stats.successful + stats.failed;
362                stats.avg_negotiation_time = if total_completed == 1 {
363                    negotiation_time
364                } else {
365                    Duration::from_nanos(
366                        (stats.avg_negotiation_time.as_nanos() as u64 * (total_completed - 1)
367                            + negotiation_time.as_nanos() as u64)
368                            / total_completed,
369                    )
370                };
371
372                // Cache the result if caching is enabled
373                if self.config.enable_caching {
374                    let cache_key = CacheKey::new(
375                        &our_preferences,
376                        remote_client_types.as_ref(),
377                        remote_server_types.as_ref(),
378                    );
379
380                    let mut cache = self.cache.lock().map_err(|e| {
381                        TlsExtensionError::InvalidExtensionData(format!(
382                            "Cache lock poisoned: {}",
383                            e
384                        ))
385                    })?;
386
387                    // Evict old entries if cache is full
388                    if cache.len() >= self.config.max_cache_size {
389                        // Simple eviction: remove oldest entries
390                        let mut entries: Vec<_> =
391                            cache.iter().map(|(k, (_, t))| (k.clone(), *t)).collect();
392                        entries.sort_by_key(|(_, timestamp)| *timestamp);
393
394                        let to_remove = cache.len() - self.config.max_cache_size + 1;
395                        let keys_to_remove: Vec<_> = entries
396                            .iter()
397                            .take(to_remove)
398                            .map(|(key, _)| key.clone())
399                            .collect();
400
401                        for key in keys_to_remove {
402                            cache.remove(&key);
403                        }
404                    }
405
406                    cache.insert(cache_key, (negotiation_result.clone(), completed_at));
407                }
408
409                info!(
410                    "Certificate type negotiation completed successfully: {:?} -> client={}, server={}",
411                    id, negotiation_result.client_cert_type, negotiation_result.server_cert_type
412                );
413
414                Ok(negotiation_result)
415            }
416            Err(error) => {
417                // Update session state
418                sessions.insert(
419                    id,
420                    NegotiationState::Failed {
421                        error: error.to_string(),
422                        failed_at: Instant::now(),
423                    },
424                );
425
426                // Update statistics
427                let mut stats = self.stats.lock().map_err(|e| {
428                    TlsExtensionError::InvalidExtensionData(format!("Stats lock poisoned: {}", e))
429                })?;
430                stats.failed += 1;
431
432                warn!("Certificate type negotiation failed: {:?} -> {}", id, error);
433                Err(error)
434            }
435        }
436    }
437
438    /// Fail a negotiation with an error
439    #[allow(clippy::unwrap_used, clippy::expect_used)]
440    pub fn fail_negotiation(&self, id: NegotiationId, error: String) {
441        let mut sessions = self
442            .sessions
443            .write()
444            .expect("Mutex poisoning is unexpected in normal operation");
445        sessions.insert(
446            id,
447            NegotiationState::Failed {
448                error,
449                failed_at: Instant::now(),
450            },
451        );
452
453        let mut stats = self
454            .stats
455            .lock()
456            .expect("Mutex poisoning is unexpected in normal operation");
457        stats.failed += 1;
458
459        warn!("Certificate type negotiation failed: {:?}", id);
460    }
461
462    /// Get the current state of a negotiation
463    #[allow(clippy::unwrap_used, clippy::expect_used)]
464    pub fn get_negotiation_state(&self, id: NegotiationId) -> Option<NegotiationState> {
465        let sessions = self
466            .sessions
467            .read()
468            .expect("Mutex poisoning is unexpected in normal operation");
469        sessions.get(&id).cloned()
470    }
471
472    /// Check for and handle timed out negotiations
473    #[allow(clippy::unwrap_used, clippy::expect_used)]
474    pub fn handle_timeouts(&self) {
475        let mut sessions = self
476            .sessions
477            .write()
478            .expect("Mutex poisoning is unexpected in normal operation");
479        let mut timed_out_ids = Vec::new();
480
481        for (id, state) in sessions.iter() {
482            if let NegotiationState::Waiting { sent_at, .. } = state {
483                if sent_at.elapsed() > self.config.timeout {
484                    timed_out_ids.push(*id);
485                }
486            }
487        }
488
489        for id in timed_out_ids {
490            sessions.insert(
491                id,
492                NegotiationState::TimedOut {
493                    timeout_at: Instant::now(),
494                },
495            );
496
497            let mut stats = self
498                .stats
499                .lock()
500                .expect("Mutex poisoning is unexpected in normal operation");
501            stats.timed_out += 1;
502
503            warn!("Certificate type negotiation timed out: {:?}", id);
504        }
505    }
506
507    /// Clean up completed negotiations older than the specified duration
508    #[allow(clippy::unwrap_used, clippy::expect_used)]
509    pub fn cleanup_old_sessions(&self, max_age: Duration) {
510        let mut sessions = self
511            .sessions
512            .write()
513            .expect("Mutex poisoning is unexpected in normal operation");
514        let cutoff = Instant::now() - max_age;
515
516        sessions.retain(|id, state| {
517            let should_retain = match state {
518                NegotiationState::Completed { completed_at, .. } => *completed_at > cutoff,
519                NegotiationState::Failed { failed_at, .. } => *failed_at > cutoff,
520                NegotiationState::TimedOut { timeout_at, .. } => *timeout_at > cutoff,
521                _ => true, // Keep pending and waiting sessions
522            };
523
524            if !should_retain {
525                debug!("Cleaned up old negotiation session: {:?}", id);
526            }
527
528            should_retain
529        });
530    }
531
532    /// Get current negotiation statistics
533    #[allow(clippy::unwrap_used, clippy::expect_used)]
534    pub fn get_stats(&self) -> NegotiationStats {
535        self.stats
536            .lock()
537            .expect("Mutex poisoning is unexpected in normal operation")
538            .clone()
539    }
540
541    /// Clear all cached results
542    #[allow(clippy::unwrap_used, clippy::expect_used)]
543    pub fn clear_cache(&self) {
544        let mut cache = self
545            .cache
546            .lock()
547            .expect("Mutex poisoning is unexpected in normal operation");
548        cache.clear();
549        debug!("Cleared certificate type negotiation cache");
550    }
551
552    /// Get cache statistics
553    #[allow(clippy::unwrap_used, clippy::expect_used)]
554    pub fn get_cache_stats(&self) -> (usize, usize) {
555        let cache = self
556            .cache
557            .lock()
558            .expect("Mutex poisoning is unexpected in normal operation");
559        (cache.len(), self.config.max_cache_size)
560    }
561}
562
563impl Default for CertificateNegotiationManager {
564    fn default() -> Self {
565        Self::new(NegotiationConfig::default())
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::super::tls_extensions::CertificateType;
572    use super::*;
573
574    #[test]
575    fn test_negotiation_id_generation() {
576        let id1 = NegotiationId::new();
577        let id2 = NegotiationId::new();
578
579        assert_ne!(id1, id2);
580        assert!(id1.as_u64() > 0);
581        assert!(id2.as_u64() > 0);
582    }
583
584    #[test]
585    fn test_negotiation_state_checks() {
586        let pending = NegotiationState::Pending;
587        assert!(!pending.is_complete());
588        assert!(!pending.is_successful());
589
590        let completed = NegotiationState::Completed {
591            result: NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509),
592            completed_at: Instant::now(),
593        };
594        assert!(completed.is_complete());
595        assert!(completed.is_successful());
596        assert!(completed.get_result().is_some());
597
598        let failed = NegotiationState::Failed {
599            error: "Test error".to_string(),
600            failed_at: Instant::now(),
601        };
602        assert!(failed.is_complete());
603        assert!(!failed.is_successful());
604        assert_eq!(failed.get_error().unwrap(), "Test error");
605    }
606
607    #[test]
608    fn test_negotiation_manager_basic_flow() {
609        let manager = CertificateNegotiationManager::default();
610        let preferences = CertificateTypePreferences::prefer_raw_public_key();
611
612        // Start negotiation
613        let id = manager.start_negotiation(preferences).unwrap();
614
615        let state = manager.get_negotiation_state(id).unwrap();
616        assert!(matches!(state, NegotiationState::Waiting { .. }));
617
618        // Complete negotiation
619        let remote_types = CertificateTypeList::raw_public_key_only();
620        let result = manager
621            .complete_negotiation(id, Some(remote_types.clone()), Some(remote_types))
622            .unwrap();
623
624        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
625        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
626
627        let state = manager.get_negotiation_state(id).unwrap();
628        assert!(state.is_successful());
629    }
630
631    #[test]
632    fn test_negotiation_caching() {
633        let config = NegotiationConfig {
634            enable_caching: true,
635            ..Default::default()
636        };
637        let manager = CertificateNegotiationManager::new(config);
638        let preferences = CertificateTypePreferences::prefer_raw_public_key();
639
640        // First negotiation
641        let id1 = manager.start_negotiation(preferences.clone()).unwrap();
642        let remote_types = CertificateTypeList::raw_public_key_only();
643        let result1 = manager
644            .complete_negotiation(id1, Some(remote_types.clone()), Some(remote_types.clone()))
645            .unwrap();
646
647        // Second negotiation with same preferences should hit cache
648        let id2 = manager.start_negotiation(preferences).unwrap();
649        let result2 = manager
650            .complete_negotiation(id2, Some(remote_types.clone()), Some(remote_types))
651            .unwrap();
652
653        assert_eq!(result1, result2);
654
655        let stats = manager.get_stats();
656        assert_eq!(stats.cache_hits, 1);
657        assert_eq!(stats.cache_misses, 1);
658    }
659
660    #[test]
661    fn test_negotiation_timeout_handling() {
662        let config = NegotiationConfig {
663            timeout: Duration::from_millis(1),
664            ..Default::default()
665        };
666        let manager = CertificateNegotiationManager::new(config);
667        let preferences = CertificateTypePreferences::prefer_raw_public_key();
668
669        let id = manager.start_negotiation(preferences).unwrap();
670
671        // Wait for timeout
672        std::thread::sleep(Duration::from_millis(10));
673        manager.handle_timeouts();
674
675        let state = manager.get_negotiation_state(id).unwrap();
676        assert!(matches!(state, NegotiationState::TimedOut { .. }));
677
678        let stats = manager.get_stats();
679        assert_eq!(stats.timed_out, 1);
680    }
681}