#![allow(missing_docs)]
use std::{
collections::HashMap,
fmt::{self, Debug},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
#[repr(u8)]
pub enum CertificateType {
X509 = 0,
RawPublicKey = 2,
}
impl CertificateType {
pub fn from_u8(value: u8) -> Result<Self, TlsExtensionError> {
match value {
0 => Ok(Self::X509),
2 => Ok(Self::RawPublicKey),
_ => Err(TlsExtensionError::UnsupportedCertificateType(value)),
}
}
pub fn to_u8(self) -> u8 {
self as u8
}
pub fn is_raw_public_key(self) -> bool {
matches!(self, Self::RawPublicKey)
}
pub fn is_x509(self) -> bool {
matches!(self, Self::X509)
}
}
impl fmt::Display for CertificateType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::X509 => write!(f, "X.509"),
Self::RawPublicKey => write!(f, "RawPublicKey"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CertificateTypeList {
pub types: Vec<CertificateType>,
}
impl CertificateTypeList {
pub fn new(types: Vec<CertificateType>) -> Result<Self, TlsExtensionError> {
if types.is_empty() {
return Err(TlsExtensionError::EmptyCertificateTypeList);
}
if types.len() > 255 {
return Err(TlsExtensionError::CertificateTypeListTooLong(types.len()));
}
let mut seen = std::collections::HashSet::new();
for cert_type in &types {
if !seen.insert(*cert_type) {
return Err(TlsExtensionError::DuplicateCertificateType(*cert_type));
}
}
Ok(Self { types })
}
pub fn raw_public_key_only() -> Self {
Self {
types: vec![CertificateType::RawPublicKey],
}
}
pub fn prefer_raw_public_key() -> Self {
Self {
types: vec![CertificateType::RawPublicKey, CertificateType::X509],
}
}
pub fn x509_only() -> Self {
Self {
types: vec![CertificateType::X509],
}
}
pub fn most_preferred(&self) -> CertificateType {
self.types[0]
}
pub fn supports_raw_public_key(&self) -> bool {
self.types.contains(&CertificateType::RawPublicKey)
}
pub fn supports_x509(&self) -> bool {
self.types.contains(&CertificateType::X509)
}
pub fn negotiate(&self, other: &Self) -> Option<CertificateType> {
for cert_type in &self.types {
if other.types.contains(cert_type) {
return Some(*cert_type);
}
}
None
}
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(1 + self.types.len());
bytes.push(self.types.len() as u8);
for cert_type in &self.types {
bytes.push(cert_type.to_u8());
}
bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, TlsExtensionError> {
if bytes.is_empty() {
return Err(TlsExtensionError::InvalidExtensionData(
"Empty certificate type list".to_string(),
));
}
let length = bytes[0] as usize;
if length == 0 {
return Err(TlsExtensionError::EmptyCertificateTypeList);
}
if length > 255 {
return Err(TlsExtensionError::CertificateTypeListTooLong(length));
}
if bytes.len() != 1 + length {
return Err(TlsExtensionError::InvalidExtensionData(format!(
"Certificate type list length mismatch: expected {}, got {}",
1 + length,
bytes.len()
)));
}
let mut types = Vec::with_capacity(length);
for i in 1..=length {
let cert_type = CertificateType::from_u8(bytes[i])?;
types.push(cert_type);
}
Self::new(types)
}
}
pub mod extension_ids {
pub const CLIENT_CERTIFICATE_TYPE: u16 = 47;
pub const SERVER_CERTIFICATE_TYPE: u16 = 48;
}
#[derive(Debug, Clone)]
pub enum TlsExtensionError {
UnsupportedCertificateType(u8),
EmptyCertificateTypeList,
CertificateTypeListTooLong(usize),
DuplicateCertificateType(CertificateType),
InvalidExtensionData(String),
NegotiationFailed {
client_types: CertificateTypeList,
server_types: CertificateTypeList,
},
ExtensionAlreadyRegistered(u16),
RustlsError(String),
}
impl fmt::Display for TlsExtensionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::UnsupportedCertificateType(value) => {
write!(f, "Unsupported certificate type: {value}")
}
Self::EmptyCertificateTypeList => {
write!(f, "Certificate type list cannot be empty")
}
Self::CertificateTypeListTooLong(len) => {
write!(f, "Certificate type list too long: {len} (max 255)")
}
Self::DuplicateCertificateType(cert_type) => {
write!(f, "Duplicate certificate type: {cert_type}")
}
Self::InvalidExtensionData(msg) => {
write!(f, "Invalid extension data: {msg}")
}
Self::NegotiationFailed {
client_types,
server_types,
} => {
write!(
f,
"Certificate type negotiation failed: client={client_types:?}, server={server_types:?}"
)
}
Self::ExtensionAlreadyRegistered(id) => {
write!(f, "Extension already registered: {id}")
}
Self::RustlsError(msg) => {
write!(f, "rustls error: {msg}")
}
}
}
}
impl std::error::Error for TlsExtensionError {}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct NegotiationResult {
pub client_cert_type: CertificateType,
pub server_cert_type: CertificateType,
}
impl NegotiationResult {
pub fn new(client_cert_type: CertificateType, server_cert_type: CertificateType) -> Self {
Self {
client_cert_type,
server_cert_type,
}
}
pub fn is_raw_public_key_only(&self) -> bool {
self.client_cert_type.is_raw_public_key() && self.server_cert_type.is_raw_public_key()
}
pub fn is_x509_only(&self) -> bool {
self.client_cert_type.is_x509() && self.server_cert_type.is_x509()
}
pub fn is_mixed(&self) -> bool {
!self.is_raw_public_key_only() && !self.is_x509_only()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CertificateTypePreferences {
pub client_types: CertificateTypeList,
pub server_types: CertificateTypeList,
pub require_extensions: bool,
pub fallback_client: CertificateType,
pub fallback_server: CertificateType,
}
impl CertificateTypePreferences {
pub fn prefer_raw_public_key() -> Self {
Self {
client_types: CertificateTypeList::prefer_raw_public_key(),
server_types: CertificateTypeList::prefer_raw_public_key(),
require_extensions: false,
fallback_client: CertificateType::X509,
fallback_server: CertificateType::X509,
}
}
pub fn raw_public_key_only() -> Self {
Self {
client_types: CertificateTypeList::raw_public_key_only(),
server_types: CertificateTypeList::raw_public_key_only(),
require_extensions: true,
fallback_client: CertificateType::RawPublicKey,
fallback_server: CertificateType::RawPublicKey,
}
}
pub fn x509_only() -> Self {
Self {
client_types: CertificateTypeList::x509_only(),
server_types: CertificateTypeList::x509_only(),
require_extensions: false,
fallback_client: CertificateType::X509,
fallback_server: CertificateType::X509,
}
}
pub fn negotiate(
&self,
remote_client_types: Option<&CertificateTypeList>,
remote_server_types: Option<&CertificateTypeList>,
) -> Result<NegotiationResult, TlsExtensionError> {
let client_cert_type = if let Some(remote_types) = remote_client_types {
self.client_types.negotiate(remote_types).ok_or_else(|| {
TlsExtensionError::NegotiationFailed {
client_types: self.client_types.clone(),
server_types: remote_types.clone(),
}
})?
} else if self.require_extensions {
return Err(TlsExtensionError::NegotiationFailed {
client_types: self.client_types.clone(),
server_types: CertificateTypeList::x509_only(),
});
} else {
self.fallback_client
};
let server_cert_type = if let Some(remote_types) = remote_server_types {
self.server_types.negotiate(remote_types).ok_or_else(|| {
TlsExtensionError::NegotiationFailed {
client_types: self.server_types.clone(),
server_types: remote_types.clone(),
}
})?
} else if self.require_extensions {
return Err(TlsExtensionError::NegotiationFailed {
client_types: self.server_types.clone(),
server_types: CertificateTypeList::x509_only(),
});
} else {
self.fallback_server
};
Ok(NegotiationResult::new(client_cert_type, server_cert_type))
}
}
impl Default for CertificateTypePreferences {
fn default() -> Self {
Self::prefer_raw_public_key()
}
}
#[derive(Debug)]
pub struct NegotiationCache {
cache: HashMap<u64, NegotiationResult>,
max_size: usize,
}
impl NegotiationCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: HashMap::with_capacity(max_size.min(1000)),
max_size,
}
}
pub fn get(&self, key: u64) -> Option<&NegotiationResult> {
self.cache.get(&key)
}
pub fn insert(&mut self, key: u64, result: NegotiationResult) {
if self.cache.len() >= self.max_size {
if let Some(oldest_key) = self.cache.keys().next().copied() {
self.cache.remove(&oldest_key);
}
}
self.cache.insert(key, result);
}
pub fn clear(&mut self) {
self.cache.clear();
}
pub fn stats(&self) -> (usize, usize) {
(self.cache.len(), self.max_size)
}
}
impl Default for NegotiationCache {
fn default() -> Self {
Self::new(1000)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_certificate_type_conversion() {
assert_eq!(CertificateType::X509.to_u8(), 0);
assert_eq!(CertificateType::RawPublicKey.to_u8(), 2);
assert_eq!(CertificateType::from_u8(0).unwrap(), CertificateType::X509);
assert_eq!(
CertificateType::from_u8(2).unwrap(),
CertificateType::RawPublicKey
);
assert!(CertificateType::from_u8(1).is_err());
assert!(CertificateType::from_u8(255).is_err());
}
#[test]
fn test_certificate_type_list_creation() {
let list =
CertificateTypeList::new(vec![CertificateType::RawPublicKey, CertificateType::X509])
.unwrap();
assert_eq!(list.types.len(), 2);
assert_eq!(list.most_preferred(), CertificateType::RawPublicKey);
assert!(list.supports_raw_public_key());
assert!(list.supports_x509());
assert!(CertificateTypeList::new(vec![]).is_err());
assert!(
CertificateTypeList::new(vec![CertificateType::X509, CertificateType::X509]).is_err()
);
}
#[test]
fn test_certificate_type_list_serialization() {
let list = CertificateTypeList::prefer_raw_public_key();
let bytes = list.to_bytes();
assert_eq!(bytes, vec![2, 2, 0]);
let parsed = CertificateTypeList::from_bytes(&bytes).unwrap();
assert_eq!(parsed, list);
}
#[test]
fn test_certificate_type_list_negotiation() {
let rpk_only = CertificateTypeList::raw_public_key_only();
let prefer_rpk = CertificateTypeList::prefer_raw_public_key();
let x509_only = CertificateTypeList::x509_only();
assert_eq!(
rpk_only.negotiate(&prefer_rpk).unwrap(),
CertificateType::RawPublicKey
);
assert_eq!(
prefer_rpk.negotiate(&x509_only).unwrap(),
CertificateType::X509
);
assert!(rpk_only.negotiate(&x509_only).is_none());
}
#[test]
fn test_preferences_negotiation() {
let rpk_prefs = CertificateTypePreferences::raw_public_key_only();
let mixed_prefs = CertificateTypePreferences::prefer_raw_public_key();
let result = rpk_prefs
.negotiate(
Some(&mixed_prefs.client_types),
Some(&mixed_prefs.server_types),
)
.unwrap();
assert_eq!(result.client_cert_type, CertificateType::RawPublicKey);
assert_eq!(result.server_cert_type, CertificateType::RawPublicKey);
assert!(result.is_raw_public_key_only());
}
#[test]
fn test_negotiation_cache() {
let mut cache = NegotiationCache::new(2);
let result = NegotiationResult::new(CertificateType::RawPublicKey, CertificateType::X509);
assert!(cache.get(123).is_none());
cache.insert(123, result.clone());
assert_eq!(cache.get(123).unwrap(), &result);
cache.insert(456, result.clone());
assert_eq!(cache.cache.len(), 2);
cache.insert(789, result.clone());
assert_eq!(cache.cache.len(), 2);
assert!(cache.get(456).is_some() || cache.get(789).is_some());
}
}