1#![allow(missing_docs)]
8
9use std::{
20 collections::HashMap,
21 fmt::{self, Debug},
22};
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
26#[repr(u8)]
27pub enum CertificateType {
28 X509 = 0,
30 RawPublicKey = 2,
32}
33
34impl CertificateType {
35 pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
37 match value {
38 0 => Ok(Self::X509),
39 2 => Ok(Self::RawPublicKey),
40 _ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
41 }
42 }
43
44 pub fn to_u8(self) -> u8 {
46 self as u8
47 }
48
49 pub fn is_raw_public_key(self) -> bool {
51 matches!(self, Self::RawPublicKey)
52 }
53
54 pub fn is_x509(self) -> bool {
56 matches!(self, Self::X509)
57 }
58}
59
60impl fmt::Display for CertificateType {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match self {
63 Self::X509 => write!(f, "X.509"),
64 Self::RawPublicKey => write!(f, "RawPublicKey"),
65 }
66 }
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct CertificateTypeList {
72 pub types: Vec<CertificateType>,
74}
75
76impl CertificateTypeList {
77 pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
79 if types.is_empty() {
80 return Err(TlsExtensionError::EmptyCertificateTypeList);
81 }
82 if types.len() > 255 {
83 return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
84 }
85
86 let mut seen = std::collections::HashSet::new();
88 for cert_type in &types {
89 if !seen.insert(*cert_type) {
90 return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
91 }
92 }
93
94 Ok(Self { types })
95 }
96
97 pub fn raw_public_key_only() -> Self {
99 Self {
100 types: vec![CertificateType::RawPublicKey],
101 }
102 }
103
104 pub fn prefer_raw_public_key() -> Self {
106 Self {
107 types: vec![CertificateType::RawPublicKey, CertificateType::X509],
108 }
109 }
110
111 pub fn x509_only() -> Self {
113 Self {
114 types: vec![CertificateType::X509],
115 }
116 }
117
118 pub fn most_preferred(&self) -> CertificateType {
120 self.types[0]
121 }
122
123 pub fn supports_raw_public_key(&self) -> bool {
125 self.types.contains(&CertificateType::RawPublicKey)
126 }
127
128 pub fn supports_x509(&self) -> bool {
130 self.types.contains(&CertificateType::X509)
131 }
132
133 pub fn negotiate(&self, other: &Self) -> Option<CertificateType> {
135 for cert_type in &self.types {
137 if other.types.contains(cert_type) {
138 return Some(*cert_type);
139 }
140 }
141 None
142 }
143
144 pub fn to_bytes(&self) -> Vec<u8> {
146 let mut bytes = Vec::with_capacity(1 + self.types.len());
147 bytes.push(self.types.len() as u8);
148 for cert_type in &self.types {
149 bytes.push(cert_type.to_u8());
150 }
151 bytes
152 }
153
154 pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
156 if bytes.is_empty() {
157 return Err(TlsExtensionError::InvalidExtensionData(
158 "Empty certificate type list".to_string(),
159 ));
160 }
161
162 let length = bytes[0] as usize;
163 if length == 0 {
164 return Err(TlsExtensionError::EmptyCertificateTypeList);
165 }
166 if length > 255 {
167 return Err(TlsExtensionError::CertificateTypeListTooLong(length));
168 }
169 if bytes.len() != 1 + length {
170 return Err(TlsExtensionError::InvalidExtensionData(format!(
171 "Certificate type list length mismatch: expected {}, got {}",
172 1 + length,
173 bytes.len()
174 )));
175 }
176
177 let mut types = Vec::with_capacity(length);
178 for i in 1..=length {
179 let cert_type = CertificateType::from_u8(bytes[i])?;
180 types.push(cert_type);
181 }
182
183 Self::new(types)
184 }
185}
186
187pub mod extension_ids {
189 pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
191 pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
193}
194
195#[derive(Debug, Clone)]
197pub enum TlsExtensionError {
198 UnsupportedCertificateType(u8),
200 EmptyCertificateTypeList,
202 CertificateTypeListTooLong(usize),
204 DuplicateCertificateType(CertificateType),
206 InvalidExtensionData(String),
208 NegotiationFailed {
210 client_types: CertificateTypeList,
211 server_types: CertificateTypeList,
212 },
213 ExtensionAlreadyRegistered(u16),
215 RustlsError(String),
217}
218
219impl fmt::Display for TlsExtensionError {
220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
221 match self {
222 Self::UnsupportedCertificateType(value) => {
223 write!(f, "Unsupported certificate type: {value}")
224 }
225 Self::EmptyCertificateTypeList => {
226 write!(f, "Certificate type list cannot be empty")
227 }
228 Self::CertificateTypeListTooLong(len) => {
229 write!(f, "Certificate type list too long: {len} (max 255)")
230 }
231 Self::DuplicateCertificateType(cert_type) => {
232 write!(f, "Duplicate certificate type: {cert_type}")
233 }
234 Self::InvalidExtensionData(msg) => {
235 write!(f, "Invalid extension data: {msg}")
236 }
237 Self::NegotiationFailed {
238 client_types,
239 server_types,
240 } => {
241 write!(
242 f,
243 "Certificate type negotiation failed: client={client_types:?}, server={server_types:?}"
244 )
245 }
246 Self::ExtensionAlreadyRegistered(id) => {
247 write!(f, "Extension already registered: {id}")
248 }
249 Self::RustlsError(msg) => {
250 write!(f, "rustls error: {msg}")
251 }
252 }
253 }
254}
255
256impl std::error::Error for TlsExtensionError {}
257
258#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
260pub struct NegotiationResult {
261 pub client_cert_type: CertificateType,
263 pub server_cert_type: CertificateType,
265}
266
267impl NegotiationResult {
268 pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
270 Self {
271 client_cert_type,
272 server_cert_type,
273 }
274 }
275
276 pub fn is_raw_public_key_only(&self) -> bool {
278 self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
279 }
280
281 pub fn is_x509_only(&self) -> bool {
283 self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
284 }
285
286 pub fn is_mixed(&self) -> bool {
288 !self.is_raw_public_key_only() && !self.is_x509_only()
289 }
290}
291
292#[derive(Debug, Clone, PartialEq, Eq)]
294pub struct CertificateTypePreferences {
295 pub client_types: CertificateTypeList,
297 pub server_types: CertificateTypeList,
299 pub require_extensions: bool,
301 pub fallback_client: CertificateType,
303 pub fallback_server: CertificateType,
304}
305
306impl CertificateTypePreferences {
307 pub fn prefer_raw_public_key() -> Self {
309 Self {
310 client_types: CertificateTypeList::prefer_raw_public_key(),
311 server_types: CertificateTypeList::prefer_raw_public_key(),
312 require_extensions: false,
313 fallback_client: CertificateType::X509,
314 fallback_server: CertificateType::X509,
315 }
316 }
317
318 pub fn raw_public_key_only() -> Self {
320 Self {
321 client_types: CertificateTypeList::raw_public_key_only(),
322 server_types: CertificateTypeList::raw_public_key_only(),
323 require_extensions: true,
324 fallback_client: CertificateType::RawPublicKey,
325 fallback_server: CertificateType::RawPublicKey,
326 }
327 }
328
329 pub fn x509_only() -> Self {
331 Self {
332 client_types: CertificateTypeList::x509_only(),
333 server_types: CertificateTypeList::x509_only(),
334 require_extensions: false,
335 fallback_client: CertificateType::X509,
336 fallback_server: CertificateType::X509,
337 }
338 }
339
340 pub fn negotiate(
342 &self,
343 remote_client_types: Option<&CertificateTypeList>,
344 remote_server_types: Option<&CertificateTypeList>,
345 ) -> Result<NegotiationResult, TlsExtensionError> {
346 let client_cert_type = if let Some(remote_types) = remote_client_types {
347 self.client_types.negotiate(remote_types).ok_or_else(|| {
348 TlsExtensionError::NegotiationFailed {
349 client_types: self.client_types.clone(),
350 server_types: remote_types.clone(),
351 }
352 })?
353 } else if self.require_extensions {
354 return Err(TlsExtensionError::NegotiationFailed {
355 client_types: self.client_types.clone(),
356 server_types: CertificateTypeList::x509_only(),
357 });
358 } else {
359 self.fallback_client
360 };
361
362 let server_cert_type = if let Some(remote_types) = remote_server_types {
363 self.server_types.negotiate(remote_types).ok_or_else(|| {
364 TlsExtensionError::NegotiationFailed {
365 client_types: self.server_types.clone(),
366 server_types: remote_types.clone(),
367 }
368 })?
369 } else if self.require_extensions {
370 return Err(TlsExtensionError::NegotiationFailed {
371 client_types: self.server_types.clone(),
372 server_types: CertificateTypeList::x509_only(),
373 });
374 } else {
375 self.fallback_server
376 };
377
378 Ok(NegotiationResult::new(client_cert_type, server_cert_type))
379 }
380}
381
382impl Default for CertificateTypePreferences {
383 fn default() -> Self {
384 Self::prefer_raw_public_key()
385 }
386}
387
388#[derive(Debug)]
390pub struct NegotiationCache {
391 cache: HashMap<u64, NegotiationResult>,
393 max_size: usize,
395}
396
397impl NegotiationCache {
398 pub fn new(max_size: usize) -> Self {
400 Self {
401 cache: HashMap::with_capacity(max_size.min(1000)),
402 max_size,
403 }
404 }
405
406 pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
408 self.cache.get(&key)
409 }
410
411 pub fn insert(&mut self, key: u64, result: NegotiationResult) {
413 if self.cache.len() >= self.max_size {
414 if let Some(oldest_key) = self.cache.keys().next().copied() {
416 self.cache.remove(&oldest_key);
417 }
418 }
419 self.cache.insert(key, result);
420 }
421
422 pub fn clear(&mut self) {
424 self.cache.clear();
425 }
426
427 pub fn stats(&self) -> (usize, usize) {
429 (self.cache.len(), self.max_size)
430 }
431}
432
433impl Default for NegotiationCache {
434 fn default() -> Self {
435 Self::new(1000)
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_certificate_type_conversion() {
445 assert_eq!(CertificateType::X509.to_u8(), 0);
446 assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
447
448 assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
449 assert_eq!(
450 CertificateType::from_u8(2).unwrap(),
451 CertificateType::RawPublicKey
452 );
453
454 assert!(CertificateType::from_u8(1).is_err());
455 assert!(CertificateType::from_u8(255).is_err());
456 }
457
458 #[test]
459 fn test_certificate_type_list_creation() {
460 let list =
461 CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509])
462 .unwrap();
463 assert_eq!(list.types.len(), 2);
464 assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
465 assert!(list.supports_raw_public_key());
466 assert!(list.supports_x509());
467
468 assert!(CertificateTypeList::new(vec![]).is_err());
470
471 assert!(
473 CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err()
474 );
475 }
476
477 #[test]
478 fn test_certificate_type_list_serialization() {
479 let list = CertificateTypeList::prefer_raw_public_key();
480 let bytes = list.to_bytes();
481 assert_eq!(bytes, vec![2, 2, 0]); let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
484 assert_eq!(parsed, list);
485 }
486
487 #[test]
488 fn test_certificate_type_list_negotiation() {
489 let rpk_only = CertificateTypeList::raw_public_key_only();
490 let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
491 let x509_only = CertificateTypeList::x509_only();
492
493 assert_eq!(
495 rpk_only.negotiate(&prefer_rpk).unwrap(),
496 CertificateType::RawPublicKey
497 );
498
499 assert_eq!(
501 prefer_rpk.negotiate(&x509_only).unwrap(),
502 CertificateType::X509
503 );
504
505 assert!(rpk_only.negotiate(&x509_only).is_none());
507 }
508
509 #[test]
510 fn test_preferences_negotiation() {
511 let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
512 let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
513
514 let result = rpk_prefs
515 .negotiate(
516 Some(&mixed_prefs.client_types),
517 Some(&mixed_prefs.server_types),
518 )
519 .unwrap();
520
521 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
522 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
523 assert!(result.is_raw_public_key_only());
524 }
525
526 #[test]
527 fn test_negotiation_cache() {
528 let mut cache = NegotiationCache::new(2);
529 let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
530
531 assert!(cache.get(123).is_none());
532
533 cache.insert(123, result.clone());
534 assert_eq!(cache.get(123).unwrap(), &result);
535
536 cache.insert(456, result.clone());
538 assert_eq!(cache.cache.len(), 2); cache.insert(789, result.clone());
541 assert_eq!(cache.cache.len(), 2); assert!(cache.get(456).is_some() || cache.get(789).is_some());
545 }
546}