ant_quic/crypto/
tls_extensions.rs

1//! TLS Extensions for RFC 7250 Raw Public Keys Certificate Type Negotiation
2//!
3//! This module implements the TLS 1.3 extensions defined in RFC 7250 Section 4.2:
4//! - client_certificate_type (47): Client's certificate type preferences
5//! - server_certificate_type (48): Server's certificate type preferences
6//!
7//! These extensions enable proper negotiation of certificate types during TLS handshake,
8//! allowing clients and servers to indicate support for Raw Public Keys (value 2)
9//! in addition to traditional X.509 certificates (value 0).
10
11use std::{
12    collections::HashMap,
13    fmt::{self, Debug},
14};
15
16/// Certificate type values as defined in RFC 7250 and IANA registry
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
18#[repr(u8)]
19pub enum CertificateType {
20    /// X.509 certificate (traditional PKI certificates)
21    X509 = 0,
22    /// Raw Public Key (RFC 7250)
23    RawPublicKey = 2,
24}
25
26impl CertificateType {
27    /// Parse certificate type from wire format
28    pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
29        match value {
30            0 => Ok(CertificateType::X509),
31            2 => Ok(CertificateType::RawPublicKey),
32            _ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
33        }
34    }
35
36    /// Convert certificate type to wire format
37    pub fn to_u8(self) -> u8 {
38        self as u8
39    }
40
41    /// Check if this certificate type is Raw Public Key
42    pub fn is_raw_public_key(self) -> bool {
43        matches!(self, CertificateType::RawPublicKey)
44    }
45
46    /// Check if this certificate type is X.509
47    pub fn is_x509(self) -> bool {
48        matches!(self, CertificateType::X509)
49    }
50}
51
52impl fmt::Display for CertificateType {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            CertificateType::X509 => write!(f, "X.509"),
56            CertificateType::RawPublicKey => write!(f, "RawPublicKey"),
57        }
58    }
59}
60
61/// Certificate type preference list for negotiation
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub struct CertificateTypeList {
64    /// Ordered list of certificate types by preference (most preferred first)
65    pub types: Vec<CertificateType>,
66}
67
68impl CertificateTypeList {
69    /// Create a new certificate type list
70    pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
71        if types.is_empty() {
72            return Err(TlsExtensionError::EmptyCertificateTypeList);
73        }
74        if types.len() > 255 {
75            return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
76        }
77
78        // Check for duplicates
79        let mut seen = std::collections::HashSet::new();
80        for cert_type in &types {
81            if !seen.insert(*cert_type) {
82                return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
83            }
84        }
85
86        Ok(CertificateTypeList { types })
87    }
88
89    /// Create a Raw Public Key only preference list
90    pub fn raw_public_key_only() -> Self {
91        CertificateTypeList {
92            types: vec![CertificateType::RawPublicKey],
93        }
94    }
95
96    /// Create a preference list favoring Raw Public Keys with X.509 fallback
97    pub fn prefer_raw_public_key() -> Self {
98        CertificateTypeList {
99            types: vec![CertificateType::RawPublicKey, CertificateType::X509],
100        }
101    }
102
103    /// Create an X.509 only preference list
104    pub fn x509_only() -> Self {
105        CertificateTypeList {
106            types: vec![CertificateType::X509],
107        }
108    }
109
110    /// Get the most preferred certificate type
111    pub fn most_preferred(&self) -> CertificateType {
112        self.types[0]
113    }
114
115    /// Check if Raw Public Key is supported
116    pub fn supports_raw_public_key(&self) -> bool {
117        self.types.contains(&CertificateType::RawPublicKey)
118    }
119
120    /// Check if X.509 is supported
121    pub fn supports_x509(&self) -> bool {
122        self.types.contains(&CertificateType::X509)
123    }
124
125    /// Find the best common certificate type between two preference lists
126    pub fn negotiate(&self, other: &CertificateTypeList) -> Option<CertificateType> {
127        // Find the first certificate type in our preference list that is also supported by the other party
128        for cert_type in &self.types {
129            if other.types.contains(cert_type) {
130                return Some(*cert_type);
131            }
132        }
133        None
134    }
135
136    /// Serialize to wire format (length-prefixed list)
137    pub fn to_bytes(&self) -> Vec<u8> {
138        let mut bytes = Vec::with_capacity(1 + self.types.len());
139        bytes.push(self.types.len() as u8);
140        for cert_type in &self.types {
141            bytes.push(cert_type.to_u8());
142        }
143        bytes
144    }
145
146    /// Parse from wire format
147    pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
148        if bytes.is_empty() {
149            return Err(TlsExtensionError::InvalidExtensionData(
150                "Empty certificate type list".to_string(),
151            ));
152        }
153
154        let length = bytes[0] as usize;
155        if length == 0 {
156            return Err(TlsExtensionError::EmptyCertificateTypeList);
157        }
158        if length > 255 {
159            return Err(TlsExtensionError::CertificateTypeListTooLong(length));
160        }
161        if bytes.len() != 1 + length {
162            return Err(TlsExtensionError::InvalidExtensionData(format!(
163                "Certificate type list length mismatch: expected {}, got {}",
164                1 + length,
165                bytes.len()
166            )));
167        }
168
169        let mut types = Vec::with_capacity(length);
170        for i in 1..=length {
171            let cert_type = CertificateType::from_u8(bytes[i])?;
172            types.push(cert_type);
173        }
174
175        CertificateTypeList::new(types)
176    }
177}
178
179/// TLS extension IDs for certificate type negotiation (RFC 7250)
180pub mod extension_ids {
181    /// Client certificate type extension ID
182    pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
183    /// Server certificate type extension ID  
184    pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
185}
186
187/// Errors that can occur during TLS extension processing
188#[derive(Debug, Clone)]
189pub enum TlsExtensionError {
190    /// Unsupported certificate type value
191    UnsupportedCertificateType(u8),
192    /// Empty certificate type list
193    EmptyCertificateTypeList,
194    /// Certificate type list too long (>255 entries)
195    CertificateTypeListTooLong(usize),
196    /// Duplicate certificate type in list
197    DuplicateCertificateType(CertificateType),
198    /// Invalid extension data format
199    InvalidExtensionData(String),
200    /// Certificate type negotiation failed
201    NegotiationFailed {
202        client_types: CertificateTypeList,
203        server_types: CertificateTypeList,
204    },
205    /// Extension already registered
206    ExtensionAlreadyRegistered(u16),
207    /// rustls integration error
208    RustlsError(String),
209}
210
211impl fmt::Display for TlsExtensionError {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        match self {
214            TlsExtensionError::UnsupportedCertificateType(value) => {
215                write!(f, "Unsupported certificate type: {}", value)
216            }
217            TlsExtensionError::EmptyCertificateTypeList => {
218                write!(f, "Certificate type list cannot be empty")
219            }
220            TlsExtensionError::CertificateTypeListTooLong(len) => {
221                write!(f, "Certificate type list too long: {} (max 255)", len)
222            }
223            TlsExtensionError::DuplicateCertificateType(cert_type) => {
224                write!(f, "Duplicate certificate type: {}", cert_type)
225            }
226            TlsExtensionError::InvalidExtensionData(msg) => {
227                write!(f, "Invalid extension data: {}", msg)
228            }
229            TlsExtensionError::NegotiationFailed {
230                client_types,
231                server_types,
232            } => {
233                write!(
234                    f,
235                    "Certificate type negotiation failed: client={:?}, server={:?}",
236                    client_types, server_types
237                )
238            }
239            TlsExtensionError::ExtensionAlreadyRegistered(id) => {
240                write!(f, "Extension already registered: {}", id)
241            }
242            TlsExtensionError::RustlsError(msg) => {
243                write!(f, "rustls error: {}", msg)
244            }
245        }
246    }
247}
248
249impl std::error::Error for TlsExtensionError {}
250
251/// Certificate type negotiation result
252#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
253pub struct NegotiationResult {
254    /// Negotiated client certificate type
255    pub client_cert_type: CertificateType,
256    /// Negotiated server certificate type
257    pub server_cert_type: CertificateType,
258}
259
260impl NegotiationResult {
261    /// Create a new negotiation result
262    pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
263        Self {
264            client_cert_type,
265            server_cert_type,
266        }
267    }
268
269    /// Check if Raw Public Keys are used for both client and server
270    pub fn is_raw_public_key_only(&self) -> bool {
271        self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
272    }
273
274    /// Check if X.509 certificates are used for both client and server
275    pub fn is_x509_only(&self) -> bool {
276        self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
277    }
278
279    /// Check if this is a mixed deployment (one RPK, one X.509)
280    pub fn is_mixed(&self) -> bool {
281        !self.is_raw_public_key_only() && !self.is_x509_only()
282    }
283}
284
285/// Certificate type negotiation preferences and state
286#[derive(Debug, Clone, PartialEq, Eq)]
287pub struct CertificateTypePreferences {
288    /// Client certificate type preferences (what types we support for client auth)
289    pub client_types: CertificateTypeList,
290    /// Server certificate type preferences (what types we support for server auth)
291    pub server_types: CertificateTypeList,
292    /// Whether to require certificate type extensions (strict mode)
293    pub require_extensions: bool,
294    /// Default fallback certificate types if negotiation fails
295    pub fallback_client: CertificateType,
296    pub fallback_server: CertificateType,
297}
298
299impl CertificateTypePreferences {
300    /// Create preferences favoring Raw Public Keys
301    pub fn prefer_raw_public_key() -> Self {
302        Self {
303            client_types: CertificateTypeList::prefer_raw_public_key(),
304            server_types: CertificateTypeList::prefer_raw_public_key(),
305            require_extensions: false,
306            fallback_client: CertificateType::X509,
307            fallback_server: CertificateType::X509,
308        }
309    }
310
311    /// Create preferences for Raw Public Key only
312    pub fn raw_public_key_only() -> Self {
313        Self {
314            client_types: CertificateTypeList::raw_public_key_only(),
315            server_types: CertificateTypeList::raw_public_key_only(),
316            require_extensions: true,
317            fallback_client: CertificateType::RawPublicKey,
318            fallback_server: CertificateType::RawPublicKey,
319        }
320    }
321
322    /// Create preferences for X.509 only (legacy mode)
323    pub fn x509_only() -> Self {
324        Self {
325            client_types: CertificateTypeList::x509_only(),
326            server_types: CertificateTypeList::x509_only(),
327            require_extensions: false,
328            fallback_client: CertificateType::X509,
329            fallback_server: CertificateType::X509,
330        }
331    }
332
333    /// Negotiate certificate types with remote peer preferences
334    pub fn negotiate(
335        &self,
336        remote_client_types: Option<&CertificateTypeList>,
337        remote_server_types: Option<&CertificateTypeList>,
338    ) -> Result<NegotiationResult, TlsExtensionError> {
339        let client_cert_type = if let Some(remote_types) = remote_client_types {
340            self.client_types.negotiate(remote_types).ok_or_else(|| {
341                TlsExtensionError::NegotiationFailed {
342                    client_types: self.client_types.clone(),
343                    server_types: remote_types.clone(),
344                }
345            })?
346        } else if self.require_extensions {
347            return Err(TlsExtensionError::NegotiationFailed {
348                client_types: self.client_types.clone(),
349                server_types: CertificateTypeList::x509_only(),
350            });
351        } else {
352            self.fallback_client
353        };
354
355        let server_cert_type = if let Some(remote_types) = remote_server_types {
356            self.server_types.negotiate(remote_types).ok_or_else(|| {
357                TlsExtensionError::NegotiationFailed {
358                    client_types: self.server_types.clone(),
359                    server_types: remote_types.clone(),
360                }
361            })?
362        } else if self.require_extensions {
363            return Err(TlsExtensionError::NegotiationFailed {
364                client_types: self.server_types.clone(),
365                server_types: CertificateTypeList::x509_only(),
366            });
367        } else {
368            self.fallback_server
369        };
370
371        Ok(NegotiationResult::new(client_cert_type, server_cert_type))
372    }
373}
374
375impl Default for CertificateTypePreferences {
376    fn default() -> Self {
377        Self::prefer_raw_public_key()
378    }
379}
380
381/// Certificate type negotiation cache for performance optimization
382#[derive(Debug)]
383pub struct NegotiationCache {
384    /// Cache of negotiation results keyed by (local_prefs, remote_prefs) hash
385    cache: HashMap<u64, NegotiationResult>,
386    /// Maximum cache size to prevent unbounded growth
387    max_size: usize,
388}
389
390impl NegotiationCache {
391    /// Create a new negotiation cache
392    pub fn new(max_size: usize) -> Self {
393        Self {
394            cache: HashMap::with_capacity(max_size.min(1000)),
395            max_size,
396        }
397    }
398
399    /// Get cached negotiation result
400    pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
401        self.cache.get(&key)
402    }
403
404    /// Cache a negotiation result
405    pub fn insert(&mut self, key: u64, result: NegotiationResult) {
406        if self.cache.len() >= self.max_size {
407            // Simple eviction: remove oldest entry (first in iteration order)
408            if let Some(oldest_key) = self.cache.keys().next().copied() {
409                self.cache.remove(&oldest_key);
410            }
411        }
412        self.cache.insert(key, result);
413    }
414
415    /// Clear the cache
416    pub fn clear(&mut self) {
417        self.cache.clear();
418    }
419
420    /// Get cache statistics
421    pub fn stats(&self) -> (usize, usize) {
422        (self.cache.len(), self.max_size)
423    }
424}
425
426impl Default for NegotiationCache {
427    fn default() -> Self {
428        Self::new(1000)
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435
436    #[test]
437    fn test_certificate_type_conversion() {
438        assert_eq!(CertificateType::X509.to_u8(), 0);
439        assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
440
441        assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
442        assert_eq!(
443            CertificateType::from_u8(2).unwrap(),
444            CertificateType::RawPublicKey
445        );
446
447        assert!(CertificateType::from_u8(1).is_err());
448        assert!(CertificateType::from_u8(255).is_err());
449    }
450
451    #[test]
452    fn test_certificate_type_list_creation() {
453        let list =
454            CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509])
455                .unwrap();
456        assert_eq!(list.types.len(), 2);
457        assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
458        assert!(list.supports_raw_public_key());
459        assert!(list.supports_x509());
460
461        // Test empty list error
462        assert!(CertificateTypeList::new(vec![]).is_err());
463
464        // Test duplicate error
465        assert!(
466            CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err()
467        );
468    }
469
470    #[test]
471    fn test_certificate_type_list_serialization() {
472        let list = CertificateTypeList::prefer_raw_public_key();
473        let bytes = list.to_bytes();
474        assert_eq!(bytes, vec![2, 2, 0]); // length=2, RPK=2, X509=0
475
476        let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
477        assert_eq!(parsed, list);
478    }
479
480    #[test]
481    fn test_certificate_type_list_negotiation() {
482        let rpk_only = CertificateTypeList::raw_public_key_only();
483        let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
484        let x509_only = CertificateTypeList::x509_only();
485
486        // RPK only with prefer RPK should negotiate to RPK
487        assert_eq!(
488            rpk_only.negotiate(&prefer_rpk).unwrap(),
489            CertificateType::RawPublicKey
490        );
491
492        // Prefer RPK with X509 only should negotiate to X509
493        assert_eq!(
494            prefer_rpk.negotiate(&x509_only).unwrap(),
495            CertificateType::X509
496        );
497
498        // RPK only with X509 only should fail
499        assert!(rpk_only.negotiate(&x509_only).is_none());
500    }
501
502    #[test]
503    fn test_preferences_negotiation() {
504        let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
505        let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
506
507        let result = rpk_prefs
508            .negotiate(
509                Some(&mixed_prefs.client_types),
510                Some(&mixed_prefs.server_types),
511            )
512            .unwrap();
513
514        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
515        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
516        assert!(result.is_raw_public_key_only());
517    }
518
519    #[test]
520    fn test_negotiation_cache() {
521        let mut cache = NegotiationCache::new(2);
522        let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
523
524        assert!(cache.get(123).is_none());
525
526        cache.insert(123, result.clone());
527        assert_eq!(cache.get(123).unwrap(), &result);
528
529        // Test that cache size is limited
530        cache.insert(456, result.clone());
531        assert_eq!(cache.cache.len(), 2); // Should have 2 entries
532
533        cache.insert(789, result.clone());
534        assert_eq!(cache.cache.len(), 2); // Should still have 2 entries after eviction
535
536        // At least one of the new entries should be present
537        assert!(cache.get(456).is_some() || cache.get(789).is_some());
538    }
539}