#![forbid(unsafe_code, missing_docs, missing_debug_implementations, warnings)]
#![doc(html_root_url = "https://docs.rs/rsa-der/0.2.0")]
use simple_asn1::{oid, ASN1Block, BigInt};
use std::fmt;
use std::fmt::{Display, Formatter};
pub fn public_key_to_der(n: &[u8], e: &[u8]) -> Vec<u8> {
let mut root_sequence = vec![];
let oid = oid!(1, 2, 840, 113_549, 1, 1, 1);
root_sequence.push(ASN1Block::Sequence(
0,
vec![ASN1Block::ObjectIdentifier(0, oid), ASN1Block::Null(0)],
));
let n_block = ASN1Block::Integer(0, BigInt::from_signed_bytes_be(n));
let e_block = ASN1Block::Integer(0, BigInt::from_signed_bytes_be(e));
let rsa_key_bits =
simple_asn1::to_der(&ASN1Block::Sequence(0, vec![n_block, e_block])).unwrap();
root_sequence.push(ASN1Block::BitString(
0,
rsa_key_bits.len() * 8,
rsa_key_bits,
));
simple_asn1::to_der(&ASN1Block::Sequence(0, root_sequence)).unwrap()
}
#[derive(Debug, Clone, PartialEq)]
pub enum Error {
InvalidDer(simple_asn1::ASN1DecodeErr),
BitStringNotFound,
SequenceNotFound,
ModulusNotFound,
ExponentNotFound,
InvalidSequenceLength,
}
type StdResult<T, E> = std::result::Result<T, E>;
pub type Result<T> = StdResult<T, Error>;
impl Display for Error {
fn fmt(&self, f: &mut Formatter) -> StdResult<(), fmt::Error> {
match self {
Error::InvalidDer(e) => e.fmt(f)?,
Error::BitStringNotFound => f.write_str("RSA bit string not found in ASN.1 blocks")?,
Error::SequenceNotFound => f.write_str("ASN.1 sequence not found")?,
Error::ModulusNotFound => f.write_str("ASN.1 public key modulus not found")?,
Error::ExponentNotFound => f.write_str("ASN.1 public key exponent not found")?,
Error::InvalidSequenceLength => {
f.write_str("ASN.1 sequence did not contain exactly two values")?
}
}
Ok(())
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::InvalidDer(e) => Some(e),
_ => None,
}
}
}
pub fn public_key_from_der(der: &[u8]) -> Result<(Vec<u8>, Vec<u8>)> {
let blocks = simple_asn1::from_der(der).map_err(Error::InvalidDer)?;
let mut bit_strings = Vec::with_capacity(1);
find_bit_string(&blocks, &mut bit_strings);
if bit_strings.is_empty() {
return Err(Error::BitStringNotFound);
}
let bit_string = &bit_strings[0];
let inner_asn = simple_asn1::from_der(bit_string).map_err(Error::InvalidDer)?;
let (n, e) = match &inner_asn[0] {
ASN1Block::Sequence(_, blocks) => {
if blocks.len() != 2 {
return Err(Error::InvalidSequenceLength);
}
let n = match &blocks[0] {
ASN1Block::Integer(_, n) => n,
_ => return Err(Error::ModulusNotFound),
};
let e = match &blocks[1] {
ASN1Block::Integer(_, e) => e,
_ => return Err(Error::ExponentNotFound),
};
(n, e)
}
_ => return Err(Error::SequenceNotFound),
};
Ok((n.to_bytes_be().1, e.to_bytes_be().1))
}
fn find_bit_string(blocks: &[ASN1Block], mut result: &mut Vec<Vec<u8>>) {
for block in blocks.iter() {
match block {
ASN1Block::BitString(_, _, bytes) => result.push(bytes.to_vec()),
ASN1Block::Sequence(_, blocks) => find_bit_string(&blocks[..], &mut result),
_ => (),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use openssl::rsa::Rsa;
#[test]
fn test_public_key_to_der() {
let key = Rsa::generate(2048).unwrap();
let bytes = public_key_to_der(&key.n().to_vec(), &key.e().to_vec());
let new_key = Rsa::public_key_from_der(&bytes).unwrap();
assert_eq!(key.n(), new_key.n());
assert_eq!(key.e(), new_key.e());
}
#[test]
fn test_public_key_from_der() {
let key = Rsa::generate(2048).unwrap();
let der = key.public_key_to_der().unwrap();
let (n, e) = public_key_from_der(&der).unwrap();
assert_eq!(key.n().to_vec(), n);
assert_eq!(key.e().to_vec(), e);
}
}