use alloc::string::String;
use alloc::vec::Vec;
use core::fmt;
use core::str::FromStr;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
use crate::{
error::ProtoResult,
rr::{RData, RecordData, RecordDataDecodable, RecordType},
serialize::{
binary::{
BinDecodable, BinDecoder, BinEncodable, BinEncoder, DecodeError, RDataEncoding,
Restrict, RestrictedMath,
},
txt::ParseError,
},
};
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum CertType {
Reserved0,
PKIX,
SPKI,
PGP,
IPKIX,
ISPKI,
IPGP,
ACPKIX,
IACPKIX,
URI,
OID,
Reserved255,
Unassigned(u16),
Experimental(u16),
Reserved65535,
}
impl From<u16> for CertType {
fn from(cert_type: u16) -> Self {
match cert_type {
0 => Self::Reserved0,
1 => Self::PKIX,
2 => Self::SPKI,
3 => Self::PGP,
4 => Self::IPKIX,
5 => Self::ISPKI,
6 => Self::IPGP,
7 => Self::ACPKIX,
8 => Self::IACPKIX,
9_u16..=252_u16 => Self::Unassigned(cert_type),
253 => Self::URI,
254 => Self::OID,
255 => Self::Reserved255,
256_u16..=65279_u16 => Self::Unassigned(cert_type),
65280_u16..=65534_u16 => Self::Experimental(cert_type),
65535 => Self::Reserved65535,
}
}
}
impl From<CertType> for u16 {
fn from(cert_type: CertType) -> Self {
match cert_type {
CertType::Reserved0 => 0,
CertType::PKIX => 1,
CertType::SPKI => 2,
CertType::PGP => 3,
CertType::IPKIX => 4,
CertType::ISPKI => 5,
CertType::IPGP => 6,
CertType::ACPKIX => 7,
CertType::IACPKIX => 8,
CertType::URI => 253,
CertType::OID => 254,
CertType::Reserved255 => 255,
CertType::Unassigned(cert_type) => cert_type,
CertType::Experimental(cert_type) => cert_type,
CertType::Reserved65535 => 65535,
}
}
}
impl<'r> BinDecodable<'r> for CertType {
fn read(decoder: &mut BinDecoder<'r>) -> Result<Self, DecodeError> {
let algorithm_id = decoder
.read_u16()?
.unverified();
Ok(Self::from(algorithm_id))
}
}
impl fmt::Display for CertType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
pub enum Algorithm {
Reserved(u8),
RSAMD5,
DH,
DSA,
ECC,
RSASHA1,
INDIRECT,
PRIVATEDNS,
PRIVATEOID,
DSANSEC3SHA1,
RSASHA1NSEC3SHA1,
RSASHA256,
RSASHA512,
ECCGOST,
ECDSAP256SHA256,
ECDSAP384SHA384,
ED25519,
ED448,
SM2SM3,
ECCGOST12,
Unassigned(u8),
}
impl From<u8> for Algorithm {
fn from(algorithm: u8) -> Self {
match algorithm {
0 => Self::Reserved(0),
1 => Self::RSAMD5,
2 => Self::DH,
3 => Self::DSA,
4 => Self::ECC,
5 => Self::RSASHA1,
6 => Self::DSANSEC3SHA1,
7 => Self::RSASHA1NSEC3SHA1,
8 => Self::RSASHA256,
9 => Self::Reserved(9),
10 => Self::RSASHA512,
11 => Self::Reserved(11),
12 => Self::ECCGOST,
13 => Self::ECDSAP256SHA256,
14 => Self::ECDSAP384SHA384,
15 => Self::ED25519,
16 => Self::ED448,
17 => Self::SM2SM3,
18..=22 => Self::Unassigned(algorithm),
23 => Self::ECCGOST12,
24..=122 => Self::Unassigned(algorithm),
252 => Self::INDIRECT,
253 => Self::PRIVATEDNS,
254 => Self::PRIVATEOID,
_ => Self::Unassigned(algorithm),
}
}
}
impl From<Algorithm> for u8 {
fn from(algorithm: Algorithm) -> Self {
match algorithm {
Algorithm::Reserved(value) if value == 0 => value,
Algorithm::RSAMD5 => 1,
Algorithm::DH => 2,
Algorithm::DSA => 3,
Algorithm::ECC => 4,
Algorithm::RSASHA1 => 5,
Algorithm::DSANSEC3SHA1 => 6,
Algorithm::RSASHA1NSEC3SHA1 => 7,
Algorithm::RSASHA256 => 8,
Algorithm::Reserved(value) if value == 9 => value,
Algorithm::RSASHA512 => 10,
Algorithm::Reserved(value) if value == 11 => value,
Algorithm::ECCGOST => 12,
Algorithm::ECDSAP256SHA256 => 13,
Algorithm::ECDSAP384SHA384 => 14,
Algorithm::ED25519 => 15,
Algorithm::ED448 => 16,
Algorithm::SM2SM3 => 17,
Algorithm::Unassigned(value) if (18..=22).contains(&value) => value,
Algorithm::ECCGOST12 => 23,
Algorithm::Unassigned(value) if (24..=122).contains(&value) => value,
Algorithm::INDIRECT => 252,
Algorithm::PRIVATEDNS => 253,
Algorithm::PRIVATEOID => 254,
Algorithm::Unassigned(value) => value,
Algorithm::Reserved(value) => value,
}
}
}
impl<'r> BinDecodable<'r> for Algorithm {
fn read(decoder: &mut BinDecoder<'r>) -> Result<Self, DecodeError> {
let algorithm_id = decoder
.read_u8()?
.unverified();
Ok(Self::from(algorithm_id))
}
}
impl fmt::Display for Algorithm {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{self:?}")
}
}
#[cfg_attr(feature = "serde", derive(Deserialize, Serialize))]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
#[non_exhaustive]
pub struct CERT {
pub cert_type: CertType,
pub key_tag: u16,
pub algorithm: Algorithm,
pub cert_data: Vec<u8>,
}
impl CERT {
pub const fn new(
cert_type: CertType,
key_tag: u16,
algorithm: Algorithm,
cert_data: Vec<u8>,
) -> Self {
Self {
cert_type,
key_tag,
algorithm,
cert_data,
}
}
pub(crate) fn from_tokens<'i, I: Iterator<Item = &'i str>>(
tokens: I,
) -> Result<Self, ParseError> {
let mut iter = tokens;
let token = iter
.next()
.ok_or(ParseError::Message("CERT cert type field missing"))?;
let cert_type = CertType::from(
u16::from_str(token)
.map_err(|_| ParseError::Message("Invalid digit found in cert_type token"))?,
);
let token = iter
.next()
.ok_or(ParseError::Message("CERT key tag field missing"))?;
let key_tag = u16::from_str(token)
.map_err(|_| ParseError::Message("Invalid digit found in key_tag token"))?;
let token = iter
.next()
.ok_or(ParseError::Message("CERT algorithm field missing"))?;
let algorithm = Algorithm::from(
u8::from_str(token)
.map_err(|_| ParseError::Message("Invalid digit found in algorithm token"))?,
);
let token = iter
.next()
.ok_or(ParseError::Message("CERT data missing"))?;
let cert_data = data_encoding::BASE64
.decode(token.as_bytes())
.map_err(|_| ParseError::Message("Invalid base64 CERT data"))?;
Ok(Self::new(cert_type, key_tag, algorithm, cert_data))
}
pub fn cert_base64(&self) -> String {
data_encoding::BASE64.encode(&self.cert_data).clone()
}
}
impl TryFrom<&[u8]> for CERT {
type Error = DecodeError;
fn try_from(cert_record: &[u8]) -> Result<Self, Self::Error> {
let mut decoder = BinDecoder::new(cert_record);
let length = Restrict::new(cert_record.len() as u16); Self::read_data(&mut decoder, length) }
}
impl BinEncodable for CERT {
fn emit(&self, encoder: &mut BinEncoder<'_>) -> ProtoResult<()> {
let mut encoder = encoder.with_rdata_behavior(RDataEncoding::Other);
encoder.emit_u16(self.cert_type.into())?;
encoder.emit_u16(self.key_tag)?;
encoder.emit_u8(self.algorithm.into())?;
encoder.emit_vec(&self.cert_data)?;
Ok(())
}
}
impl<'r> RecordDataDecodable<'r> for CERT {
fn read_data(decoder: &mut BinDecoder<'r>, length: Restrict<u16>) -> Result<Self, DecodeError> {
let rdata_length = length.map(|u| u as usize).unverified();
if rdata_length <= 5 {
return Err(DecodeError::IncorrectRDataLengthRead {
read: rdata_length,
len: 6,
});
}
let start_idx = decoder.index();
let cert_type = CertType::read(decoder)?;
let key_tag = decoder.read_u16()?.unverified();
let algorithm = Algorithm::read(decoder)?;
let cert_len = length
.map(|u| u as usize)
.checked_sub(decoder.index() - start_idx)
.map_err(|len| DecodeError::IncorrectRDataLengthRead { read: decoder.index() - start_idx, len })?
.unverified();
let cert_data = decoder.read_vec(cert_len)?.unverified();
Ok(Self {
cert_type,
key_tag,
algorithm,
cert_data,
})
}
}
impl RecordData for CERT {
fn try_borrow(data: &RData) -> Option<&Self> {
match data {
RData::CERT(data) => Some(data),
_ => None,
}
}
fn record_type(&self) -> RecordType {
RecordType::CERT
}
fn into_rdata(self) -> RData {
RData::CERT(self)
}
}
impl fmt::Display for CERT {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
let cert_data = &data_encoding::BASE64.encode(&self.cert_data);
write!(
f,
"{cert_type} {key_tag} {algorithm} {cert_data}",
cert_type = self.cert_type,
key_tag = &self.key_tag,
algorithm = self.algorithm,
cert_data = &cert_data
)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::dbg_macro, clippy::print_stdout)]
use super::*;
#[test]
fn test_cert_type() {
assert_eq!(CertType::Reserved0, CertType::from(0));
assert_eq!(CertType::PKIX, CertType::from(1));
assert_eq!(CertType::SPKI, CertType::from(2));
assert_eq!(CertType::PGP, CertType::from(3));
assert_eq!(CertType::IPKIX, CertType::from(4));
assert_eq!(CertType::ISPKI, CertType::from(5));
assert_eq!(CertType::IPGP, CertType::from(6));
assert_eq!(CertType::ACPKIX, CertType::from(7));
assert_eq!(CertType::IACPKIX, CertType::from(8));
assert_eq!(CertType::URI, CertType::from(253));
assert_eq!(CertType::OID, CertType::from(254));
assert_eq!(CertType::Reserved255, CertType::from(255));
assert_eq!(CertType::Unassigned(9), CertType::from(9));
assert_eq!(CertType::Unassigned(90), CertType::from(90));
assert_eq!(CertType::Experimental(65280), CertType::from(65280));
assert_eq!(CertType::Experimental(65390), CertType::from(65390));
assert_eq!(CertType::Reserved65535, CertType::from(65535));
let cert_type_iana_9 = CertType::Unassigned(9);
let cert_type_iana_90 = CertType::Unassigned(90);
let cert_type_experimental_80 = CertType::Experimental(65280);
let cert_type_experimental_90 = CertType::Experimental(65290);
assert_eq!(u16::from(CertType::Reserved0), 0);
assert_eq!(u16::from(CertType::PKIX), 1);
assert_eq!(u16::from(CertType::SPKI), 2);
assert_eq!(u16::from(CertType::PGP), 3);
assert_eq!(u16::from(CertType::IPKIX), 4);
assert_eq!(u16::from(CertType::ISPKI), 5);
assert_eq!(u16::from(CertType::IPGP), 6);
assert_eq!(u16::from(CertType::ACPKIX), 7);
assert_eq!(u16::from(CertType::IACPKIX), 8);
assert_eq!(u16::from(cert_type_iana_9), 9);
assert_eq!(u16::from(cert_type_iana_90), 90);
assert_eq!(u16::from(CertType::URI), 253);
assert_eq!(u16::from(CertType::OID), 254);
assert_eq!(u16::from(CertType::Reserved255), 255);
assert_eq!(u16::from(cert_type_experimental_80), 65280);
assert_eq!(u16::from(cert_type_experimental_90), 65290);
assert_eq!(u16::from(CertType::Reserved65535), 65535);
}
#[test]
fn test_algorithm() {
assert_eq!(Algorithm::Reserved(0), Algorithm::from(0));
assert_eq!(Algorithm::DH, Algorithm::from(2));
assert_eq!(Algorithm::DSA, Algorithm::from(3));
assert_eq!(Algorithm::ECC, Algorithm::from(4));
assert_eq!(Algorithm::RSASHA1, Algorithm::from(5));
assert_eq!(Algorithm::DSANSEC3SHA1, Algorithm::from(6));
assert_eq!(Algorithm::RSASHA1NSEC3SHA1, Algorithm::from(7));
assert_eq!(Algorithm::RSASHA256, Algorithm::from(8));
assert_eq!(Algorithm::Reserved(9), Algorithm::from(9));
assert_eq!(Algorithm::RSASHA512, Algorithm::from(10));
assert_eq!(Algorithm::Reserved(11), Algorithm::from(11));
assert_eq!(Algorithm::ECCGOST, Algorithm::from(12));
assert_eq!(Algorithm::ECDSAP256SHA256, Algorithm::from(13));
assert_eq!(Algorithm::ECDSAP384SHA384, Algorithm::from(14));
assert_eq!(Algorithm::ED25519, Algorithm::from(15));
assert_eq!(Algorithm::ED448, Algorithm::from(16));
assert_eq!(Algorithm::SM2SM3, Algorithm::from(17));
assert_eq!(Algorithm::Unassigned(18), Algorithm::from(18));
assert_eq!(Algorithm::Unassigned(20), Algorithm::from(20));
assert_eq!(Algorithm::ECCGOST12, Algorithm::from(23));
assert_eq!(Algorithm::INDIRECT, Algorithm::from(252));
assert_eq!(Algorithm::PRIVATEDNS, Algorithm::from(253));
assert_eq!(Algorithm::PRIVATEOID, Algorithm::from(254));
let algorithm_reserved_0 = Algorithm::Reserved(0);
let algorithm_reserved_9 = Algorithm::Reserved(9);
assert_eq!(u8::from(algorithm_reserved_0), 0);
assert_eq!(u8::from(Algorithm::DH), 2);
assert_eq!(u8::from(Algorithm::DSA), 3);
assert_eq!(u8::from(Algorithm::ECC), 4);
assert_eq!(u8::from(Algorithm::RSASHA1), 5);
assert_eq!(u8::from(Algorithm::DSANSEC3SHA1), 6);
assert_eq!(u8::from(Algorithm::RSASHA1NSEC3SHA1), 7);
assert_eq!(u8::from(Algorithm::RSASHA256), 8);
assert_eq!(u8::from(Algorithm::Reserved(9)), 9);
assert_eq!(u8::from(Algorithm::RSASHA512), 10);
assert_eq!(u8::from(Algorithm::Reserved(11)), 11);
assert_eq!(u8::from(Algorithm::ECCGOST), 12);
assert_eq!(u8::from(Algorithm::ECDSAP256SHA256), 13);
assert_eq!(u8::from(Algorithm::ECDSAP384SHA384), 14);
assert_eq!(u8::from(Algorithm::ED25519), 15);
assert_eq!(u8::from(Algorithm::ED448), 16);
assert_eq!(u8::from(Algorithm::SM2SM3), 17);
assert_eq!(u8::from(Algorithm::Unassigned(18)), 18);
assert_eq!(u8::from(Algorithm::Unassigned(20)), 20);
assert_eq!(u8::from(Algorithm::ECCGOST12), 23);
assert_eq!(u8::from(Algorithm::INDIRECT), 252);
assert_eq!(u8::from(Algorithm::PRIVATEDNS), 253);
assert_eq!(u8::from(Algorithm::PRIVATEOID), 254);
assert_eq!(u8::from(algorithm_reserved_9), 9);
}
#[test]
fn test_valid_cert_data_length() {
let valid_cert_data = [1, 2, 3, 4, 5, 6]; let result = CERT::try_from(&valid_cert_data[..]);
assert!(
result.is_ok(),
"Expected a valid result with sufficient cert_data length"
);
}
#[test]
fn test_cert_creation() {
let cert_type = CertType::PKIX;
let key_tag = 12345;
let algorithm = Algorithm::RSASHA256; let cert_data = [1, 2, 3, 4, 5];
let cert = CERT {
cert_type,
key_tag,
algorithm,
cert_data: cert_data.to_vec(),
};
assert_eq!(cert.cert_type, cert_type);
assert_eq!(cert.key_tag, key_tag);
assert_eq!(cert.algorithm, algorithm);
assert_eq!(cert.cert_data, cert_data);
}
#[test]
fn test_cert_empty_cert_data() {
let cert_type = CertType::PKIX;
let key_tag = 12345;
let algorithm = Algorithm::RSASHA256;
let cert_data = Vec::new();
let cert = CERT {
cert_type,
key_tag,
algorithm,
cert_data,
};
assert_eq!(cert.cert_type, cert_type);
assert_eq!(cert.key_tag, key_tag);
assert_eq!(cert.algorithm, algorithm);
assert!(cert.cert_data.is_empty());
}
#[test]
fn test_valid_cert_record() {
let valid_cert_record = [
0x00, 0x01, 0x30, 0x39, 0x08, 65, 81, 73, 68, ];
let cert = CERT::try_from(&valid_cert_record[..]);
assert!(cert.is_ok(), "Expected valid cert_record");
let cert = cert.unwrap();
assert_eq!(cert.cert_type, CertType::PKIX);
assert_eq!(cert.key_tag, 12345);
assert_eq!(cert.algorithm, Algorithm::RSASHA256); assert_eq!(cert.cert_data, [65, 81, 73, 68]);
}
#[test]
fn test_invalid_cert_record_length() {
let invalid_cert_record = [1, 2, 3, 4];
let result = CERT::try_from(&invalid_cert_record[..]);
assert!(
matches!(
result,
Err(DecodeError::IncorrectRDataLengthRead { read: 4, len: 6 })
),
"Expected error due to invalid cert_record length, got {result:?}"
);
}
#[test]
fn test_valid_cert_data() {
let tokens = vec!["1", "123", "3", "Q2VydGlmaWNhdGUgZGF0YQ=="].into_iter();
let result = CERT::from_tokens(tokens);
assert!(result.is_ok());
let cert = result.unwrap();
assert_eq!(cert.cert_type, CertType::from(1));
assert_eq!(cert.key_tag, 123);
assert_eq!(cert.algorithm, Algorithm::from(3));
assert_eq!(cert.cert_data, b"Certificate data".to_vec()); }
#[test]
fn test_invalid_base64_data() {
let tokens = vec!["1", "123", "3", "Invalid_base64"].into_iter();
let result = CERT::from_tokens(tokens);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(format!("{err}"), "Invalid base64 CERT data");
}
#[test]
fn test_invalid_token_digit() {
let tokens = vec!["123", "3", "Q2VydGlmaWNhdGUgZGF0YQ=="].into_iter();
let result = CERT::from_tokens(tokens);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(format!("{err}"), "Invalid digit found in algorithm token");
}
#[test]
fn test_missing_cert_data() {
let tokens = vec!["1", "123", "3"].into_iter();
let result = CERT::from_tokens(tokens);
assert!(result.is_err());
let err = result.unwrap_err();
assert_eq!(format!("{err}"), "CERT data missing");
}
}