ant_quic/crypto/
tls_extensions.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7#![allow(missing_docs)]
8
9//! TLS Extensions for RFC 7250 Raw Public Keys Certificate Type Negotiation
10//!
11//! This module implements the TLS 1.3 extensions defined in RFC 7250 Section 4.2:
12//! - client_certificate_type (47): Client's certificate type preferences
13//! - server_certificate_type (48): Server's certificate type preferences
14//!
15//! These extensions enable proper negotiation of certificate types during TLS handshake,
16//! allowing clients and servers to indicate support for Raw Public Keys (value 2)
17//! in addition to traditional X.509 certificates (value 0).
18
19use std::{
20    collections::HashMap,
21    fmt::{self, Debug},
22};
23
24/// Certificate type values as defined in RFC 7250 and IANA registry
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
26#[repr(u8)]
27pub enum CertificateType {
28    /// X.509 certificate (traditional PKI certificates)
29    X509 = 0,
30    /// Raw Public Key (RFC 7250)
31    RawPublicKey = 2,
32}
33
34impl CertificateType {
35    /// Parse certificate type from wire format
36    pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
37        match value {
38            0 => Ok(Self::X509),
39            2 => Ok(Self::RawPublicKey),
40            _ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
41        }
42    }
43
44    /// Convert certificate type to wire format
45    pub fn to_u8(self) -> u8 {
46        self as u8
47    }
48
49    /// Check if this certificate type is Raw Public Key
50    pub fn is_raw_public_key(self) -> bool {
51        matches!(self, Self::RawPublicKey)
52    }
53
54    /// Check if this certificate type is X.509
55    pub fn is_x509(self) -> bool {
56        matches!(self, Self::X509)
57    }
58}
59
60impl fmt::Display for CertificateType {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        match self {
63            Self::X509 => write!(f, "X.509"),
64            Self::RawPublicKey => write!(f, "RawPublicKey"),
65        }
66    }
67}
68
69/// Certificate type preference list for negotiation
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct CertificateTypeList {
72    /// Ordered list of certificate types by preference (most preferred first)
73    pub types: Vec<CertificateType>,
74}
75
76impl CertificateTypeList {
77    /// Create a new certificate type list
78    pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
79        if types.is_empty() {
80            return Err(TlsExtensionError::EmptyCertificateTypeList);
81        }
82        if types.len() > 255 {
83            return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
84        }
85
86        // Check for duplicates
87        let mut seen = std::collections::HashSet::new();
88        for cert_type in &types {
89            if !seen.insert(*cert_type) {
90                return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
91            }
92        }
93
94        Ok(Self { types })
95    }
96
97    /// Create a Raw Public Key only preference list
98    pub fn raw_public_key_only() -> Self {
99        Self {
100            types: vec![CertificateType::RawPublicKey],
101        }
102    }
103
104    /// Create a preference list favoring Raw Public Keys with X.509 fallback
105    pub fn prefer_raw_public_key() -> Self {
106        Self {
107            types: vec![CertificateType::RawPublicKey, CertificateType::X509],
108        }
109    }
110
111    /// Create an X.509 only preference list
112    pub fn x509_only() -> Self {
113        Self {
114            types: vec![CertificateType::X509],
115        }
116    }
117
118    /// Get the most preferred certificate type
119    pub fn most_preferred(&self) -> CertificateType {
120        self.types[0]
121    }
122
123    /// Check if Raw Public Key is supported
124    pub fn supports_raw_public_key(&self) -> bool {
125        self.types.contains(&CertificateType::RawPublicKey)
126    }
127
128    /// Check if X.509 is supported
129    pub fn supports_x509(&self) -> bool {
130        self.types.contains(&CertificateType::X509)
131    }
132
133    /// Find the best common certificate type between two preference lists
134    pub fn negotiate(&self, other: &Self) -> Option<CertificateType> {
135        // Find the first certificate type in our preference list that is also supported by the other party
136        for cert_type in &self.types {
137            if other.types.contains(cert_type) {
138                return Some(*cert_type);
139            }
140        }
141        None
142    }
143
144    /// Serialize to wire format (length-prefixed list)
145    pub fn to_bytes(&self) -> Vec<u8> {
146        let mut bytes = Vec::with_capacity(1 + self.types.len());
147        bytes.push(self.types.len() as u8);
148        for cert_type in &self.types {
149            bytes.push(cert_type.to_u8());
150        }
151        bytes
152    }
153
154    /// Parse from wire format
155    pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
156        if bytes.is_empty() {
157            return Err(TlsExtensionError::InvalidExtensionData(
158                "Empty certificate type list".to_string(),
159            ));
160        }
161
162        let length = bytes[0] as usize;
163        if length == 0 {
164            return Err(TlsExtensionError::EmptyCertificateTypeList);
165        }
166        if length > 255 {
167            return Err(TlsExtensionError::CertificateTypeListTooLong(length));
168        }
169        if bytes.len() != 1 + length {
170            return Err(TlsExtensionError::InvalidExtensionData(format!(
171                "Certificate type list length mismatch: expected {}, got {}",
172                1 + length,
173                bytes.len()
174            )));
175        }
176
177        let mut types = Vec::with_capacity(length);
178        for i in 1..=length {
179            let cert_type = CertificateType::from_u8(bytes[i])?;
180            types.push(cert_type);
181        }
182
183        Self::new(types)
184    }
185}
186
187/// TLS extension IDs for certificate type negotiation (RFC 7250)
188pub mod extension_ids {
189    /// Client certificate type extension ID
190    pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
191    /// Server certificate type extension ID  
192    pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
193}
194
195/// Errors that can occur during TLS extension processing
196#[derive(Debug, Clone)]
197pub enum TlsExtensionError {
198    /// Unsupported certificate type value
199    UnsupportedCertificateType(u8),
200    /// Empty certificate type list
201    EmptyCertificateTypeList,
202    /// Certificate type list too long (>255 entries)
203    CertificateTypeListTooLong(usize),
204    /// Duplicate certificate type in list
205    DuplicateCertificateType(CertificateType),
206    /// Invalid extension data format
207    InvalidExtensionData(String),
208    /// Certificate type negotiation failed
209    NegotiationFailed {
210        client_types: CertificateTypeList,
211        server_types: CertificateTypeList,
212    },
213    /// Extension already registered
214    ExtensionAlreadyRegistered(u16),
215    /// rustls integration error
216    RustlsError(String),
217}
218
219impl fmt::Display for TlsExtensionError {
220    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221        match self {
222            Self::UnsupportedCertificateType(value) => {
223                write!(f, "Unsupported certificate type: {value}")
224            }
225            Self::EmptyCertificateTypeList => {
226                write!(f, "Certificate type list cannot be empty")
227            }
228            Self::CertificateTypeListTooLong(len) => {
229                write!(f, "Certificate type list too long: {len} (max 255)")
230            }
231            Self::DuplicateCertificateType(cert_type) => {
232                write!(f, "Duplicate certificate type: {cert_type}")
233            }
234            Self::InvalidExtensionData(msg) => {
235                write!(f, "Invalid extension data: {msg}")
236            }
237            Self::NegotiationFailed {
238                client_types,
239                server_types,
240            } => {
241                write!(
242                    f,
243                    "Certificate type negotiation failed: client={client_types:?}, server={server_types:?}"
244                )
245            }
246            Self::ExtensionAlreadyRegistered(id) => {
247                write!(f, "Extension already registered: {id}")
248            }
249            Self::RustlsError(msg) => {
250                write!(f, "rustls error: {msg}")
251            }
252        }
253    }
254}
255
256impl std::error::Error for TlsExtensionError {}
257
258/// Certificate type negotiation result
259#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
260pub struct NegotiationResult {
261    /// Negotiated client certificate type
262    pub client_cert_type: CertificateType,
263    /// Negotiated server certificate type
264    pub server_cert_type: CertificateType,
265}
266
267impl NegotiationResult {
268    /// Create a new negotiation result
269    pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
270        Self {
271            client_cert_type,
272            server_cert_type,
273        }
274    }
275
276    /// Check if Raw Public Keys are used for both client and server
277    pub fn is_raw_public_key_only(&self) -> bool {
278        self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
279    }
280
281    /// Check if X.509 certificates are used for both client and server
282    pub fn is_x509_only(&self) -> bool {
283        self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
284    }
285
286    /// Check if this is a mixed deployment (one RPK, one X.509)
287    pub fn is_mixed(&self) -> bool {
288        !self.is_raw_public_key_only() && !self.is_x509_only()
289    }
290}
291
292/// Certificate type negotiation preferences and state
293#[derive(Debug, Clone, PartialEq, Eq)]
294pub struct CertificateTypePreferences {
295    /// Client certificate type preferences (what types we support for client auth)
296    pub client_types: CertificateTypeList,
297    /// Server certificate type preferences (what types we support for server auth)
298    pub server_types: CertificateTypeList,
299    /// Whether to require certificate type extensions (strict mode)
300    pub require_extensions: bool,
301    /// Default fallback certificate types if negotiation fails
302    pub fallback_client: CertificateType,
303    pub fallback_server: CertificateType,
304}
305
306impl CertificateTypePreferences {
307    /// Create preferences favoring Raw Public Keys
308    pub fn prefer_raw_public_key() -> Self {
309        Self {
310            client_types: CertificateTypeList::prefer_raw_public_key(),
311            server_types: CertificateTypeList::prefer_raw_public_key(),
312            require_extensions: false,
313            fallback_client: CertificateType::X509,
314            fallback_server: CertificateType::X509,
315        }
316    }
317
318    /// Create preferences for Raw Public Key only
319    pub fn raw_public_key_only() -> Self {
320        Self {
321            client_types: CertificateTypeList::raw_public_key_only(),
322            server_types: CertificateTypeList::raw_public_key_only(),
323            require_extensions: true,
324            fallback_client: CertificateType::RawPublicKey,
325            fallback_server: CertificateType::RawPublicKey,
326        }
327    }
328
329    /// Create preferences for X.509 only (legacy mode)
330    pub fn x509_only() -> Self {
331        Self {
332            client_types: CertificateTypeList::x509_only(),
333            server_types: CertificateTypeList::x509_only(),
334            require_extensions: false,
335            fallback_client: CertificateType::X509,
336            fallback_server: CertificateType::X509,
337        }
338    }
339
340    /// Negotiate certificate types with remote peer preferences
341    pub fn negotiate(
342        &self,
343        remote_client_types: Option<&CertificateTypeList>,
344        remote_server_types: Option<&CertificateTypeList>,
345    ) -> Result<NegotiationResult, TlsExtensionError> {
346        let client_cert_type = if let Some(remote_types) = remote_client_types {
347            self.client_types.negotiate(remote_types).ok_or_else(|| {
348                TlsExtensionError::NegotiationFailed {
349                    client_types: self.client_types.clone(),
350                    server_types: remote_types.clone(),
351                }
352            })?
353        } else if self.require_extensions {
354            return Err(TlsExtensionError::NegotiationFailed {
355                client_types: self.client_types.clone(),
356                server_types: CertificateTypeList::x509_only(),
357            });
358        } else {
359            self.fallback_client
360        };
361
362        let server_cert_type = if let Some(remote_types) = remote_server_types {
363            self.server_types.negotiate(remote_types).ok_or_else(|| {
364                TlsExtensionError::NegotiationFailed {
365                    client_types: self.server_types.clone(),
366                    server_types: remote_types.clone(),
367                }
368            })?
369        } else if self.require_extensions {
370            return Err(TlsExtensionError::NegotiationFailed {
371                client_types: self.server_types.clone(),
372                server_types: CertificateTypeList::x509_only(),
373            });
374        } else {
375            self.fallback_server
376        };
377
378        Ok(NegotiationResult::new(client_cert_type, server_cert_type))
379    }
380}
381
382impl Default for CertificateTypePreferences {
383    fn default() -> Self {
384        Self::prefer_raw_public_key()
385    }
386}
387
388/// Certificate type negotiation cache for performance optimization
389#[derive(Debug)]
390pub struct NegotiationCache {
391    /// Cache of negotiation results keyed by (local_prefs, remote_prefs) hash
392    cache: HashMap<u64, NegotiationResult>,
393    /// Maximum cache size to prevent unbounded growth
394    max_size: usize,
395}
396
397impl NegotiationCache {
398    /// Create a new negotiation cache
399    pub fn new(max_size: usize) -> Self {
400        Self {
401            cache: HashMap::with_capacity(max_size.min(1000)),
402            max_size,
403        }
404    }
405
406    /// Get cached negotiation result
407    pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
408        self.cache.get(&key)
409    }
410
411    /// Cache a negotiation result
412    pub fn insert(&mut self, key: u64, result: NegotiationResult) {
413        if self.cache.len() >= self.max_size {
414            // Simple eviction: remove oldest entry (first in iteration order)
415            if let Some(oldest_key) = self.cache.keys().next().copied() {
416                self.cache.remove(&oldest_key);
417            }
418        }
419        self.cache.insert(key, result);
420    }
421
422    /// Clear the cache
423    pub fn clear(&mut self) {
424        self.cache.clear();
425    }
426
427    /// Get cache statistics
428    pub fn stats(&self) -> (usize, usize) {
429        (self.cache.len(), self.max_size)
430    }
431}
432
433impl Default for NegotiationCache {
434    fn default() -> Self {
435        Self::new(1000)
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_certificate_type_conversion() {
445        assert_eq!(CertificateType::X509.to_u8(), 0);
446        assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
447
448        assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
449        assert_eq!(
450            CertificateType::from_u8(2).unwrap(),
451            CertificateType::RawPublicKey
452        );
453
454        assert!(CertificateType::from_u8(1).is_err());
455        assert!(CertificateType::from_u8(255).is_err());
456    }
457
458    #[test]
459    fn test_certificate_type_list_creation() {
460        let list =
461            CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509])
462                .unwrap();
463        assert_eq!(list.types.len(), 2);
464        assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
465        assert!(list.supports_raw_public_key());
466        assert!(list.supports_x509());
467
468        // Test empty list error
469        assert!(CertificateTypeList::new(vec![]).is_err());
470
471        // Test duplicate error
472        assert!(
473            CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err()
474        );
475    }
476
477    #[test]
478    fn test_certificate_type_list_serialization() {
479        let list = CertificateTypeList::prefer_raw_public_key();
480        let bytes = list.to_bytes();
481        assert_eq!(bytes, vec![2, 2, 0]); // length=2, RPK=2, X509=0
482
483        let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
484        assert_eq!(parsed, list);
485    }
486
487    #[test]
488    fn test_certificate_type_list_negotiation() {
489        let rpk_only = CertificateTypeList::raw_public_key_only();
490        let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
491        let x509_only = CertificateTypeList::x509_only();
492
493        // RPK only with prefer RPK should negotiate to RPK
494        assert_eq!(
495            rpk_only.negotiate(&prefer_rpk).unwrap(),
496            CertificateType::RawPublicKey
497        );
498
499        // Prefer RPK with X509 only should negotiate to X509
500        assert_eq!(
501            prefer_rpk.negotiate(&x509_only).unwrap(),
502            CertificateType::X509
503        );
504
505        // RPK only with X509 only should fail
506        assert!(rpk_only.negotiate(&x509_only).is_none());
507    }
508
509    #[test]
510    fn test_preferences_negotiation() {
511        let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
512        let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
513
514        let result = rpk_prefs
515            .negotiate(
516                Some(&mixed_prefs.client_types),
517                Some(&mixed_prefs.server_types),
518            )
519            .unwrap();
520
521        assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
522        assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
523        assert!(result.is_raw_public_key_only());
524    }
525
526    #[test]
527    fn test_negotiation_cache() {
528        let mut cache = NegotiationCache::new(2);
529        let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
530
531        assert!(cache.get(123).is_none());
532
533        cache.insert(123, result.clone());
534        assert_eq!(cache.get(123).unwrap(), &result);
535
536        // Test that cache size is limited
537        cache.insert(456, result.clone());
538        assert_eq!(cache.cache.len(), 2); // Should have 2 entries
539
540        cache.insert(789, result.clone());
541        assert_eq!(cache.cache.len(), 2); // Should still have 2 entries after eviction
542
543        // At least one of the new entries should be present
544        assert!(cache.get(456).is_some() || cache.get(789).is_some());
545    }
546}