use rsa::{
Oaep, RsaPrivateKey, pkcs1::DecodeRsaPrivateKey, pkcs8::DecodePrivateKey,
traits::PublicKeyParts,
};
use sha2::Sha256;
use crate::encryption::EncryptionError;
const CEK_VERSION_BYTE: u8 = 0x01;
pub struct RsaKeyUnwrapper {
private_key: RsaPrivateKey,
}
impl RsaKeyUnwrapper {
pub fn from_pem(pem: &str) -> Result<Self, EncryptionError> {
let private_key = RsaPrivateKey::from_pkcs8_pem(pem)
.or_else(|_| RsaPrivateKey::from_pkcs1_pem(pem))
.map_err(|e| {
EncryptionError::CmkError(format!("Failed to parse RSA private key: {e}"))
})?;
Ok(Self { private_key })
}
pub fn from_der(der: &[u8]) -> Result<Self, EncryptionError> {
let private_key = RsaPrivateKey::from_pkcs8_der(der)
.or_else(|_| RsaPrivateKey::from_pkcs1_der(der))
.map_err(|e| {
EncryptionError::CmkError(format!("Failed to parse RSA private key: {e}"))
})?;
Ok(Self { private_key })
}
pub fn from_key(private_key: RsaPrivateKey) -> Self {
Self { private_key }
}
pub fn decrypt_cek(&self, encrypted_cek: &[u8]) -> Result<Vec<u8>, EncryptionError> {
let ciphertext = self.parse_encrypted_cek(encrypted_cek)?;
let padding = Oaep::new::<Sha256>();
let decrypted = self.private_key.decrypt(padding, ciphertext).map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("RSA-OAEP decryption failed: {e}"))
})?;
Ok(decrypted)
}
pub fn decrypt_raw(&self, ciphertext: &[u8]) -> Result<Vec<u8>, EncryptionError> {
let padding = Oaep::new::<Sha256>();
self.private_key.decrypt(padding, ciphertext).map_err(|e| {
EncryptionError::CekDecryptionFailed(format!("RSA-OAEP decryption failed: {e}"))
})
}
fn parse_encrypted_cek<'a>(&self, data: &'a [u8]) -> Result<&'a [u8], EncryptionError> {
if data.len() < 5 {
return Err(EncryptionError::CekDecryptionFailed(
"Encrypted CEK too short".into(),
));
}
if data[0] != CEK_VERSION_BYTE {
return Err(EncryptionError::CekDecryptionFailed(format!(
"Invalid CEK version: expected {:#04x}, got {:#04x}",
CEK_VERSION_BYTE, data[0]
)));
}
let key_path_len = u16::from_le_bytes([data[1], data[2]]) as usize;
let ciphertext_len_offset = 3 + key_path_len;
if data.len() < ciphertext_len_offset + 2 {
return Err(EncryptionError::CekDecryptionFailed(
"Encrypted CEK truncated: missing ciphertext length".into(),
));
}
let ciphertext_len =
u16::from_le_bytes([data[ciphertext_len_offset], data[ciphertext_len_offset + 1]])
as usize;
let ciphertext_offset = ciphertext_len_offset + 2;
if data.len() < ciphertext_offset + ciphertext_len {
return Err(EncryptionError::CekDecryptionFailed(format!(
"Encrypted CEK truncated: expected {} bytes of ciphertext, got {}",
ciphertext_len,
data.len() - ciphertext_offset
)));
}
Ok(&data[ciphertext_offset..ciphertext_offset + ciphertext_len])
}
pub fn key_bits(&self) -> usize {
self.private_key.size() * 8
}
}
#[cfg(test)]
pub fn create_test_encrypted_cek(key_path: &str, ciphertext: &[u8]) -> Vec<u8> {
let key_path_utf16: Vec<u8> = key_path
.encode_utf16()
.flat_map(|c| c.to_le_bytes())
.collect();
let mut result = Vec::new();
result.push(CEK_VERSION_BYTE);
let path_len = key_path_utf16.len() as u16;
result.extend_from_slice(&path_len.to_le_bytes());
result.extend_from_slice(&key_path_utf16);
let cipher_len = ciphertext.len() as u16;
result.extend_from_slice(&cipher_len.to_le_bytes());
result.extend_from_slice(ciphertext);
result
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use rsa::{RsaPrivateKey, pkcs8::EncodePrivateKey};
fn generate_test_key() -> RsaPrivateKey {
let mut rng = rand::thread_rng();
RsaPrivateKey::new(&mut rng, 2048).unwrap()
}
#[test]
fn test_key_unwrapper_from_pem_pkcs8() {
let key = generate_test_key();
let pem = key.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF).unwrap();
let unwrapper = RsaKeyUnwrapper::from_pem(&pem).unwrap();
assert_eq!(unwrapper.key_bits(), 2048);
}
#[test]
fn test_decrypt_raw() {
let key = generate_test_key();
let unwrapper = RsaKeyUnwrapper::from_key(key.clone());
let test_cek = [0x42u8; 32]; let public_key = key.to_public_key();
let padding = Oaep::new::<Sha256>();
let mut rng = rand::thread_rng();
let ciphertext = public_key.encrypt(&mut rng, padding, &test_cek).unwrap();
let decrypted = unwrapper.decrypt_raw(&ciphertext).unwrap();
assert_eq!(decrypted, test_cek);
}
#[test]
fn test_parse_encrypted_cek() {
let key = generate_test_key();
let unwrapper = RsaKeyUnwrapper::from_key(key.clone());
let test_ciphertext = vec![0xAB; 256]; let encrypted_cek = create_test_encrypted_cek("TestKeyPath", &test_ciphertext);
let extracted = unwrapper.parse_encrypted_cek(&encrypted_cek).unwrap();
assert_eq!(extracted, &test_ciphertext[..]);
}
#[test]
fn test_parse_encrypted_cek_invalid_version() {
let key = generate_test_key();
let unwrapper = RsaKeyUnwrapper::from_key(key);
let mut data = create_test_encrypted_cek("Test", &[0u8; 32]);
data[0] = 0x02;
let result = unwrapper.parse_encrypted_cek(&data);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Invalid CEK version")
);
}
#[test]
fn test_parse_encrypted_cek_too_short() {
let key = generate_test_key();
let unwrapper = RsaKeyUnwrapper::from_key(key);
let result = unwrapper.parse_encrypted_cek(&[0x01, 0x00]);
assert!(result.is_err());
}
#[test]
fn test_decrypt_cek_full_flow() {
let key = generate_test_key();
let unwrapper = RsaKeyUnwrapper::from_key(key.clone());
let test_cek = [0x55u8; 32];
let public_key = key.to_public_key();
let padding = Oaep::new::<Sha256>();
let mut rng = rand::thread_rng();
let rsa_ciphertext = public_key.encrypt(&mut rng, padding, &test_cek).unwrap();
let encrypted_cek = create_test_encrypted_cek("CurrentUser/My/TestCert", &rsa_ciphertext);
let decrypted = unwrapper.decrypt_cek(&encrypted_cek).unwrap();
assert_eq!(decrypted, test_cek);
}
#[test]
fn test_create_test_encrypted_cek() {
let ciphertext = vec![0x12, 0x34, 0x56, 0x78];
let encrypted = create_test_encrypted_cek("Test", &ciphertext);
assert_eq!(encrypted[0], 0x01);
let path_len = u16::from_le_bytes([encrypted[1], encrypted[2]]);
assert_eq!(path_len, 8);
assert_eq!(&encrypted[encrypted.len() - 4..], &ciphertext[..]);
}
}