use std::path::Path;
pub struct RsaKeyParts {
pub e: Vec<u8>,
pub p: Vec<u8>,
pub q: Vec<u8>,
pub u: Vec<u8>,
pub dp: Vec<u8>,
pub dq: Vec<u8>,
pub n: Vec<u8>,
}
impl Drop for RsaKeyParts {
fn drop(&mut self) {
use zeroize::Zeroize;
for buf in [
&mut self.e,
&mut self.p,
&mut self.q,
&mut self.u,
&mut self.dp,
&mut self.dq,
&mut self.n,
] {
buf.zeroize();
}
}
}
#[derive(Debug)]
pub enum RsaKeyError {
Io(std::io::Error),
Parse(String),
WrongSize(usize),
Crypto(String),
MissingComponent(&'static str),
}
impl std::fmt::Display for RsaKeyError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RsaKeyError::Io(e) => write!(f, "cannot read key file: {e}"),
RsaKeyError::Parse(e) => write!(f, "could not parse RSA private key: {e}"),
RsaKeyError::WrongSize(bits) => write!(
f,
"key is RSA-{bits}, but the card slot is RSA-2048; \
import only supports 2048-bit keys"
),
RsaKeyError::Crypto(e) => write!(f, "RSA operation failed: {e}"),
RsaKeyError::MissingComponent(c) => write!(f, "RSA key missing precomputed {c}"),
}
}
}
impl std::error::Error for RsaKeyError {}
impl From<std::io::Error> for RsaKeyError {
fn from(e: std::io::Error) -> Self {
RsaKeyError::Io(e)
}
}
pub fn generate_2048() -> Result<RsaKeyParts, RsaKeyError> {
let mut rng = rand::thread_rng();
let key =
rsa::RsaPrivateKey::new(&mut rng, 2048).map_err(|e| RsaKeyError::Crypto(e.to_string()))?;
parts_from_key(key)
}
pub fn load_from_file(path: &Path) -> Result<RsaKeyParts, RsaKeyError> {
let bytes = zeroize::Zeroizing::new(std::fs::read(path)?);
parts_from_encoded(&bytes)
}
fn parts_from_encoded(bytes: &[u8]) -> Result<RsaKeyParts, RsaKeyError> {
use rsa::pkcs1::DecodeRsaPrivateKey;
use rsa::pkcs8::DecodePrivateKey;
let key = if bytes.starts_with(b"-----BEGIN") {
let text = std::str::from_utf8(bytes)
.map_err(|_| RsaKeyError::Parse("key file is not valid PEM/UTF-8".into()))?;
rsa::RsaPrivateKey::from_pkcs8_pem(text)
.or_else(|_| rsa::RsaPrivateKey::from_pkcs1_pem(text))
.map_err(|e| RsaKeyError::Parse(e.to_string()))?
} else {
rsa::RsaPrivateKey::from_pkcs8_der(bytes)
.or_else(|_| rsa::RsaPrivateKey::from_pkcs1_der(bytes))
.map_err(|e| RsaKeyError::Parse(e.to_string()))?
};
parts_from_key(key)
}
fn parts_from_key(mut key: rsa::RsaPrivateKey) -> Result<RsaKeyParts, RsaKeyError> {
use rsa::traits::{PrivateKeyParts, PublicKeyParts};
let bits = key.n().bits();
if bits != 2048 {
return Err(RsaKeyError::WrongSize(bits));
}
key.precompute()
.map_err(|e| RsaKeyError::Crypto(e.to_string()))?;
let primes = key.primes();
if primes.len() != 2 {
return Err(RsaKeyError::Crypto("expected a 2-prime RSA key".into()));
}
let dp = key
.dp()
.ok_or(RsaKeyError::MissingComponent("dp"))?
.to_bytes_be();
let dq = key
.dq()
.ok_or(RsaKeyError::MissingComponent("dq"))?
.to_bytes_be();
let u = key
.qinv()
.ok_or(RsaKeyError::MissingComponent("qinv"))?
.to_bytes_be()
.1;
Ok(RsaKeyParts {
e: key.e().to_bytes_be(),
n: key.n().to_bytes_be(),
p: primes[0].to_bytes_be(),
q: primes[1].to_bytes_be(),
u,
dp,
dq,
})
}
#[cfg(test)]
mod tests {
use super::*;
use rsa::pkcs8::EncodePrivateKey;
#[test]
fn generate_2048_has_expected_shapes() {
let k = generate_2048().expect("keygen");
assert_eq!(k.n.len(), 256, "modulus should be 256 bytes");
assert_eq!(k.p.len(), 128, "p should be 128 bytes");
assert_eq!(k.q.len(), 128, "q should be 128 bytes");
assert_eq!(k.e, vec![0x01, 0x00, 0x01]);
assert!(!k.u.is_empty() && !k.dp.is_empty() && !k.dq.is_empty());
assert!(k.dp.len() <= 128 && k.dq.len() <= 128);
}
#[test]
fn load_round_trips_through_der() {
let mut rng = rand::thread_rng();
let key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keygen");
let der = key.to_pkcs8_der().expect("encode der");
let parsed = parts_from_encoded(der.as_bytes()).expect("parse der");
use rsa::traits::PublicKeyParts;
assert_eq!(parsed.n, key.n().to_bytes_be());
assert_eq!(parsed.e, key.e().to_bytes_be());
}
#[test]
fn rejects_non_2048() {
let mut rng = rand::thread_rng();
let key = rsa::RsaPrivateKey::new(&mut rng, 1024).expect("keygen");
let der = key.to_pkcs8_der().expect("encode der");
match parts_from_encoded(der.as_bytes()) {
Err(RsaKeyError::WrongSize(1024)) => {}
Err(e) => panic!("expected WrongSize(1024), got error: {e}"),
Ok(_) => panic!("expected WrongSize(1024), but parsing succeeded"),
}
}
#[test]
fn rejects_garbage() {
assert!(matches!(
parts_from_encoded(&[0xDE, 0xAD, 0xBE, 0xEF]),
Err(RsaKeyError::Parse(_))
));
}
}