1use std::{
12 collections::HashMap,
13 fmt::{self, Debug},
14};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
18#[repr(u8)]
19pub enum CertificateType {
20 X509 = 0,
22 RawPublicKey = 2,
24}
25
26impl CertificateType {
27 pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
29 match value {
30 0 => Ok(Self::X509),
31 2 => Ok(Self::RawPublicKey),
32 _ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
33 }
34 }
35
36 pub fn to_u8(self) -> u8 {
38 self as u8
39 }
40
41 pub fn is_raw_public_key(self) -> bool {
43 matches!(self, Self::RawPublicKey)
44 }
45
46 pub fn is_x509(self) -> bool {
48 matches!(self, Self::X509)
49 }
50}
51
52impl fmt::Display for CertificateType {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 match self {
55 Self::X509 => write!(f, "X.509"),
56 Self::RawPublicKey => write!(f, "RawPublicKey"),
57 }
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
63pub struct CertificateTypeList {
64 pub types: Vec<CertificateType>,
66}
67
68impl CertificateTypeList {
69 pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
71 if types.is_empty() {
72 return Err(TlsExtensionError::EmptyCertificateTypeList);
73 }
74 if types.len() > 255 {
75 return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
76 }
77
78 let mut seen = std::collections::HashSet::new();
80 for cert_type in &types {
81 if !seen.insert(*cert_type) {
82 return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
83 }
84 }
85
86 Ok(Self { types })
87 }
88
89 pub fn raw_public_key_only() -> Self {
91 Self {
92 types: vec![CertificateType::RawPublicKey],
93 }
94 }
95
96 pub fn prefer_raw_public_key() -> Self {
98 Self {
99 types: vec![CertificateType::RawPublicKey, CertificateType::X509],
100 }
101 }
102
103 pub fn x509_only() -> Self {
105 Self {
106 types: vec![CertificateType::X509],
107 }
108 }
109
110 pub fn most_preferred(&self) -> CertificateType {
112 self.types[0]
113 }
114
115 pub fn supports_raw_public_key(&self) -> bool {
117 self.types.contains(&CertificateType::RawPublicKey)
118 }
119
120 pub fn supports_x509(&self) -> bool {
122 self.types.contains(&CertificateType::X509)
123 }
124
125 pub fn negotiate(&self, other: &Self) -> Option<CertificateType> {
127 for cert_type in &self.types {
129 if other.types.contains(cert_type) {
130 return Some(*cert_type);
131 }
132 }
133 None
134 }
135
136 pub fn to_bytes(&self) -> Vec<u8> {
138 let mut bytes = Vec::with_capacity(1 + self.types.len());
139 bytes.push(self.types.len() as u8);
140 for cert_type in &self.types {
141 bytes.push(cert_type.to_u8());
142 }
143 bytes
144 }
145
146 pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
148 if bytes.is_empty() {
149 return Err(TlsExtensionError::InvalidExtensionData(
150 "Empty certificate type list".to_string(),
151 ));
152 }
153
154 let length = bytes[0] as usize;
155 if length == 0 {
156 return Err(TlsExtensionError::EmptyCertificateTypeList);
157 }
158 if length > 255 {
159 return Err(TlsExtensionError::CertificateTypeListTooLong(length));
160 }
161 if bytes.len() != 1 + length {
162 return Err(TlsExtensionError::InvalidExtensionData(format!(
163 "Certificate type list length mismatch: expected {}, got {}",
164 1 + length,
165 bytes.len()
166 )));
167 }
168
169 let mut types = Vec::with_capacity(length);
170 for i in 1..=length {
171 let cert_type = CertificateType::from_u8(bytes[i])?;
172 types.push(cert_type);
173 }
174
175 Self::new(types)
176 }
177}
178
179pub mod extension_ids {
181 pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
183 pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
185}
186
187#[derive(Debug, Clone)]
189pub enum TlsExtensionError {
190 UnsupportedCertificateType(u8),
192 EmptyCertificateTypeList,
194 CertificateTypeListTooLong(usize),
196 DuplicateCertificateType(CertificateType),
198 InvalidExtensionData(String),
200 NegotiationFailed {
202 client_types: CertificateTypeList,
203 server_types: CertificateTypeList,
204 },
205 ExtensionAlreadyRegistered(u16),
207 RustlsError(String),
209}
210
211impl fmt::Display for TlsExtensionError {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 match self {
214 Self::UnsupportedCertificateType(value) => {
215 write!(f, "Unsupported certificate type: {value}")
216 }
217 Self::EmptyCertificateTypeList => {
218 write!(f, "Certificate type list cannot be empty")
219 }
220 Self::CertificateTypeListTooLong(len) => {
221 write!(f, "Certificate type list too long: {len} (max 255)")
222 }
223 Self::DuplicateCertificateType(cert_type) => {
224 write!(f, "Duplicate certificate type: {cert_type}")
225 }
226 Self::InvalidExtensionData(msg) => {
227 write!(f, "Invalid extension data: {msg}")
228 }
229 Self::NegotiationFailed {
230 client_types,
231 server_types,
232 } => {
233 write!(
234 f,
235 "Certificate type negotiation failed: client={client_types:?}, server={server_types:?}"
236 )
237 }
238 Self::ExtensionAlreadyRegistered(id) => {
239 write!(f, "Extension already registered: {id}")
240 }
241 Self::RustlsError(msg) => {
242 write!(f, "rustls error: {msg}")
243 }
244 }
245 }
246}
247
248impl std::error::Error for TlsExtensionError {}
249
250#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
252pub struct NegotiationResult {
253 pub client_cert_type: CertificateType,
255 pub server_cert_type: CertificateType,
257}
258
259impl NegotiationResult {
260 pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
262 Self {
263 client_cert_type,
264 server_cert_type,
265 }
266 }
267
268 pub fn is_raw_public_key_only(&self) -> bool {
270 self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
271 }
272
273 pub fn is_x509_only(&self) -> bool {
275 self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
276 }
277
278 pub fn is_mixed(&self) -> bool {
280 !self.is_raw_public_key_only() && !self.is_x509_only()
281 }
282}
283
284#[derive(Debug, Clone, PartialEq, Eq)]
286pub struct CertificateTypePreferences {
287 pub client_types: CertificateTypeList,
289 pub server_types: CertificateTypeList,
291 pub require_extensions: bool,
293 pub fallback_client: CertificateType,
295 pub fallback_server: CertificateType,
296}
297
298impl CertificateTypePreferences {
299 pub fn prefer_raw_public_key() -> Self {
301 Self {
302 client_types: CertificateTypeList::prefer_raw_public_key(),
303 server_types: CertificateTypeList::prefer_raw_public_key(),
304 require_extensions: false,
305 fallback_client: CertificateType::X509,
306 fallback_server: CertificateType::X509,
307 }
308 }
309
310 pub fn raw_public_key_only() -> Self {
312 Self {
313 client_types: CertificateTypeList::raw_public_key_only(),
314 server_types: CertificateTypeList::raw_public_key_only(),
315 require_extensions: true,
316 fallback_client: CertificateType::RawPublicKey,
317 fallback_server: CertificateType::RawPublicKey,
318 }
319 }
320
321 pub fn x509_only() -> Self {
323 Self {
324 client_types: CertificateTypeList::x509_only(),
325 server_types: CertificateTypeList::x509_only(),
326 require_extensions: false,
327 fallback_client: CertificateType::X509,
328 fallback_server: CertificateType::X509,
329 }
330 }
331
332 pub fn negotiate(
334 &self,
335 remote_client_types: Option<&CertificateTypeList>,
336 remote_server_types: Option<&CertificateTypeList>,
337 ) -> Result<NegotiationResult, TlsExtensionError> {
338 let client_cert_type = if let Some(remote_types) = remote_client_types {
339 self.client_types.negotiate(remote_types).ok_or_else(|| {
340 TlsExtensionError::NegotiationFailed {
341 client_types: self.client_types.clone(),
342 server_types: remote_types.clone(),
343 }
344 })?
345 } else if self.require_extensions {
346 return Err(TlsExtensionError::NegotiationFailed {
347 client_types: self.client_types.clone(),
348 server_types: CertificateTypeList::x509_only(),
349 });
350 } else {
351 self.fallback_client
352 };
353
354 let server_cert_type = if let Some(remote_types) = remote_server_types {
355 self.server_types.negotiate(remote_types).ok_or_else(|| {
356 TlsExtensionError::NegotiationFailed {
357 client_types: self.server_types.clone(),
358 server_types: remote_types.clone(),
359 }
360 })?
361 } else if self.require_extensions {
362 return Err(TlsExtensionError::NegotiationFailed {
363 client_types: self.server_types.clone(),
364 server_types: CertificateTypeList::x509_only(),
365 });
366 } else {
367 self.fallback_server
368 };
369
370 Ok(NegotiationResult::new(client_cert_type, server_cert_type))
371 }
372}
373
374impl Default for CertificateTypePreferences {
375 fn default() -> Self {
376 Self::prefer_raw_public_key()
377 }
378}
379
380#[derive(Debug)]
382pub struct NegotiationCache {
383 cache: HashMap<u64, NegotiationResult>,
385 max_size: usize,
387}
388
389impl NegotiationCache {
390 pub fn new(max_size: usize) -> Self {
392 Self {
393 cache: HashMap::with_capacity(max_size.min(1000)),
394 max_size,
395 }
396 }
397
398 pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
400 self.cache.get(&key)
401 }
402
403 pub fn insert(&mut self, key: u64, result: NegotiationResult) {
405 if self.cache.len() >= self.max_size {
406 if let Some(oldest_key) = self.cache.keys().next().copied() {
408 self.cache.remove(&oldest_key);
409 }
410 }
411 self.cache.insert(key, result);
412 }
413
414 pub fn clear(&mut self) {
416 self.cache.clear();
417 }
418
419 pub fn stats(&self) -> (usize, usize) {
421 (self.cache.len(), self.max_size)
422 }
423}
424
425impl Default for NegotiationCache {
426 fn default() -> Self {
427 Self::new(1000)
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_certificate_type_conversion() {
437 assert_eq!(CertificateType::X509.to_u8(), 0);
438 assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
439
440 assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
441 assert_eq!(
442 CertificateType::from_u8(2).unwrap(),
443 CertificateType::RawPublicKey
444 );
445
446 assert!(CertificateType::from_u8(1).is_err());
447 assert!(CertificateType::from_u8(255).is_err());
448 }
449
450 #[test]
451 fn test_certificate_type_list_creation() {
452 let list =
453 CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509])
454 .unwrap();
455 assert_eq!(list.types.len(), 2);
456 assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
457 assert!(list.supports_raw_public_key());
458 assert!(list.supports_x509());
459
460 assert!(CertificateTypeList::new(vec![]).is_err());
462
463 assert!(
465 CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err()
466 );
467 }
468
469 #[test]
470 fn test_certificate_type_list_serialization() {
471 let list = CertificateTypeList::prefer_raw_public_key();
472 let bytes = list.to_bytes();
473 assert_eq!(bytes, vec![2, 2, 0]); let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
476 assert_eq!(parsed, list);
477 }
478
479 #[test]
480 fn test_certificate_type_list_negotiation() {
481 let rpk_only = CertificateTypeList::raw_public_key_only();
482 let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
483 let x509_only = CertificateTypeList::x509_only();
484
485 assert_eq!(
487 rpk_only.negotiate(&prefer_rpk).unwrap(),
488 CertificateType::RawPublicKey
489 );
490
491 assert_eq!(
493 prefer_rpk.negotiate(&x509_only).unwrap(),
494 CertificateType::X509
495 );
496
497 assert!(rpk_only.negotiate(&x509_only).is_none());
499 }
500
501 #[test]
502 fn test_preferences_negotiation() {
503 let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
504 let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
505
506 let result = rpk_prefs
507 .negotiate(
508 Some(&mixed_prefs.client_types),
509 Some(&mixed_prefs.server_types),
510 )
511 .unwrap();
512
513 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
514 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
515 assert!(result.is_raw_public_key_only());
516 }
517
518 #[test]
519 fn test_negotiation_cache() {
520 let mut cache = NegotiationCache::new(2);
521 let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
522
523 assert!(cache.get(123).is_none());
524
525 cache.insert(123, result.clone());
526 assert_eq!(cache.get(123).unwrap(), &result);
527
528 cache.insert(456, result.clone());
530 assert_eq!(cache.cache.len(), 2); cache.insert(789, result.clone());
533 assert_eq!(cache.cache.len(), 2); assert!(cache.get(456).is_some() || cache.get(789).is_some());
537 }
538}