ant_quic/crypto/
tls_extensions.rs

1//! TLS Extensions for RFC 7250 Raw Public Keys Certificate Type Negotiation
2//!
3//! This module implements the TLS 1.3 extensions defined in RFC 7250 Section 4.2:
4//! - client_certificate_type (47): Client's certificate type preferences
5//! - server_certificate_type (48): Server's certificate type preferences
6//!
7//! These extensions enable proper negotiation of certificate types during TLS handshake,
8//! allowing clients and servers to indicate support for Raw Public Keys (value 2)
9//! in addition to traditional X.509 certificates (value 0).
10
11use std::{
12    collections::HashMap,
13    fmt::{self, Debug},
14};
15
16
17
18/// Certificate type values as defined in RFC 7250 and IANA registry
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
20#[repr(u8)]
21pub enum CertificateType {
22    /// X.509 certificate (traditional PKI certificates)
23    X509 = 0,
24    /// Raw Public Key (RFC 7250)
25    RawPublicKey = 2,
26}
27
28impl CertificateType {
29    /// Parse certificate type from wire format
30    pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
31        match value {
32            0 => Ok(CertificateType::X509),
33            2 => Ok(CertificateType::RawPublicKey),
34            _ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
35        }
36    }
37
38    /// Convert certificate type to wire format
39    pub fn to_u8(self) -> u8 {
40        self as u8
41    }
42
43    /// Check if this certificate type is Raw Public Key
44    pub fn is_raw_public_key(self) -> bool {
45        matches!(self, CertificateType::RawPublicKey)
46    }
47
48    /// Check if this certificate type is X.509
49    pub fn is_x509(self) -> bool {
50        matches!(self, CertificateType::X509)
51    }
52}
53
54impl fmt::Display for CertificateType {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            CertificateType::X509 => write!(f, "X.509"),
58            CertificateType::RawPublicKey => write!(f, "RawPublicKey"),
59        }
60    }
61}
62
63/// Certificate type preference list for negotiation
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct CertificateTypeList {
66    /// Ordered list of certificate types by preference (most preferred first)
67    pub types: Vec<CertificateType>,
68}
69
70impl CertificateTypeList {
71    /// Create a new certificate type list
72    pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
73        if types.is_empty() {
74            return Err(TlsExtensionError::EmptyCertificateTypeList);
75        }
76        if types.len() > 255 {
77            return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
78        }
79        
80        // Check for duplicates
81        let mut seen = std::collections::HashSet::new();
82        for cert_type in &types {
83            if !seen.insert(*cert_type) {
84                return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
85            }
86        }
87        
88        Ok(CertificateTypeList { types })
89    }
90
91    /// Create a Raw Public Key only preference list
92    pub fn raw_public_key_only() -> Self {
93        CertificateTypeList {
94            types: vec![CertificateType::RawPublicKey],
95        }
96    }
97
98    /// Create a preference list favoring Raw Public Keys with X.509 fallback
99    pub fn prefer_raw_public_key() -> Self {
100        CertificateTypeList {
101            types: vec![CertificateType::RawPublicKey, CertificateType::X509],
102        }
103    }
104
105    /// Create an X.509 only preference list
106    pub fn x509_only() -> Self {
107        CertificateTypeList {
108            types: vec![CertificateType::X509],
109        }
110    }
111
112    /// Get the most preferred certificate type
113    pub fn most_preferred(&self) -> CertificateType {
114        self.types[0]
115    }
116
117    /// Check if Raw Public Key is supported
118    pub fn supports_raw_public_key(&self) -> bool {
119        self.types.contains(&CertificateType::RawPublicKey)
120    }
121
122    /// Check if X.509 is supported
123    pub fn supports_x509(&self) -> bool {
124        self.types.contains(&CertificateType::X509)
125    }
126
127    /// Find the best common certificate type between two preference lists
128    pub fn negotiate(&self, other: &CertificateTypeList) -> Option<CertificateType> {
129        // Find the first certificate type in our preference list that is also supported by the other party
130        for cert_type in &self.types {
131            if other.types.contains(cert_type) {
132                return Some(*cert_type);
133            }
134        }
135        None
136    }
137
138    /// Serialize to wire format (length-prefixed list)
139    pub fn to_bytes(&self) -> Vec<u8> {
140        let mut bytes = Vec::with_capacity(1 + self.types.len());
141        bytes.push(self.types.len() as u8);
142        for cert_type in &self.types {
143            bytes.push(cert_type.to_u8());
144        }
145        bytes
146    }
147
148    /// Parse from wire format
149    pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
150        if bytes.is_empty() {
151            return Err(TlsExtensionError::InvalidExtensionData("Empty certificate type list".to_string()));
152        }
153
154        let length = bytes[0] as usize;
155        if length == 0 {
156            return Err(TlsExtensionError::EmptyCertificateTypeList);
157        }
158        if length > 255 {
159            return Err(TlsExtensionError::CertificateTypeListTooLong(length));
160        }
161        if bytes.len() != 1 + length {
162            return Err(TlsExtensionError::InvalidExtensionData(
163                format!("Certificate type list length mismatch: expected {}, got {}", 1 + length, bytes.len())
164            ));
165        }
166
167        let mut types = Vec::with_capacity(length);
168        for i in 1..=length {
169            let cert_type = CertificateType::from_u8(bytes[i])?;
170            types.push(cert_type);
171        }
172
173        CertificateTypeList::new(types)
174    }
175}
176
177/// TLS extension IDs for certificate type negotiation (RFC 7250)
178pub mod extension_ids {
179    /// Client certificate type extension ID
180    pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
181    /// Server certificate type extension ID  
182    pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
183}
184
185/// Errors that can occur during TLS extension processing
186#[derive(Debug, Clone)]
187pub enum TlsExtensionError {
188    /// Unsupported certificate type value
189    UnsupportedCertificateType(u8),
190    /// Empty certificate type list
191    EmptyCertificateTypeList,
192    /// Certificate type list too long (>255 entries)
193    CertificateTypeListTooLong(usize),
194    /// Duplicate certificate type in list
195    DuplicateCertificateType(CertificateType),
196    /// Invalid extension data format
197    InvalidExtensionData(String),
198    /// Certificate type negotiation failed
199    NegotiationFailed {
200        client_types: CertificateTypeList,
201        server_types: CertificateTypeList,
202    },
203    /// Extension already registered
204    ExtensionAlreadyRegistered(u16),
205    /// rustls integration error
206    RustlsError(String),
207}
208
209impl fmt::Display for TlsExtensionError {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        match self {
212            TlsExtensionError::UnsupportedCertificateType(value) => {
213                write!(f, "Unsupported certificate type: {}", value)
214            }
215            TlsExtensionError::EmptyCertificateTypeList => {
216                write!(f, "Certificate type list cannot be empty")
217            }
218            TlsExtensionError::CertificateTypeListTooLong(len) => {
219                write!(f, "Certificate type list too long: {} (max 255)", len)
220            }
221            TlsExtensionError::DuplicateCertificateType(cert_type) => {
222                write!(f, "Duplicate certificate type: {}", cert_type)
223            }
224            TlsExtensionError::InvalidExtensionData(msg) => {
225                write!(f, "Invalid extension data: {}", msg)
226            }
227            TlsExtensionError::NegotiationFailed { client_types, server_types } => {
228                write!(f, "Certificate type negotiation failed: client={:?}, server={:?}", 
229                       client_types, server_types)
230            }
231            TlsExtensionError::ExtensionAlreadyRegistered(id) => {
232                write!(f, "Extension already registered: {}", id)
233            }
234            TlsExtensionError::RustlsError(msg) => {
235                write!(f, "rustls error: {}", msg)
236            }
237        }
238    }
239}
240
241impl std::error::Error for TlsExtensionError {}
242
243/// Certificate type negotiation result
244#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
245pub struct NegotiationResult {
246    /// Negotiated client certificate type
247    pub client_cert_type: CertificateType,
248    /// Negotiated server certificate type
249    pub server_cert_type: CertificateType,
250}
251
252impl NegotiationResult {
253    /// Create a new negotiation result
254    pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
255        Self {
256            client_cert_type,
257            server_cert_type,
258        }
259    }
260
261    /// Check if Raw Public Keys are used for both client and server
262    pub fn is_raw_public_key_only(&self) -> bool {
263        self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
264    }
265
266    /// Check if X.509 certificates are used for both client and server
267    pub fn is_x509_only(&self) -> bool {
268        self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
269    }
270
271    /// Check if this is a mixed deployment (one RPK, one X.509)
272    pub fn is_mixed(&self) -> bool {
273        !self.is_raw_public_key_only() && !self.is_x509_only()
274    }
275}
276
277/// Certificate type negotiation preferences and state
278#[derive(Debug, Clone, PartialEq, Eq)]
279pub struct CertificateTypePreferences {
280    /// Client certificate type preferences (what types we support for client auth)
281    pub client_types: CertificateTypeList,
282    /// Server certificate type preferences (what types we support for server auth)
283    pub server_types: CertificateTypeList,
284    /// Whether to require certificate type extensions (strict mode)
285    pub require_extensions: bool,
286    /// Default fallback certificate types if negotiation fails
287    pub fallback_client: CertificateType,
288    pub fallback_server: CertificateType,
289}
290
291impl CertificateTypePreferences {
292    /// Create preferences favoring Raw Public Keys
293    pub fn prefer_raw_public_key() -> Self {
294        Self {
295            client_types: CertificateTypeList::prefer_raw_public_key(),
296            server_types: CertificateTypeList::prefer_raw_public_key(),
297            require_extensions: false,
298            fallback_client: CertificateType::X509,
299            fallback_server: CertificateType::X509,
300        }
301    }
302
303    /// Create preferences for Raw Public Key only
304    pub fn raw_public_key_only() -> Self {
305        Self {
306            client_types: CertificateTypeList::raw_public_key_only(),
307            server_types: CertificateTypeList::raw_public_key_only(),
308            require_extensions: true,
309            fallback_client: CertificateType::RawPublicKey,
310            fallback_server: CertificateType::RawPublicKey,
311        }
312    }
313
314    /// Create preferences for X.509 only (legacy mode)
315    pub fn x509_only() -> Self {
316        Self {
317            client_types: CertificateTypeList::x509_only(),
318            server_types: CertificateTypeList::x509_only(),
319            require_extensions: false,
320            fallback_client: CertificateType::X509,
321            fallback_server: CertificateType::X509,
322        }
323    }
324
325    /// Negotiate certificate types with remote peer preferences
326    pub fn negotiate(
327        &self,
328        remote_client_types: Option<&CertificateTypeList>,
329        remote_server_types: Option<&CertificateTypeList>,
330    ) -> Result<NegotiationResult, TlsExtensionError> {
331        let client_cert_type = if let Some(remote_types) = remote_client_types {
332            self.client_types.negotiate(remote_types)
333                .ok_or_else(|| TlsExtensionError::NegotiationFailed {
334                    client_types: self.client_types.clone(),
335                    server_types: remote_types.clone(),
336                })?
337        } else if self.require_extensions {
338            return Err(TlsExtensionError::NegotiationFailed {
339                client_types: self.client_types.clone(),
340                server_types: CertificateTypeList::x509_only(),
341            });
342        } else {
343            self.fallback_client
344        };
345
346        let server_cert_type = if let Some(remote_types) = remote_server_types {
347            self.server_types.negotiate(remote_types)
348                .ok_or_else(|| TlsExtensionError::NegotiationFailed {
349                    client_types: self.server_types.clone(),
350                    server_types: remote_types.clone(),
351                })?
352        } else if self.require_extensions {
353            return Err(TlsExtensionError::NegotiationFailed {
354                client_types: self.server_types.clone(),
355                server_types: CertificateTypeList::x509_only(),
356            });
357        } else {
358            self.fallback_server
359        };
360
361        Ok(NegotiationResult::new(client_cert_type, server_cert_type))
362    }
363}
364
365impl Default for CertificateTypePreferences {
366    fn default() -> Self {
367        Self::prefer_raw_public_key()
368    }
369}
370
371/// Certificate type negotiation cache for performance optimization
372#[derive(Debug)]
373pub struct NegotiationCache {
374    /// Cache of negotiation results keyed by (local_prefs, remote_prefs) hash
375    cache: HashMap<u64, NegotiationResult>,
376    /// Maximum cache size to prevent unbounded growth
377    max_size: usize,
378}
379
380impl NegotiationCache {
381    /// Create a new negotiation cache
382    pub fn new(max_size: usize) -> Self {
383        Self {
384            cache: HashMap::with_capacity(max_size.min(1000)),
385            max_size,
386        }
387    }
388
389    /// Get cached negotiation result
390    pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
391        self.cache.get(&key)
392    }
393
394    /// Cache a negotiation result
395    pub fn insert(&mut self, key: u64, result: NegotiationResult) {
396        if self.cache.len() >= self.max_size {
397            // Simple eviction: remove oldest entry (first in iteration order)
398            if let Some(oldest_key) = self.cache.keys().next().copied() {
399                self.cache.remove(&oldest_key);
400            }
401        }
402        self.cache.insert(key, result);
403    }
404
405    /// Clear the cache
406    pub fn clear(&mut self) {
407        self.cache.clear();
408    }
409
410    /// Get cache statistics
411    pub fn stats(&self) -> (usize, usize) {
412        (self.cache.len(), self.max_size)
413    }
414}
415
416impl Default for NegotiationCache {
417    fn default() -> Self {
418        Self::new(1000)
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_certificate_type_conversion() {
428        assert_eq!(CertificateType::X509.to_u8(), 0);
429        assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
430        
431        assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
432        assert_eq!(CertificateType::from_u8(2).unwrap(), CertificateType::RawPublicKey);
433        
434        assert!(CertificateType::from_u8(1).is_err());
435        assert!(CertificateType::from_u8(255).is_err());
436    }
437
438    #[test]
439    fn test_certificate_type_list_creation() {
440        let list = CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509]).unwrap();
441        assert_eq!(list.types.len(), 2);
442        assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
443        assert!(list.supports_raw_public_key());
444        assert!(list.supports_x509());
445
446        // Test empty list error
447        assert!(CertificateTypeList::new(vec![]).is_err());
448
449        // Test duplicate error
450        assert!(CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err());
451    }
452
453    #[test]
454    fn test_certificate_type_list_serialization() {
455        let list = CertificateTypeList::prefer_raw_public_key();
456        let bytes = list.to_bytes();
457        assert_eq!(bytes, vec![2, 2, 0]); // length=2, RPK=2, X509=0
458
459        let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
460        assert_eq!(parsed, list);
461    }
462
463    #[test]
464    fn test_certificate_type_list_negotiation() {
465        let rpk_only = CertificateTypeList::raw_public_key_only();
466        let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
467        let x509_only = CertificateTypeList::x509_only();
468
469        // RPK only with prefer RPK should negotiate to RPK
470        assert_eq!(rpk_only.negotiate(&prefer_rpk).unwrap(), CertificateType::RawPublicKey);
471        
472        // Prefer RPK with X509 only should negotiate to X509
473        assert_eq!(prefer_rpk.negotiate(&x509_only).unwrap(), CertificateType::X509);
474        
475        // RPK only with X509 only should fail
476        assert!(rpk_only.negotiate(&x509_only).is_none());
477    }
478
479    #[test]
480    fn test_preferences_negotiation() {
481        let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
482        let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
483
484        let result = rpk_prefs.negotiate(
485            Some(&mixed_prefs.client_types),
486            Some(&mixed_prefs.server_types),
487        ).unwrap();
488
489        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
490        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
491        assert!(result.is_raw_public_key_only());
492    }
493
494    #[test]
495    fn test_negotiation_cache() {
496        let mut cache = NegotiationCache::new(2);
497        let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
498
499        assert!(cache.get(123).is_none());
500        
501        cache.insert(123, result.clone());
502        assert_eq!(cache.get(123).unwrap(), &result);
503
504        // Test that cache size is limited
505        cache.insert(456, result.clone());
506        assert_eq!(cache.cache.len(), 2); // Should have 2 entries
507        
508        cache.insert(789, result.clone());
509        assert_eq!(cache.cache.len(), 2); // Should still have 2 entries after eviction
510        
511        // At least one of the new entries should be present
512        assert!(cache.get(456).is_some() || cache.get(789).is_some());
513    }
514}