use crate::errors::{CrabError, CrabResult};
use rand_core::OsRng;
use rsa::{
pkcs8::{DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey},
pss::{BlindedSigningKey, Signature, VerifyingKey},
sha2::Sha256,
signature::SignatureEncoding,
traits::PublicKeyParts,
Oaep, RsaPrivateKey, RsaPublicKey as RsaPubKey,
};
use zeroize::{ZeroizeOnDrop, Zeroizing};
#[derive(Clone, Debug, PartialEq)]
pub struct RsaSignature(Vec<u8>);
impl RsaSignature {
pub fn from_bytes(bytes: Vec<u8>) -> Self {
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn to_base64(&self) -> String {
crate::encoding::base64_encode(&self.0)
}
pub fn from_base64(data: &str) -> CrabResult<Self> {
let bytes = crate::encoding::base64_decode(data)?;
Ok(Self(bytes))
}
pub fn to_hex(&self) -> String {
hex::encode(&self.0)
}
pub fn from_hex(data: &str) -> CrabResult<Self> {
let bytes = hex::decode(data)
.map_err(|e| CrabError::encoding_error(format!("Invalid hex: {}", e)))?;
Ok(Self(bytes))
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
}
#[derive(Clone, Debug)]
pub struct RsaPublicKey(RsaPubKey);
impl RsaPublicKey {
pub fn from_pem(pem: &str) -> CrabResult<Self> {
let key = RsaPubKey::from_public_key_pem(pem)
.map_err(|e| CrabError::key_error(format!("Invalid RSA public key PEM: {}", e)))?;
Ok(Self(key))
}
pub fn to_pem(&self) -> CrabResult<String> {
self.0
.to_public_key_pem(rsa::pkcs8::LineEnding::LF)
.map_err(|e| CrabError::key_error(format!("Failed to encode public key: {}", e)))
}
pub fn encrypt(&self, plaintext: &[u8]) -> CrabResult<Vec<u8>> {
let padding = Oaep::new::<Sha256>();
let ciphertext = self
.0
.encrypt(&mut OsRng, padding, plaintext)
.map_err(|e| CrabError::crypto_error(format!("RSA encryption failed: {}", e)))?;
Ok(ciphertext)
}
pub fn verify(&self, message: &[u8], signature: &RsaSignature) -> CrabResult<bool> {
let verifying_key: VerifyingKey<Sha256> = VerifyingKey::new(self.0.clone());
let sig = Signature::try_from(signature.as_bytes())
.map_err(|_| CrabError::encoding_error("Invalid signature encoding"))?;
use rsa::signature::Verifier;
match verifying_key.verify(message, &sig) {
Ok(_) => Ok(true),
Err(_) => Ok(false),
}
}
pub fn size_bits(&self) -> usize {
self.0.size() * 8
}
pub fn size_bytes(&self) -> usize {
self.0.size()
}
pub fn to_public_key_der(&self) -> CrabResult<Vec<u8>> {
let der = self
.0
.to_public_key_der()
.map_err(|e| CrabError::key_error(format!("Failed to encode public key: {}", e)))?;
Ok(der.as_bytes().to_vec())
}
pub fn from_public_key_der(der: &[u8]) -> CrabResult<Self> {
let key = RsaPubKey::from_public_key_der(der)
.map_err(|e| CrabError::key_error(format!("Invalid RSA public key DER: {}", e)))?;
Ok(Self(key))
}
pub fn to_base64(&self) -> CrabResult<String> {
let der = self.to_public_key_der()?;
Ok(crate::encoding::base64_encode(&der))
}
pub fn from_base64(data: &str) -> CrabResult<Self> {
let der = crate::encoding::base64_decode(data)?;
Self::from_public_key_der(&der)
}
}
#[derive(ZeroizeOnDrop)]
pub struct RsaKeyPair {
pkcs8_der: Zeroizing<Vec<u8>>,
}
impl RsaKeyPair {
fn with_private_key<T, F>(&self, f: F) -> CrabResult<T>
where
F: FnOnce(&RsaPrivateKey) -> CrabResult<T>,
{
let private_key = RsaPrivateKey::from_pkcs8_der(&self.pkcs8_der)
.map_err(|e| CrabError::key_error(format!("Invalid private key DER: {}", e)))?;
f(&private_key)
}
pub fn generate(bits: usize) -> CrabResult<Self> {
if bits < 2048 {
return Err(CrabError::invalid_input("RSA key size must be at least 2048 bits"));
}
let private_key = RsaPrivateKey::new(&mut OsRng, bits)
.map_err(|e| CrabError::key_error(format!("Failed to generate RSA keypair: {}", e)))?;
let der = private_key
.to_pkcs8_der()
.map_err(|e| CrabError::key_error(format!("Failed to serialize private key: {}", e)))?
.as_bytes()
.to_vec();
Ok(Self {
pkcs8_der: Zeroizing::new(der),
})
}
pub fn generate_2048() -> CrabResult<Self> {
Self::generate(2048)
}
pub fn generate_4096() -> CrabResult<Self> {
Self::generate(4096)
}
pub fn from_pem(pem: &str) -> CrabResult<Self> {
let private_key = RsaPrivateKey::from_pkcs8_pem(pem)
.map_err(|e| CrabError::key_error(format!("Invalid RSA private key PEM: {}", e)))?;
let der = private_key
.to_pkcs8_der()
.map_err(|e| CrabError::key_error(format!("Failed to serialize private key: {}", e)))?
.as_bytes()
.to_vec();
Ok(Self {
pkcs8_der: Zeroizing::new(der),
})
}
pub fn to_pem(&self) -> CrabResult<String> {
self.with_private_key(|private_key| {
private_key
.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.map(|s| s.to_string())
.map_err(|e| CrabError::key_error(format!("Failed to encode private key: {}", e)))
})
}
pub fn public_key(&self) -> RsaPublicKey {
self.with_private_key(|private_key| Ok(RsaPublicKey(private_key.to_public_key())))
.expect("Failed to extract public key from valid private key")
}
pub fn encrypt(&self, plaintext: &[u8]) -> CrabResult<Vec<u8>> {
self.public_key().encrypt(plaintext)
}
pub fn decrypt(&self, ciphertext: &[u8]) -> CrabResult<Vec<u8>> {
self.with_private_key(|private_key| {
let padding = Oaep::new::<Sha256>();
let plaintext = private_key
.decrypt(padding, ciphertext)
.map_err(|e| CrabError::crypto_error(format!("RSA decryption failed: {}", e)))?;
Ok(plaintext)
})
}
pub fn sign(&self, message: &[u8]) -> CrabResult<RsaSignature> {
self.with_private_key(|private_key| {
let signing_key = BlindedSigningKey::<Sha256>::new(private_key.clone());
use rsa::signature::RandomizedSigner;
let signature = signing_key.sign_with_rng(&mut OsRng, message);
Ok(RsaSignature(signature.to_bytes().as_ref().to_vec()))
})
}
pub fn verify(&self, message: &[u8], signature: &RsaSignature) -> CrabResult<bool> {
self.public_key().verify(message, signature)
}
pub fn size_bits(&self) -> usize {
self.with_private_key(|private_key| Ok(private_key.size() * 8))
.expect("Failed to read key size from valid private key")
}
pub fn size_bytes(&self) -> usize {
self.with_private_key(|private_key| Ok(private_key.size()))
.expect("Failed to read key size from valid private key")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rsa_keygen_2048() {
let keypair = RsaKeyPair::generate_2048().unwrap();
assert_eq!(keypair.size_bits(), 2048);
}
#[test]
fn test_rsa_encrypt_decrypt_small() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let plaintext = b"Hello, RSA!";
let ciphertext = keypair.encrypt(plaintext).unwrap();
assert_ne!(ciphertext.as_slice(), plaintext);
let decrypted = keypair.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_rsa_encrypt_decrypt_max_size() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let plaintext = vec![0x42u8; 190];
let ciphertext = keypair.encrypt(&plaintext).unwrap();
let decrypted = keypair.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_rsa_encrypt_too_large_fails() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let plaintext = vec![0x42u8; 191];
assert!(keypair.encrypt(&plaintext).is_err());
}
#[test]
fn test_rsa_sign_verify() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let message = b"Important document to sign";
let signature = keypair.sign(message).unwrap();
assert!(keypair.verify(message, &signature).unwrap());
assert!(!keypair.verify(b"Different message", &signature).unwrap());
}
#[test]
fn test_rsa_sign_verify_with_public_key() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let public_key = keypair.public_key();
let message = b"Document to verify";
let signature = keypair.sign(message).unwrap();
assert!(public_key.verify(message, &signature).unwrap());
}
#[test]
fn test_rsa_pem_roundtrip() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let pem = keypair.to_pem().unwrap();
let restored = RsaKeyPair::from_pem(&pem).unwrap();
let message = b"Test message";
let signature = restored.sign(message).unwrap();
assert!(keypair.verify(message, &signature).unwrap());
}
#[test]
fn test_rsa_public_key_pem_roundtrip() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let public_key = keypair.public_key();
let pem = public_key.to_pem().unwrap();
let restored = RsaPublicKey::from_pem(&pem).unwrap();
let plaintext = b"Test";
let ciphertext = restored.encrypt(plaintext).unwrap();
let decrypted = keypair.decrypt(&ciphertext).unwrap();
assert_eq!(decrypted, plaintext);
}
#[test]
fn test_rsa_signature_encoding() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let message = b"Test";
let signature = keypair.sign(message).unwrap();
let base64 = signature.to_base64();
let restored = RsaSignature::from_base64(&base64).unwrap();
assert_eq!(signature, restored);
let hex = signature.to_hex();
let restored = RsaSignature::from_hex(&hex).unwrap();
assert_eq!(signature, restored);
}
#[test]
fn test_rsa_invalid_signature() {
let keypair = RsaKeyPair::generate_2048().unwrap();
let message = b"Test message";
let fake_signature = RsaSignature(vec![0u8; 256]);
assert!(!keypair.verify(message, &fake_signature).unwrap());
}
#[test]
fn test_rsa_key_too_small() {
let result = RsaKeyPair::generate(1024);
assert!(result.is_err());
}
}