use std::sync::Arc;
use memsecurity::EncryptedMem;
use zeroize::Zeroizing;
use crate::error::Error;
#[derive(Clone)]
pub struct EncryptedKey {
key: Arc<EncryptedMem>,
}
impl EncryptedKey {
pub fn new(key: &[u8; 32]) -> Result<Self, Error> {
let mut encrypted_mem = EncryptedMem::new();
encrypted_mem.encrypt(key).map_err(|_| {
tracing::error!("Key encryption failed");
Error::Helper {
name: "encryption".to_owned(),
reason: "Failed to encrypt key".to_owned(),
}
})?;
tracing::debug!("EncryptedKey created");
Ok(Self {
key: Arc::new(encrypted_mem),
})
}
pub fn key(&self) -> Result<Zeroizing<[u8; 32]>, Error> {
let decrypted = self.key.decrypt().map_err(|_| {
tracing::error!(
"Key decryption failed, possible memory corruption"
);
Error::Helper {
name: "decryption".to_owned(),
reason: "Failed to decrypt key".to_owned(),
}
})?;
let mut key_array = Zeroizing::new([0u8; 32]);
key_array.copy_from_slice(decrypted.as_ref());
tracing::debug!("Key accessed");
Ok(key_array)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encrypted_hash() {
let key = [1u8; 32];
let eh = EncryptedKey::new(&key).unwrap();
let key1 = eh.key().unwrap();
let eh_cloned = eh.clone();
let key2 = eh_cloned.key().unwrap();
assert_eq!(*key1, *key2);
}
#[test]
fn test_encrypted_hash_deterministic() {
let key = [42u8; 32];
let eh1 = EncryptedKey::new(&key).unwrap();
let eh2 = EncryptedKey::new(&key).unwrap();
let key1 = eh1.key().unwrap();
let key2 = eh2.key().unwrap();
assert_eq!(*key1, *key2);
}
#[test]
fn test_encrypted_hash_different() {
let key1_data = [1u8; 32];
let key2_data = [2u8; 32];
let eh1 = EncryptedKey::new(&key1_data).unwrap();
let eh2 = EncryptedKey::new(&key2_data).unwrap();
let key1 = eh1.key().unwrap();
let key2 = eh2.key().unwrap();
assert_ne!(*key1, *key2);
}
#[test]
fn test_key_access() {
let original = [123u8; 32];
let eh = EncryptedKey::new(&original).unwrap();
let decrypted = eh.key().unwrap();
assert_eq!(*decrypted, original);
let decrypted2 = eh.key().unwrap();
assert_eq!(*decrypted2, original);
}
#[test]
fn test_zeroizing() {
let original = [99u8; 32];
let eh = EncryptedKey::new(&original).unwrap();
{
let key = eh.key().unwrap();
assert_eq!(*key, original);
}
let key2 = eh.key().unwrap();
assert_eq!(*key2, original);
}
#[test]
fn test_clone_is_cheap() {
let original = [77u8; 32];
let eh = EncryptedKey::new(&original).unwrap();
let eh_clone = eh.clone();
assert_eq!(Arc::strong_count(&eh.key), 2);
assert_eq!(Arc::strong_count(&eh_clone.key), 2);
let key1 = eh.key().unwrap();
let key2 = eh_clone.key().unwrap();
assert_eq!(*key1, *key2);
}
}