use std::fs;
use std::path::Path;
use rsa::pkcs1::DecodeRsaPrivateKey;
use rsa::pkcs8::DecodePrivateKey;
use rsa::RsaPrivateKey;
use crate::crypto::CryptoError;
pub fn load_rsa_private_key(path: &Path) -> Result<RsaPrivateKey, CryptoError> {
let bytes = fs::read(path)?;
load_rsa_private_key_from_bytes(&bytes)
}
pub fn load_rsa_private_key_from_bytes(bytes: &[u8]) -> Result<RsaPrivateKey, CryptoError> {
let pem = std::str::from_utf8(bytes)
.map_err(|e| CryptoError::InvalidPem(format!("PEM is not UTF-8: {e}")))?;
if pem.contains("-----BEGIN RSA PRIVATE KEY-----") {
return RsaPrivateKey::from_pkcs1_pem(pem)
.map_err(|e| CryptoError::InvalidPem(format!("PKCS#1 parse failed: {e}")));
}
if pem.contains("-----BEGIN PRIVATE KEY-----") {
return RsaPrivateKey::from_pkcs8_pem(pem)
.map_err(|e| CryptoError::InvalidPem(format!("PKCS#8 parse failed: {e}")));
}
Err(CryptoError::InvalidPem(
"no RSA private key marker found".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::OsRng;
use rsa::pkcs1::EncodeRsaPrivateKey;
use rsa::pkcs8::{EncodePrivateKey, LineEnding};
use rsa::traits::PublicKeyParts;
fn fresh_key() -> RsaPrivateKey {
let mut rng = OsRng;
RsaPrivateKey::new(&mut rng, 2048).unwrap()
}
#[test]
fn pkcs1_round_trips() {
let rsa = fresh_key();
let pem = rsa.to_pkcs1_pem(LineEnding::LF).unwrap();
let loaded = load_rsa_private_key_from_bytes(pem.as_bytes()).unwrap();
assert_eq!(loaded.size(), rsa.size());
}
#[test]
fn pkcs8_round_trips() {
let rsa = fresh_key();
let pem = rsa.to_pkcs8_pem(LineEnding::LF).unwrap();
let loaded = load_rsa_private_key_from_bytes(pem.as_bytes()).unwrap();
assert!(loaded.size() > 0);
}
#[test]
fn unknown_marker_is_rejected() {
let bad = b"-----BEGIN GARBAGE-----\nabcd\n-----END GARBAGE-----\n";
let err = load_rsa_private_key_from_bytes(bad).unwrap_err();
assert!(matches!(err, CryptoError::InvalidPem(_)));
}
#[test]
fn malformed_pkcs1_returns_invalid_pem() {
let bad = b"-----BEGIN RSA PRIVATE KEY-----\nNOTBASE64\n-----END RSA PRIVATE KEY-----\n";
let err = load_rsa_private_key_from_bytes(bad).unwrap_err();
assert!(matches!(err, CryptoError::InvalidPem(_)));
}
}