1use std::{
12 collections::HashMap,
13 fmt::{self, Debug},
14};
15
16
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
20#[repr(u8)]
21pub enum CertificateType {
22 X509 = 0,
24 RawPublicKey = 2,
26}
27
28impl CertificateType {
29 pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
31 match value {
32 0 => Ok(CertificateType::X509),
33 2 => Ok(CertificateType::RawPublicKey),
34 _ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
35 }
36 }
37
38 pub fn to_u8(self) -> u8 {
40 self as u8
41 }
42
43 pub fn is_raw_public_key(self) -> bool {
45 matches!(self, CertificateType::RawPublicKey)
46 }
47
48 pub fn is_x509(self) -> bool {
50 matches!(self, CertificateType::X509)
51 }
52}
53
54impl fmt::Display for CertificateType {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 match self {
57 CertificateType::X509 => write!(f, "X.509"),
58 CertificateType::RawPublicKey => write!(f, "RawPublicKey"),
59 }
60 }
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct CertificateTypeList {
66 pub types: Vec<CertificateType>,
68}
69
70impl CertificateTypeList {
71 pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
73 if types.is_empty() {
74 return Err(TlsExtensionError::EmptyCertificateTypeList);
75 }
76 if types.len() > 255 {
77 return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
78 }
79
80 let mut seen = std::collections::HashSet::new();
82 for cert_type in &types {
83 if !seen.insert(*cert_type) {
84 return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
85 }
86 }
87
88 Ok(CertificateTypeList { types })
89 }
90
91 pub fn raw_public_key_only() -> Self {
93 CertificateTypeList {
94 types: vec![CertificateType::RawPublicKey],
95 }
96 }
97
98 pub fn prefer_raw_public_key() -> Self {
100 CertificateTypeList {
101 types: vec![CertificateType::RawPublicKey, CertificateType::X509],
102 }
103 }
104
105 pub fn x509_only() -> Self {
107 CertificateTypeList {
108 types: vec![CertificateType::X509],
109 }
110 }
111
112 pub fn most_preferred(&self) -> CertificateType {
114 self.types[0]
115 }
116
117 pub fn supports_raw_public_key(&self) -> bool {
119 self.types.contains(&CertificateType::RawPublicKey)
120 }
121
122 pub fn supports_x509(&self) -> bool {
124 self.types.contains(&CertificateType::X509)
125 }
126
127 pub fn negotiate(&self, other: &CertificateTypeList) -> Option<CertificateType> {
129 for cert_type in &self.types {
131 if other.types.contains(cert_type) {
132 return Some(*cert_type);
133 }
134 }
135 None
136 }
137
138 pub fn to_bytes(&self) -> Vec<u8> {
140 let mut bytes = Vec::with_capacity(1 + self.types.len());
141 bytes.push(self.types.len() as u8);
142 for cert_type in &self.types {
143 bytes.push(cert_type.to_u8());
144 }
145 bytes
146 }
147
148 pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
150 if bytes.is_empty() {
151 return Err(TlsExtensionError::InvalidExtensionData("Empty certificate type list".to_string()));
152 }
153
154 let length = bytes[0] as usize;
155 if length == 0 {
156 return Err(TlsExtensionError::EmptyCertificateTypeList);
157 }
158 if length > 255 {
159 return Err(TlsExtensionError::CertificateTypeListTooLong(length));
160 }
161 if bytes.len() != 1 + length {
162 return Err(TlsExtensionError::InvalidExtensionData(
163 format!("Certificate type list length mismatch: expected {}, got {}", 1 + length, bytes.len())
164 ));
165 }
166
167 let mut types = Vec::with_capacity(length);
168 for i in 1..=length {
169 let cert_type = CertificateType::from_u8(bytes[i])?;
170 types.push(cert_type);
171 }
172
173 CertificateTypeList::new(types)
174 }
175}
176
177pub mod extension_ids {
179 pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
181 pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
183}
184
185#[derive(Debug, Clone)]
187pub enum TlsExtensionError {
188 UnsupportedCertificateType(u8),
190 EmptyCertificateTypeList,
192 CertificateTypeListTooLong(usize),
194 DuplicateCertificateType(CertificateType),
196 InvalidExtensionData(String),
198 NegotiationFailed {
200 client_types: CertificateTypeList,
201 server_types: CertificateTypeList,
202 },
203 ExtensionAlreadyRegistered(u16),
205 RustlsError(String),
207}
208
209impl fmt::Display for TlsExtensionError {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 match self {
212 TlsExtensionError::UnsupportedCertificateType(value) => {
213 write!(f, "Unsupported certificate type: {}", value)
214 }
215 TlsExtensionError::EmptyCertificateTypeList => {
216 write!(f, "Certificate type list cannot be empty")
217 }
218 TlsExtensionError::CertificateTypeListTooLong(len) => {
219 write!(f, "Certificate type list too long: {} (max 255)", len)
220 }
221 TlsExtensionError::DuplicateCertificateType(cert_type) => {
222 write!(f, "Duplicate certificate type: {}", cert_type)
223 }
224 TlsExtensionError::InvalidExtensionData(msg) => {
225 write!(f, "Invalid extension data: {}", msg)
226 }
227 TlsExtensionError::NegotiationFailed { client_types, server_types } => {
228 write!(f, "Certificate type negotiation failed: client={:?}, server={:?}",
229 client_types, server_types)
230 }
231 TlsExtensionError::ExtensionAlreadyRegistered(id) => {
232 write!(f, "Extension already registered: {}", id)
233 }
234 TlsExtensionError::RustlsError(msg) => {
235 write!(f, "rustls error: {}", msg)
236 }
237 }
238 }
239}
240
241impl std::error::Error for TlsExtensionError {}
242
243#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
245pub struct NegotiationResult {
246 pub client_cert_type: CertificateType,
248 pub server_cert_type: CertificateType,
250}
251
252impl NegotiationResult {
253 pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
255 Self {
256 client_cert_type,
257 server_cert_type,
258 }
259 }
260
261 pub fn is_raw_public_key_only(&self) -> bool {
263 self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
264 }
265
266 pub fn is_x509_only(&self) -> bool {
268 self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
269 }
270
271 pub fn is_mixed(&self) -> bool {
273 !self.is_raw_public_key_only() && !self.is_x509_only()
274 }
275}
276
277#[derive(Debug, Clone, PartialEq, Eq)]
279pub struct CertificateTypePreferences {
280 pub client_types: CertificateTypeList,
282 pub server_types: CertificateTypeList,
284 pub require_extensions: bool,
286 pub fallback_client: CertificateType,
288 pub fallback_server: CertificateType,
289}
290
291impl CertificateTypePreferences {
292 pub fn prefer_raw_public_key() -> Self {
294 Self {
295 client_types: CertificateTypeList::prefer_raw_public_key(),
296 server_types: CertificateTypeList::prefer_raw_public_key(),
297 require_extensions: false,
298 fallback_client: CertificateType::X509,
299 fallback_server: CertificateType::X509,
300 }
301 }
302
303 pub fn raw_public_key_only() -> Self {
305 Self {
306 client_types: CertificateTypeList::raw_public_key_only(),
307 server_types: CertificateTypeList::raw_public_key_only(),
308 require_extensions: true,
309 fallback_client: CertificateType::RawPublicKey,
310 fallback_server: CertificateType::RawPublicKey,
311 }
312 }
313
314 pub fn x509_only() -> Self {
316 Self {
317 client_types: CertificateTypeList::x509_only(),
318 server_types: CertificateTypeList::x509_only(),
319 require_extensions: false,
320 fallback_client: CertificateType::X509,
321 fallback_server: CertificateType::X509,
322 }
323 }
324
325 pub fn negotiate(
327 &self,
328 remote_client_types: Option<&CertificateTypeList>,
329 remote_server_types: Option<&CertificateTypeList>,
330 ) -> Result<NegotiationResult, TlsExtensionError> {
331 let client_cert_type = if let Some(remote_types) = remote_client_types {
332 self.client_types.negotiate(remote_types)
333 .ok_or_else(|| TlsExtensionError::NegotiationFailed {
334 client_types: self.client_types.clone(),
335 server_types: remote_types.clone(),
336 })?
337 } else if self.require_extensions {
338 return Err(TlsExtensionError::NegotiationFailed {
339 client_types: self.client_types.clone(),
340 server_types: CertificateTypeList::x509_only(),
341 });
342 } else {
343 self.fallback_client
344 };
345
346 let server_cert_type = if let Some(remote_types) = remote_server_types {
347 self.server_types.negotiate(remote_types)
348 .ok_or_else(|| TlsExtensionError::NegotiationFailed {
349 client_types: self.server_types.clone(),
350 server_types: remote_types.clone(),
351 })?
352 } else if self.require_extensions {
353 return Err(TlsExtensionError::NegotiationFailed {
354 client_types: self.server_types.clone(),
355 server_types: CertificateTypeList::x509_only(),
356 });
357 } else {
358 self.fallback_server
359 };
360
361 Ok(NegotiationResult::new(client_cert_type, server_cert_type))
362 }
363}
364
365impl Default for CertificateTypePreferences {
366 fn default() -> Self {
367 Self::prefer_raw_public_key()
368 }
369}
370
371#[derive(Debug)]
373pub struct NegotiationCache {
374 cache: HashMap<u64, NegotiationResult>,
376 max_size: usize,
378}
379
380impl NegotiationCache {
381 pub fn new(max_size: usize) -> Self {
383 Self {
384 cache: HashMap::with_capacity(max_size.min(1000)),
385 max_size,
386 }
387 }
388
389 pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
391 self.cache.get(&key)
392 }
393
394 pub fn insert(&mut self, key: u64, result: NegotiationResult) {
396 if self.cache.len() >= self.max_size {
397 if let Some(oldest_key) = self.cache.keys().next().copied() {
399 self.cache.remove(&oldest_key);
400 }
401 }
402 self.cache.insert(key, result);
403 }
404
405 pub fn clear(&mut self) {
407 self.cache.clear();
408 }
409
410 pub fn stats(&self) -> (usize, usize) {
412 (self.cache.len(), self.max_size)
413 }
414}
415
416impl Default for NegotiationCache {
417 fn default() -> Self {
418 Self::new(1000)
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425
426 #[test]
427 fn test_certificate_type_conversion() {
428 assert_eq!(CertificateType::X509.to_u8(), 0);
429 assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
430
431 assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
432 assert_eq!(CertificateType::from_u8(2).unwrap(), CertificateType::RawPublicKey);
433
434 assert!(CertificateType::from_u8(1).is_err());
435 assert!(CertificateType::from_u8(255).is_err());
436 }
437
438 #[test]
439 fn test_certificate_type_list_creation() {
440 let list = CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509]).unwrap();
441 assert_eq!(list.types.len(), 2);
442 assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
443 assert!(list.supports_raw_public_key());
444 assert!(list.supports_x509());
445
446 assert!(CertificateTypeList::new(vec![]).is_err());
448
449 assert!(CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err());
451 }
452
453 #[test]
454 fn test_certificate_type_list_serialization() {
455 let list = CertificateTypeList::prefer_raw_public_key();
456 let bytes = list.to_bytes();
457 assert_eq!(bytes, vec![2, 2, 0]); let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
460 assert_eq!(parsed, list);
461 }
462
463 #[test]
464 fn test_certificate_type_list_negotiation() {
465 let rpk_only = CertificateTypeList::raw_public_key_only();
466 let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
467 let x509_only = CertificateTypeList::x509_only();
468
469 assert_eq!(rpk_only.negotiate(&prefer_rpk).unwrap(), CertificateType::RawPublicKey);
471
472 assert_eq!(prefer_rpk.negotiate(&x509_only).unwrap(), CertificateType::X509);
474
475 assert!(rpk_only.negotiate(&x509_only).is_none());
477 }
478
479 #[test]
480 fn test_preferences_negotiation() {
481 let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
482 let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
483
484 let result = rpk_prefs.negotiate(
485 Some(&mixed_prefs.client_types),
486 Some(&mixed_prefs.server_types),
487 ).unwrap();
488
489 assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
490 assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
491 assert!(result.is_raw_public_key_only());
492 }
493
494 #[test]
495 fn test_negotiation_cache() {
496 let mut cache = NegotiationCache::new(2);
497 let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
498
499 assert!(cache.get(123).is_none());
500
501 cache.insert(123, result.clone());
502 assert_eq!(cache.get(123).unwrap(), &result);
503
504 cache.insert(456, result.clone());
506 assert_eq!(cache.cache.len(), 2); cache.insert(789, result.clone());
509 assert_eq!(cache.cache.len(), 2); assert!(cache.get(456).is_some() || cache.get(789).is_some());
513 }
514}