use std::fmt;
use aws_lc_rs::encoding::AsDer;
use aws_lc_rs::rand::SystemRandom;
use aws_lc_rs::rsa::{KeyPair as RsaKeyPair, KeySize};
use aws_lc_rs::signature::{self, KeyPair as SignatureKeyPair, UnparsedPublicKey};
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[non_exhaustive]
pub enum RsaBits {
Rsa2048,
Rsa4096,
}
impl RsaBits {
#[must_use]
pub const fn as_u32(self) -> u32 {
match self {
Self::Rsa2048 => 2048,
Self::Rsa4096 => 4096,
}
}
const fn as_key_size(self) -> KeySize {
match self {
Self::Rsa2048 => KeySize::Rsa2048,
Self::Rsa4096 => KeySize::Rsa4096,
}
}
}
pub struct RsaSigningKey {
inner: RsaKeyPair,
pkcs8_der: Vec<u8>,
public_spki_der: Vec<u8>,
bits: u32,
}
impl RsaSigningKey {
pub fn generate(bits: RsaBits) -> Result<Self, Error> {
let pair = RsaKeyPair::generate(bits.as_key_size())
.map_err(|_| Error::KeyGeneration("RSA generation failed"))?;
let pkcs8_der = pair
.as_der()
.map_err(|_| Error::KeyGeneration("RSA PKCS#8 v1 serialisation failed"))?
.as_ref()
.to_vec();
Self::build(pair, pkcs8_der, bits.as_u32())
}
pub fn from_pkcs8_der(der: &[u8]) -> Result<Self, Error> {
let pair =
RsaKeyPair::from_pkcs8(der).map_err(|e| Error::InvalidPkcs8(format!("RSA: {e}")))?;
let bits = u32::try_from(pair.public_modulus_len() * 8).unwrap_or(u32::MAX);
if !(2048..=8192).contains(&bits) || bits % 256 != 0 {
return Err(Error::UnsupportedRsaSize(bits));
}
Self::build(pair, der.to_vec(), bits)
}
fn build(pair: RsaKeyPair, pkcs8_der: Vec<u8>, bits: u32) -> Result<Self, Error> {
let spki = pair
.public_key()
.as_der()
.map_err(|_| Error::Crypto("RSA SPKI serialisation failed"))?
.as_ref()
.to_vec();
Ok(Self {
inner: pair,
pkcs8_der,
public_spki_der: spki,
bits,
})
}
#[must_use]
pub fn to_pkcs8_der(&self) -> &[u8] {
&self.pkcs8_der
}
#[must_use]
pub const fn bits(&self) -> u32 {
self.bits
}
#[must_use]
pub fn public_key(&self) -> RsaPublicKey {
RsaPublicKey {
spki_der: self.public_spki_der.clone(),
}
}
pub fn sign(&self, message: &[u8]) -> Result<Vec<u8>, Error> {
let rng = SystemRandom::new();
let mut sig = vec![0u8; self.inner.public_modulus_len()];
self.inner
.sign(&signature::RSA_PKCS1_SHA256, &rng, message, &mut sig)
.map_err(|_| Error::Crypto("RSA PKCS#1 SHA-256 signing failed"))?;
Ok(sig)
}
}
impl fmt::Debug for RsaSigningKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RsaSigningKey")
.field("bits", &self.bits)
.finish_non_exhaustive()
}
}
#[derive(Clone)]
pub struct RsaPublicKey {
spki_der: Vec<u8>,
}
impl RsaPublicKey {
pub fn from_spki_der(der: &[u8]) -> Result<Self, Error> {
if der.is_empty() {
return Err(Error::InvalidPkcs8("empty SPKI DER".into()));
}
Ok(Self {
spki_der: der.to_vec(),
})
}
#[must_use]
pub fn as_spki_der(&self) -> &[u8] {
&self.spki_der
}
pub fn verify(&self, message: &[u8], signature_bytes: &[u8]) -> Result<(), Error> {
UnparsedPublicKey::new(
&signature::RSA_PKCS1_2048_8192_SHA256,
self.spki_der.as_slice(),
)
.verify(message, signature_bytes)
.map_err(|_| Error::VerificationFailed)
}
}
impl fmt::Debug for RsaPublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RsaPublicKey")
.field("spki_bytes", &self.spki_der.len())
.finish_non_exhaustive()
}
}
impl PartialEq for RsaPublicKey {
fn eq(&self, other: &Self) -> bool {
self.spki_der == other.spki_der
}
}
impl Eq for RsaPublicKey {}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
fn fresh_key() -> RsaSigningKey {
RsaSigningKey::generate(RsaBits::Rsa2048).expect("rng available")
}
#[test]
fn generate_then_sign_and_verify_roundtrips() {
let key = fresh_key();
let public = key.public_key();
let msg = b"ActivityPub inbox delivery";
let sig = key.sign(msg).expect("sign must succeed");
assert_eq!(sig.len(), 256);
public.verify(msg, &sig).expect("signature must verify");
}
#[test]
fn tampered_message_fails_verification() {
let key = fresh_key();
let public = key.public_key();
let sig = key.sign(b"original message").expect("sign");
let err = public
.verify(b"tampered message", &sig)
.expect_err("tampered message must not verify");
assert!(matches!(err, Error::VerificationFailed));
}
#[test]
fn pkcs8_roundtrip_preserves_key() {
let original = fresh_key();
let reloaded =
RsaSigningKey::from_pkcs8_der(original.to_pkcs8_der()).expect("reload must succeed");
assert_eq!(original.bits(), reloaded.bits());
let msg = b"cross-verify me";
let sig = reloaded.sign(msg).expect("sign");
original
.public_key()
.verify(msg, &sig)
.expect("reloaded key must produce signatures the original can verify");
}
#[test]
fn rsa_bits_reports_correct_width() {
let key = fresh_key();
assert_eq!(key.bits(), 2048);
}
}