use crate::errors::{ErrorKind, Result};
#[derive(Debug, PartialEq)]
enum PemType {
EcPublic,
EcPrivate,
RsaPublic,
RsaPrivate,
EdPublic,
EdPrivate,
}
#[derive(Debug, PartialEq)]
enum Standard {
Pkcs1,
Pkcs8,
}
#[derive(Debug, PartialEq)]
enum Classification {
Ec,
Ed,
Rsa,
}
#[derive(Debug)]
pub(crate) struct PemEncodedKey {
content: Vec<u8>,
asn1: Vec<simple_asn1::ASN1Block>,
pem_type: PemType,
standard: Standard,
}
impl PemEncodedKey {
pub fn new(input: &[u8]) -> Result<PemEncodedKey> {
match pem::parse(input) {
Ok(content) => {
let asn1_content = match simple_asn1::from_der(content.contents()) {
Ok(asn1) => asn1,
Err(_) => return Err(ErrorKind::InvalidKeyFormat.into()),
};
match content.tag() {
"RSA PRIVATE KEY" => Ok(PemEncodedKey {
content: content.into_contents(),
asn1: asn1_content,
pem_type: PemType::RsaPrivate,
standard: Standard::Pkcs1,
}),
"RSA PUBLIC KEY" => Ok(PemEncodedKey {
content: content.into_contents(),
asn1: asn1_content,
pem_type: PemType::RsaPublic,
standard: Standard::Pkcs1,
}),
tag @ "PRIVATE KEY" | tag @ "PUBLIC KEY" | tag @ "CERTIFICATE" => {
match classify_pem(&asn1_content) {
Some(c) => {
let is_private = tag == "PRIVATE KEY";
let pem_type = match c {
Classification::Ec => {
if is_private {
PemType::EcPrivate
} else {
PemType::EcPublic
}
}
Classification::Ed => {
if is_private {
PemType::EdPrivate
} else {
PemType::EdPublic
}
}
Classification::Rsa => {
if is_private {
PemType::RsaPrivate
} else {
PemType::RsaPublic
}
}
};
Ok(PemEncodedKey {
content: content.into_contents(),
asn1: asn1_content,
pem_type,
standard: Standard::Pkcs8,
})
}
None => Err(ErrorKind::InvalidKeyFormat.into()),
}
}
_ => Err(ErrorKind::InvalidKeyFormat.into()),
}
}
Err(_) => Err(ErrorKind::InvalidKeyFormat.into()),
}
}
pub fn as_ec_private_key(&self) -> Result<&[u8]> {
match self.standard {
Standard::Pkcs1 => Err(ErrorKind::InvalidKeyFormat.into()),
Standard::Pkcs8 => match self.pem_type {
PemType::EcPrivate => Ok(self.content.as_slice()),
_ => Err(ErrorKind::InvalidKeyFormat.into()),
},
}
}
pub fn as_ec_public_key(&self) -> Result<&[u8]> {
match self.standard {
Standard::Pkcs1 => Err(ErrorKind::InvalidKeyFormat.into()),
Standard::Pkcs8 => match self.pem_type {
PemType::EcPublic => extract_first_bitstring(&self.asn1),
_ => Err(ErrorKind::InvalidKeyFormat.into()),
},
}
}
pub fn as_ed_private_key(&self) -> Result<&[u8]> {
match self.standard {
Standard::Pkcs1 => Err(ErrorKind::InvalidKeyFormat.into()),
Standard::Pkcs8 => match self.pem_type {
PemType::EdPrivate => Ok(self.content.as_slice()),
_ => Err(ErrorKind::InvalidKeyFormat.into()),
},
}
}
pub fn as_ed_public_key(&self) -> Result<&[u8]> {
match self.standard {
Standard::Pkcs1 => Err(ErrorKind::InvalidKeyFormat.into()),
Standard::Pkcs8 => match self.pem_type {
PemType::EdPublic => extract_first_bitstring(&self.asn1),
_ => Err(ErrorKind::InvalidKeyFormat.into()),
},
}
}
pub fn as_rsa_key(&self) -> Result<&[u8]> {
match self.standard {
Standard::Pkcs1 => Ok(self.content.as_slice()),
Standard::Pkcs8 => match self.pem_type {
PemType::RsaPrivate => extract_first_bitstring(&self.asn1),
PemType::RsaPublic => extract_first_bitstring(&self.asn1),
_ => Err(ErrorKind::InvalidKeyFormat.into()),
},
}
}
}
fn extract_first_bitstring(asn1: &[simple_asn1::ASN1Block]) -> Result<&[u8]> {
for asn1_entry in asn1.iter() {
match asn1_entry {
simple_asn1::ASN1Block::Sequence(_, entries) => {
if let Ok(result) = extract_first_bitstring(entries) {
return Ok(result);
}
}
simple_asn1::ASN1Block::BitString(_, _, value) => {
return Ok(value.as_ref());
}
simple_asn1::ASN1Block::OctetString(_, value) => {
return Ok(value.as_ref());
}
_ => (),
}
}
Err(ErrorKind::InvalidEcdsaKey.into())
}
fn classify_pem(asn1: &[simple_asn1::ASN1Block]) -> Option<Classification> {
let ec_public_key_oid = simple_asn1::oid!(1, 2, 840, 10_045, 2, 1);
let rsa_public_key_oid = simple_asn1::oid!(1, 2, 840, 113_549, 1, 1, 1);
let ed25519_oid = simple_asn1::oid!(1, 3, 101, 112);
for asn1_entry in asn1.iter() {
match asn1_entry {
simple_asn1::ASN1Block::Sequence(_, entries) => {
if let Some(classification) = classify_pem(entries) {
return Some(classification);
}
}
simple_asn1::ASN1Block::ObjectIdentifier(_, oid) => {
if oid == ec_public_key_oid {
return Some(Classification::Ec);
}
if oid == rsa_public_key_oid {
return Some(Classification::Rsa);
}
if oid == ed25519_oid {
return Some(Classification::Ed);
}
}
_ => {}
}
}
None
}