use aes_gcm::{
aead::{Aead, KeyInit},
Aes256Gcm, Nonce,
};
use rand::RngCore;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum KeyEncryptionError {
#[error("Encryption failed: {0}")]
EncryptionFailed(String),
#[error("Decryption failed: {0}")]
DecryptionFailed(String),
#[error("Invalid key length: expected 32 bytes, got {0}")]
InvalidKeyLength(usize),
#[error("Invalid encrypted data: too short")]
InvalidEncryptedData,
}
const NONCE_SIZE: usize = 12;
#[allow(deprecated)]
pub fn encrypt_private_key(
private_key: &[u8],
encryption_key: &[u8],
) -> Result<Vec<u8>, KeyEncryptionError> {
if encryption_key.len() != 32 {
return Err(KeyEncryptionError::InvalidKeyLength(encryption_key.len()));
}
let cipher = Aes256Gcm::new_from_slice(encryption_key)
.map_err(|e| KeyEncryptionError::EncryptionFailed(e.to_string()))?;
let mut nonce_bytes = [0u8; NONCE_SIZE];
rand::thread_rng().fill_bytes(&mut nonce_bytes);
let nonce = Nonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, private_key)
.map_err(|e| KeyEncryptionError::EncryptionFailed(e.to_string()))?;
let mut encrypted = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
encrypted.extend_from_slice(&nonce_bytes);
encrypted.extend_from_slice(&ciphertext);
Ok(encrypted)
}
#[allow(deprecated)]
pub fn decrypt_private_key(
encrypted_data: &[u8],
encryption_key: &[u8],
) -> Result<Vec<u8>, KeyEncryptionError> {
if encryption_key.len() != 32 {
return Err(KeyEncryptionError::InvalidKeyLength(encryption_key.len()));
}
if encrypted_data.len() < NONCE_SIZE + 17 {
return Err(KeyEncryptionError::InvalidEncryptedData);
}
let cipher = Aes256Gcm::new_from_slice(encryption_key)
.map_err(|e| KeyEncryptionError::DecryptionFailed(e.to_string()))?;
let nonce = Nonce::from_slice(&encrypted_data[..NONCE_SIZE]);
let ciphertext = &encrypted_data[NONCE_SIZE..];
let plaintext = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| KeyEncryptionError::DecryptionFailed(e.to_string()))?;
Ok(plaintext)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypt_decrypt_roundtrip() {
let private_key = [0x42u8; 32]; let encryption_key = [0x01u8; 32];
let encrypted = encrypt_private_key(&private_key, &encryption_key).unwrap();
assert!(encrypted.len() > private_key.len());
let decrypted = decrypt_private_key(&encrypted, &encryption_key).unwrap();
assert_eq!(decrypted, private_key);
}
#[test]
fn test_wrong_key_fails() {
let private_key = [0x42u8; 32];
let encryption_key = [0x01u8; 32];
let wrong_key = [0x02u8; 32];
let encrypted = encrypt_private_key(&private_key, &encryption_key).unwrap();
let result = decrypt_private_key(&encrypted, &wrong_key);
assert!(result.is_err());
}
#[test]
fn test_invalid_key_length() {
let private_key = [0x42u8; 32];
let short_key = [0x01u8; 16];
let result = encrypt_private_key(&private_key, &short_key);
assert!(matches!(
result,
Err(KeyEncryptionError::InvalidKeyLength(16))
));
}
#[test]
fn test_invalid_encrypted_data() {
let encryption_key = [0x01u8; 32];
let too_short = [0u8; 20];
let result = decrypt_private_key(&too_short, &encryption_key);
assert!(matches!(
result,
Err(KeyEncryptionError::InvalidEncryptedData)
));
}
#[test]
fn test_tampered_ciphertext_fails() {
let private_key = [0x42u8; 32];
let encryption_key = [0x01u8; 32];
let mut encrypted = encrypt_private_key(&private_key, &encryption_key).unwrap();
encrypted[NONCE_SIZE + 5] ^= 0xFF;
let result = decrypt_private_key(&encrypted, &encryption_key);
assert!(result.is_err());
}
}