ant_quic/crypto/
tls_extensions.rs

1//! TLS Extensions for RFC 7250 Raw Public Keys Certificate Type Negotiation
2//!
3//! This module implements the TLS 1.3 extensions defined in RFC 7250 Section 4.2:
4//! - client_certificate_type (47): Client's certificate type preferences
5//! - server_certificate_type (48): Server's certificate type preferences
6//!
7//! These extensions enable proper negotiation of certificate types during TLS handshake,
8//! allowing clients and servers to indicate support for Raw Public Keys (value 2)
9//! in addition to traditional X.509 certificates (value 0).
10
11use std::{
12    collections::HashMap,
13    fmt::{self, Debug},
14};
15
16/// Certificate type values as defined in RFC 7250 and IANA registry
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
18#[repr(u8)]
19pub enum CertificateType {
20    /// X.509 certificate (traditional PKI certificates)
21    X509 = 0,
22    /// Raw Public Key (RFC 7250)
23    RawPublicKey = 2,
24}
25
26impl CertificateType {
27    /// Parse certificate type from wire format
28    pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
29        match value {
30            0 => Ok(Self::X509),
31            2 => Ok(Self::RawPublicKey),
32            _ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
33        }
34    }
35
36    /// Convert certificate type to wire format
37    pub fn to_u8(self) -> u8 {
38        self as u8
39    }
40
41    /// Check if this certificate type is Raw Public Key
42    pub fn is_raw_public_key(self) -> bool {
43        matches!(self, Self::RawPublicKey)
44    }
45
46    /// Check if this certificate type is X.509
47    pub fn is_x509(self) -> bool {
48        matches!(self, Self::X509)
49    }
50}
51
52impl fmt::Display for CertificateType {
53    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54        match self {
55            Self::X509 => write!(f, "X.509"),
56            Self::RawPublicKey => write!(f, "RawPublicKey"),
57        }
58    }
59}
60
61/// Certificate type preference list for negotiation
62#[derive(Debug, Clone, PartialEq, Eq)]
63pub struct CertificateTypeList {
64    /// Ordered list of certificate types by preference (most preferred first)
65    pub types: Vec<CertificateType>,
66}
67
68impl CertificateTypeList {
69    /// Create a new certificate type list
70    pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
71        if types.is_empty() {
72            return Err(TlsExtensionError::EmptyCertificateTypeList);
73        }
74        if types.len() > 255 {
75            return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
76        }
77
78        // Check for duplicates
79        let mut seen = std::collections::HashSet::new();
80        for cert_type in &types {
81            if !seen.insert(*cert_type) {
82                return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
83            }
84        }
85
86        Ok(Self { types })
87    }
88
89    /// Create a Raw Public Key only preference list
90    pub fn raw_public_key_only() -> Self {
91        Self {
92            types: vec![CertificateType::RawPublicKey],
93        }
94    }
95
96    /// Create a preference list favoring Raw Public Keys with X.509 fallback
97    pub fn prefer_raw_public_key() -> Self {
98        Self {
99            types: vec![CertificateType::RawPublicKey, CertificateType::X509],
100        }
101    }
102
103    /// Create an X.509 only preference list
104    pub fn x509_only() -> Self {
105        Self {
106            types: vec![CertificateType::X509],
107        }
108    }
109
110    /// Get the most preferred certificate type
111    pub fn most_preferred(&self) -> CertificateType {
112        self.types[0]
113    }
114
115    /// Check if Raw Public Key is supported
116    pub fn supports_raw_public_key(&self) -> bool {
117        self.types.contains(&CertificateType::RawPublicKey)
118    }
119
120    /// Check if X.509 is supported
121    pub fn supports_x509(&self) -> bool {
122        self.types.contains(&CertificateType::X509)
123    }
124
125    /// Find the best common certificate type between two preference lists
126    pub fn negotiate(&self, other: &Self) -> Option<CertificateType> {
127        // Find the first certificate type in our preference list that is also supported by the other party
128        for cert_type in &self.types {
129            if other.types.contains(cert_type) {
130                return Some(*cert_type);
131            }
132        }
133        None
134    }
135
136    /// Serialize to wire format (length-prefixed list)
137    pub fn to_bytes(&self) -> Vec<u8> {
138        let mut bytes = Vec::with_capacity(1 + self.types.len());
139        bytes.push(self.types.len() as u8);
140        for cert_type in &self.types {
141            bytes.push(cert_type.to_u8());
142        }
143        bytes
144    }
145
146    /// Parse from wire format
147    pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
148        if bytes.is_empty() {
149            return Err(TlsExtensionError::InvalidExtensionData(
150                "Empty certificate type list".to_string(),
151            ));
152        }
153
154        let length = bytes[0] as usize;
155        if length == 0 {
156            return Err(TlsExtensionError::EmptyCertificateTypeList);
157        }
158        if length > 255 {
159            return Err(TlsExtensionError::CertificateTypeListTooLong(length));
160        }
161        if bytes.len() != 1 + length {
162            return Err(TlsExtensionError::InvalidExtensionData(format!(
163                "Certificate type list length mismatch: expected {}, got {}",
164                1 + length,
165                bytes.len()
166            )));
167        }
168
169        let mut types = Vec::with_capacity(length);
170        for i in 1..=length {
171            let cert_type = CertificateType::from_u8(bytes[i])?;
172            types.push(cert_type);
173        }
174
175        Self::new(types)
176    }
177}
178
179/// TLS extension IDs for certificate type negotiation (RFC 7250)
180pub mod extension_ids {
181    /// Client certificate type extension ID
182    pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
183    /// Server certificate type extension ID  
184    pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
185}
186
187/// Errors that can occur during TLS extension processing
188#[derive(Debug, Clone)]
189pub enum TlsExtensionError {
190    /// Unsupported certificate type value
191    UnsupportedCertificateType(u8),
192    /// Empty certificate type list
193    EmptyCertificateTypeList,
194    /// Certificate type list too long (>255 entries)
195    CertificateTypeListTooLong(usize),
196    /// Duplicate certificate type in list
197    DuplicateCertificateType(CertificateType),
198    /// Invalid extension data format
199    InvalidExtensionData(String),
200    /// Certificate type negotiation failed
201    NegotiationFailed {
202        client_types: CertificateTypeList,
203        server_types: CertificateTypeList,
204    },
205    /// Extension already registered
206    ExtensionAlreadyRegistered(u16),
207    /// rustls integration error
208    RustlsError(String),
209}
210
211impl fmt::Display for TlsExtensionError {
212    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213        match self {
214            Self::UnsupportedCertificateType(value) => {
215                write!(f, "Unsupported certificate type: {value}")
216            }
217            Self::EmptyCertificateTypeList => {
218                write!(f, "Certificate type list cannot be empty")
219            }
220            Self::CertificateTypeListTooLong(len) => {
221                write!(f, "Certificate type list too long: {len} (max 255)")
222            }
223            Self::DuplicateCertificateType(cert_type) => {
224                write!(f, "Duplicate certificate type: {cert_type}")
225            }
226            Self::InvalidExtensionData(msg) => {
227                write!(f, "Invalid extension data: {msg}")
228            }
229            Self::NegotiationFailed {
230                client_types,
231                server_types,
232            } => {
233                write!(
234                    f,
235                    "Certificate type negotiation failed: client={client_types:?}, server={server_types:?}"
236                )
237            }
238            Self::ExtensionAlreadyRegistered(id) => {
239                write!(f, "Extension already registered: {id}")
240            }
241            Self::RustlsError(msg) => {
242                write!(f, "rustls error: {msg}")
243            }
244        }
245    }
246}
247
248impl std::error::Error for TlsExtensionError {}
249
250/// Certificate type negotiation result
251#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
252pub struct NegotiationResult {
253    /// Negotiated client certificate type
254    pub client_cert_type: CertificateType,
255    /// Negotiated server certificate type
256    pub server_cert_type: CertificateType,
257}
258
259impl NegotiationResult {
260    /// Create a new negotiation result
261    pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
262        Self {
263            client_cert_type,
264            server_cert_type,
265        }
266    }
267
268    /// Check if Raw Public Keys are used for both client and server
269    pub fn is_raw_public_key_only(&self) -> bool {
270        self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
271    }
272
273    /// Check if X.509 certificates are used for both client and server
274    pub fn is_x509_only(&self) -> bool {
275        self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
276    }
277
278    /// Check if this is a mixed deployment (one RPK, one X.509)
279    pub fn is_mixed(&self) -> bool {
280        !self.is_raw_public_key_only() && !self.is_x509_only()
281    }
282}
283
284/// Certificate type negotiation preferences and state
285#[derive(Debug, Clone, PartialEq, Eq)]
286pub struct CertificateTypePreferences {
287    /// Client certificate type preferences (what types we support for client auth)
288    pub client_types: CertificateTypeList,
289    /// Server certificate type preferences (what types we support for server auth)
290    pub server_types: CertificateTypeList,
291    /// Whether to require certificate type extensions (strict mode)
292    pub require_extensions: bool,
293    /// Default fallback certificate types if negotiation fails
294    pub fallback_client: CertificateType,
295    pub fallback_server: CertificateType,
296}
297
298impl CertificateTypePreferences {
299    /// Create preferences favoring Raw Public Keys
300    pub fn prefer_raw_public_key() -> Self {
301        Self {
302            client_types: CertificateTypeList::prefer_raw_public_key(),
303            server_types: CertificateTypeList::prefer_raw_public_key(),
304            require_extensions: false,
305            fallback_client: CertificateType::X509,
306            fallback_server: CertificateType::X509,
307        }
308    }
309
310    /// Create preferences for Raw Public Key only
311    pub fn raw_public_key_only() -> Self {
312        Self {
313            client_types: CertificateTypeList::raw_public_key_only(),
314            server_types: CertificateTypeList::raw_public_key_only(),
315            require_extensions: true,
316            fallback_client: CertificateType::RawPublicKey,
317            fallback_server: CertificateType::RawPublicKey,
318        }
319    }
320
321    /// Create preferences for X.509 only (legacy mode)
322    pub fn x509_only() -> Self {
323        Self {
324            client_types: CertificateTypeList::x509_only(),
325            server_types: CertificateTypeList::x509_only(),
326            require_extensions: false,
327            fallback_client: CertificateType::X509,
328            fallback_server: CertificateType::X509,
329        }
330    }
331
332    /// Negotiate certificate types with remote peer preferences
333    pub fn negotiate(
334        &self,
335        remote_client_types: Option<&CertificateTypeList>,
336        remote_server_types: Option<&CertificateTypeList>,
337    ) -> Result<NegotiationResult, TlsExtensionError> {
338        let client_cert_type = if let Some(remote_types) = remote_client_types {
339            self.client_types.negotiate(remote_types).ok_or_else(|| {
340                TlsExtensionError::NegotiationFailed {
341                    client_types: self.client_types.clone(),
342                    server_types: remote_types.clone(),
343                }
344            })?
345        } else if self.require_extensions {
346            return Err(TlsExtensionError::NegotiationFailed {
347                client_types: self.client_types.clone(),
348                server_types: CertificateTypeList::x509_only(),
349            });
350        } else {
351            self.fallback_client
352        };
353
354        let server_cert_type = if let Some(remote_types) = remote_server_types {
355            self.server_types.negotiate(remote_types).ok_or_else(|| {
356                TlsExtensionError::NegotiationFailed {
357                    client_types: self.server_types.clone(),
358                    server_types: remote_types.clone(),
359                }
360            })?
361        } else if self.require_extensions {
362            return Err(TlsExtensionError::NegotiationFailed {
363                client_types: self.server_types.clone(),
364                server_types: CertificateTypeList::x509_only(),
365            });
366        } else {
367            self.fallback_server
368        };
369
370        Ok(NegotiationResult::new(client_cert_type, server_cert_type))
371    }
372}
373
374impl Default for CertificateTypePreferences {
375    fn default() -> Self {
376        Self::prefer_raw_public_key()
377    }
378}
379
380/// Certificate type negotiation cache for performance optimization
381#[derive(Debug)]
382pub struct NegotiationCache {
383    /// Cache of negotiation results keyed by (local_prefs, remote_prefs) hash
384    cache: HashMap<u64, NegotiationResult>,
385    /// Maximum cache size to prevent unbounded growth
386    max_size: usize,
387}
388
389impl NegotiationCache {
390    /// Create a new negotiation cache
391    pub fn new(max_size: usize) -> Self {
392        Self {
393            cache: HashMap::with_capacity(max_size.min(1000)),
394            max_size,
395        }
396    }
397
398    /// Get cached negotiation result
399    pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
400        self.cache.get(&key)
401    }
402
403    /// Cache a negotiation result
404    pub fn insert(&mut self, key: u64, result: NegotiationResult) {
405        if self.cache.len() >= self.max_size {
406            // Simple eviction: remove oldest entry (first in iteration order)
407            if let Some(oldest_key) = self.cache.keys().next().copied() {
408                self.cache.remove(&oldest_key);
409            }
410        }
411        self.cache.insert(key, result);
412    }
413
414    /// Clear the cache
415    pub fn clear(&mut self) {
416        self.cache.clear();
417    }
418
419    /// Get cache statistics
420    pub fn stats(&self) -> (usize, usize) {
421        (self.cache.len(), self.max_size)
422    }
423}
424
425impl Default for NegotiationCache {
426    fn default() -> Self {
427        Self::new(1000)
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_certificate_type_conversion() {
437        assert_eq!(CertificateType::X509.to_u8(), 0);
438        assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
439
440        assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
441        assert_eq!(
442            CertificateType::from_u8(2).unwrap(),
443            CertificateType::RawPublicKey
444        );
445
446        assert!(CertificateType::from_u8(1).is_err());
447        assert!(CertificateType::from_u8(255).is_err());
448    }
449
450    #[test]
451    fn test_certificate_type_list_creation() {
452        let list =
453            CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509])
454                .unwrap();
455        assert_eq!(list.types.len(), 2);
456        assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
457        assert!(list.supports_raw_public_key());
458        assert!(list.supports_x509());
459
460        // Test empty list error
461        assert!(CertificateTypeList::new(vec![]).is_err());
462
463        // Test duplicate error
464        assert!(
465            CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err()
466        );
467    }
468
469    #[test]
470    fn test_certificate_type_list_serialization() {
471        let list = CertificateTypeList::prefer_raw_public_key();
472        let bytes = list.to_bytes();
473        assert_eq!(bytes, vec![2, 2, 0]); // length=2, RPK=2, X509=0
474
475        let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
476        assert_eq!(parsed, list);
477    }
478
479    #[test]
480    fn test_certificate_type_list_negotiation() {
481        let rpk_only = CertificateTypeList::raw_public_key_only();
482        let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
483        let x509_only = CertificateTypeList::x509_only();
484
485        // RPK only with prefer RPK should negotiate to RPK
486        assert_eq!(
487            rpk_only.negotiate(&prefer_rpk).unwrap(),
488            CertificateType::RawPublicKey
489        );
490
491        // Prefer RPK with X509 only should negotiate to X509
492        assert_eq!(
493            prefer_rpk.negotiate(&x509_only).unwrap(),
494            CertificateType::X509
495        );
496
497        // RPK only with X509 only should fail
498        assert!(rpk_only.negotiate(&x509_only).is_none());
499    }
500
501    #[test]
502    fn test_preferences_negotiation() {
503        let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
504        let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
505
506        let result = rpk_prefs
507            .negotiate(
508                Some(&mixed_prefs.client_types),
509                Some(&mixed_prefs.server_types),
510            )
511            .unwrap();
512
513        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
514        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
515        assert!(result.is_raw_public_key_only());
516    }
517
518    #[test]
519    fn test_negotiation_cache() {
520        let mut cache = NegotiationCache::new(2);
521        let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
522
523        assert!(cache.get(123).is_none());
524
525        cache.insert(123, result.clone());
526        assert_eq!(cache.get(123).unwrap(), &result);
527
528        // Test that cache size is limited
529        cache.insert(456, result.clone());
530        assert_eq!(cache.cache.len(), 2); // Should have 2 entries
531
532        cache.insert(789, result.clone());
533        assert_eq!(cache.cache.len(), 2); // Should still have 2 entries after eviction
534
535        // At least one of the new entries should be present
536        assert!(cache.get(456).is_some() || cache.get(789).is_some());
537    }
538}