use crate::{CryptoError, CryptoResult};
use argon2::{
password_hash::{PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2, Params, Version,
};
use chacha20poly1305::{
aead::{Aead, KeyInit, OsRng},
XChaCha20Poly1305, XNonce,
};
use rand::RngCore;
use std::fmt;
use zeroize::{Zeroize, ZeroizeOnDrop};
pub const KEK_SIZE: usize = 32;
pub const SALT_SIZE: usize = 16;
pub const NONCE_SIZE: usize = 24;
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct Kek {
key: [u8; KEK_SIZE],
}
impl Kek {
pub fn derive_from_password(password: &str, salt: &[u8]) -> CryptoResult<Self> {
if salt.len() != SALT_SIZE {
return Err(CryptoError::InvalidKeySize {
expected: SALT_SIZE,
actual: salt.len(),
});
}
let params = Params::new(65536, 3, 4, Some(KEK_SIZE))
.map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
let argon2 = Argon2::new(
argon2::Algorithm::Argon2id,
Version::V0x13,
params,
);
let mut key = [0u8; KEK_SIZE];
argon2
.hash_password_into(password.as_bytes(), salt, &mut key)
.map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
Ok(Self { key })
}
pub fn from_bytes(bytes: [u8; KEK_SIZE]) -> Self {
Self { key: bytes }
}
pub fn as_bytes(&self) -> &[u8; KEK_SIZE] {
&self.key
}
pub fn encrypt_dek(&self, dek: &[u8]) -> CryptoResult<Vec<u8>> {
let cipher = XChaCha20Poly1305::new(&self.key.into());
let mut nonce_bytes = [0u8; NONCE_SIZE];
OsRng.fill_bytes(&mut nonce_bytes);
let nonce = XNonce::from_slice(&nonce_bytes);
let ciphertext = cipher
.encrypt(nonce, dek)
.map_err(|e| CryptoError::Encryption(e.to_string()))?;
let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
result.extend_from_slice(&nonce_bytes);
result.extend_from_slice(&ciphertext);
Ok(result)
}
pub fn decrypt_dek(&self, encrypted_dek: &[u8]) -> CryptoResult<Vec<u8>> {
if encrypted_dek.len() < NONCE_SIZE {
return Err(CryptoError::Decryption(
"Encrypted DEK too short".to_string(),
));
}
let cipher = XChaCha20Poly1305::new(&self.key.into());
let (nonce_bytes, ciphertext) = encrypted_dek.split_at(NONCE_SIZE);
let nonce = XNonce::from_slice(nonce_bytes);
let dek = cipher
.decrypt(nonce, ciphertext)
.map_err(|e| CryptoError::Decryption(e.to_string()))?;
Ok(dek)
}
}
impl fmt::Debug for Kek {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Kek")
.field("key", &"[REDACTED]")
.finish()
}
}
pub fn generate_salt() -> [u8; SALT_SIZE] {
let mut salt = [0u8; SALT_SIZE];
OsRng.fill_bytes(&mut salt);
salt
}
pub fn hash_password(password: &str) -> CryptoResult<String> {
let salt = SaltString::generate(&mut OsRng);
let params = Params::new(65536, 3, 4, None)
.map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
let argon2 = Argon2::new(
argon2::Algorithm::Argon2id,
Version::V0x13,
params,
);
let hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| CryptoError::KeyDerivation(e.to_string()))?
.to_string();
Ok(hash)
}
pub fn verify_password(password: &str, hash: &str) -> CryptoResult<bool> {
let parsed_hash = PasswordHash::new(hash)
.map_err(|e| CryptoError::KeyDerivation(e.to_string()))?;
let argon2 = Argon2::default();
Ok(argon2
.verify_password(password.as_bytes(), &parsed_hash)
.is_ok())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kek_derive_from_password() {
let password = "super_secret_password";
let salt = generate_salt();
let kek1 = Kek::derive_from_password(password, &salt).unwrap();
let kek2 = Kek::derive_from_password(password, &salt).unwrap();
assert_eq!(kek1.as_bytes(), kek2.as_bytes());
let different_salt = generate_salt();
let kek3 = Kek::derive_from_password(password, &different_salt).unwrap();
assert_ne!(kek1.as_bytes(), kek3.as_bytes());
}
#[test]
fn test_encrypt_decrypt_dek() {
let password = "test_password";
let salt = generate_salt();
let kek = Kek::derive_from_password(password, &salt).unwrap();
let dek = b"this_is_a_32_byte_dek_key_123456";
let encrypted = kek.encrypt_dek(dek).unwrap();
assert!(encrypted.len() > dek.len());
let decrypted = kek.decrypt_dek(&encrypted).unwrap();
assert_eq!(&decrypted[..], &dek[..]);
}
#[test]
fn test_wrong_password() {
let salt = generate_salt();
let kek1 = Kek::derive_from_password("password1", &salt).unwrap();
let kek2 = Kek::derive_from_password("password2", &salt).unwrap();
let dek = b"sample_dek_key_32_bytes_long_123";
let encrypted = kek1.encrypt_dek(dek).unwrap();
let result = kek2.decrypt_dek(&encrypted);
assert!(result.is_err());
}
#[test]
fn test_password_hashing() {
let password = "my_secure_password";
let hash = hash_password(password).unwrap();
assert!(hash.starts_with("$argon2id$"));
assert!(verify_password(password, &hash).unwrap());
assert!(!verify_password("wrong_password", &hash).unwrap());
}
#[test]
fn test_salt_generation() {
let salt1 = generate_salt();
let salt2 = generate_salt();
assert_ne!(salt1, salt2);
assert_eq!(salt1.len(), SALT_SIZE);
}
}