use crate::crypto::{CryptoError, CryptoResult, defaults};
use argon2::{Argon2, Params, Algorithm, Version, PasswordHash, PasswordHasher, PasswordVerifier};
use chacha20poly1305::Key;
use rand::{RngCore, rngs::OsRng};
use zeroize::{Zeroize, ZeroizeOnDrop};
use subtle::ConstantTimeEq;
#[derive(Clone, ZeroizeOnDrop)]
pub struct DerivedKey {
key: Key,
salt: [u8; defaults::SALT_LENGTH],
}
impl DerivedKey {
pub fn from_password(password: &str, params: &KeyDerivationParams) -> CryptoResult<Self> {
let argon2 = Argon2::new(
defaults::ARGON2_ALGORITHM,
defaults::ARGON2_VERSION,
params.argon2_params.clone(),
);
let mut key_bytes = [0u8; defaults::KEY_LENGTH];
argon2
.hash_password_into(password.as_bytes(), ¶ms.salt, &mut key_bytes)
.map_err(CryptoError::from)?;
let key = Key::from_slice(&key_bytes).clone();
key_bytes.zeroize();
Ok(Self {
key,
salt: params.salt,
})
}
pub fn from_password_with_random_salt(password: &str) -> CryptoResult<Self> {
let params = KeyDerivationParams::new_random()?;
Self::from_password(password, ¶ms)
}
pub fn from_password_with_salt(password: &str, salt: &[u8]) -> CryptoResult<Self> {
if salt.len() != defaults::SALT_LENGTH {
return Err(CryptoError::invalid_salt(format!(
"Salt must be {} bytes, got {}",
defaults::SALT_LENGTH,
salt.len()
)));
}
let mut salt_array = [0u8; defaults::SALT_LENGTH];
salt_array.copy_from_slice(salt);
let params = KeyDerivationParams::from_salt(salt_array)?;
Self::from_password(password, ¶ms)
}
pub fn key(&self) -> &Key {
&self.key
}
pub fn salt(&self) -> &[u8; defaults::SALT_LENGTH] {
&self.salt
}
pub fn verify_password(&self, password: &str) -> CryptoResult<bool> {
let test_key = Self::from_password_with_salt(password, &self.salt)?;
let result = self.key.ct_eq(&test_key.key).into();
Ok(result)
}
pub fn to_hex(&self) -> String {
let mut combined = Vec::with_capacity(defaults::KEY_LENGTH + defaults::SALT_LENGTH);
combined.extend_from_slice(self.key.as_slice());
combined.extend_from_slice(&self.salt);
let hex_string = hex::encode(&combined);
combined.zeroize();
hex_string
}
pub fn from_hex(hex_str: &str) -> CryptoResult<Self> {
let bytes = hex::decode(hex_str)?;
if bytes.len() != defaults::KEY_LENGTH + defaults::SALT_LENGTH {
return Err(CryptoError::serialization(format!(
"Invalid hex length: expected {}, got {}",
defaults::KEY_LENGTH + defaults::SALT_LENGTH,
bytes.len()
)));
}
let key = Key::from_slice(&bytes[..defaults::KEY_LENGTH]).clone();
let mut salt = [0u8; defaults::SALT_LENGTH];
salt.copy_from_slice(&bytes[defaults::KEY_LENGTH..]);
Ok(Self { key, salt })
}
}
impl std::fmt::Debug for DerivedKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DerivedKey")
.field("key", &"[REDACTED]")
.field("salt", &hex::encode(&self.salt))
.finish()
}
}
#[derive(Debug, Clone)]
pub struct KeyDerivationParams {
pub argon2_params: Params,
pub salt: [u8; defaults::SALT_LENGTH],
}
impl KeyDerivationParams {
pub fn new_random() -> CryptoResult<Self> {
let mut salt = [0u8; defaults::SALT_LENGTH];
OsRng
.try_fill_bytes(&mut salt)
.map_err(|e| CryptoError::random_generation(e.to_string()))?;
Ok(Self {
argon2_params: defaults::ARGON2_PARAMS,
salt,
})
}
pub fn from_salt(salt: [u8; defaults::SALT_LENGTH]) -> CryptoResult<Self> {
Ok(Self {
argon2_params: defaults::ARGON2_PARAMS,
salt,
})
}
pub fn with_custom_params(
memory_cost: u32,
time_cost: u32,
parallelism: u32,
salt: [u8; defaults::SALT_LENGTH],
) -> CryptoResult<Self> {
let argon2_params = Params::new(
memory_cost,
time_cost,
parallelism,
Some(defaults::KEY_LENGTH),
)
.map_err(|e| CryptoError::key_derivation(e.to_string()))?;
Ok(Self {
argon2_params,
salt,
})
}
pub fn salt(&self) -> &[u8] {
&self.salt
}
}
pub struct SecureRandom;
impl SecureRandom {
pub fn generate_salt() -> CryptoResult<[u8; defaults::SALT_LENGTH]> {
let mut salt = [0u8; defaults::SALT_LENGTH];
OsRng
.try_fill_bytes(&mut salt)
.map_err(|e| CryptoError::random_generation(e.to_string()))?;
Ok(salt)
}
pub fn generate_nonce() -> CryptoResult<[u8; defaults::NONCE_LENGTH]> {
let mut nonce = [0u8; defaults::NONCE_LENGTH];
OsRng
.try_fill_bytes(&mut nonce)
.map_err(|e| CryptoError::random_generation(e.to_string()))?;
Ok(nonce)
}
pub fn generate_password(length: usize) -> CryptoResult<String> {
const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()_+-=[]{}|;:,.<>?";
let mut password = String::with_capacity(length);
let charset_len = CHARSET.len();
for _ in 0..length {
let mut idx = [0u8; 1];
loop {
OsRng
.try_fill_bytes(&mut idx)
.map_err(|e| CryptoError::random_generation(e.to_string()))?;
if (idx[0] as usize) < charset_len * (256 / charset_len) {
break;
}
}
password.push(CHARSET[idx[0] as usize % charset_len] as char);
}
Ok(password)
}
pub fn generate_bytes(length: usize) -> CryptoResult<Vec<u8>> {
let mut bytes = vec![0u8; length];
OsRng
.try_fill_bytes(&mut bytes)
.map_err(|e| CryptoError::random_generation(e.to_string()))?;
Ok(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_key_derivation_with_random_salt() {
let password = "test_password_123";
let key = DerivedKey::from_password_with_random_salt(password).unwrap();
assert!(key.verify_password(password).unwrap());
assert!(!key.verify_password("wrong_password").unwrap());
}
#[test]
fn test_key_derivation_deterministic() {
let password = "test_password_123";
let salt = [42u8; defaults::SALT_LENGTH];
let key1 = DerivedKey::from_password_with_salt(password, &salt).unwrap();
let key2 = DerivedKey::from_password_with_salt(password, &salt).unwrap();
assert_eq!(key1.key().as_slice(), key2.key().as_slice());
assert_eq!(key1.salt(), key2.salt());
}
#[test]
fn test_key_serialization() {
let password = "test_password_123";
let key1 = DerivedKey::from_password_with_random_salt(password).unwrap();
let hex_repr = key1.to_hex();
let key2 = DerivedKey::from_hex(&hex_repr).unwrap();
assert_eq!(key1.key().as_slice(), key2.key().as_slice());
assert_eq!(key1.salt(), key2.salt());
}
#[test]
fn test_secure_random_generation() {
let salt1 = SecureRandom::generate_salt().unwrap();
let salt2 = SecureRandom::generate_salt().unwrap();
assert_ne!(salt1, salt2);
assert_eq!(salt1.len(), defaults::SALT_LENGTH);
let nonce1 = SecureRandom::generate_nonce().unwrap();
let nonce2 = SecureRandom::generate_nonce().unwrap();
assert_ne!(nonce1, nonce2);
assert_eq!(nonce1.len(), defaults::NONCE_LENGTH);
}
#[test]
fn test_password_verification_constant_time() {
let password = "test_password_123";
let key = DerivedKey::from_password_with_random_salt(password).unwrap();
for _ in 0..10 {
assert!(key.verify_password(password).unwrap());
assert!(!key.verify_password("wrong").unwrap());
}
}
#[test]
fn test_key_params_creation() {
let params = KeyDerivationParams::new_random().unwrap();
assert_eq!(params.salt.len(), defaults::SALT_LENGTH);
let salt = [1u8; defaults::SALT_LENGTH];
let params2 = KeyDerivationParams::from_salt(salt).unwrap();
assert_eq!(params2.salt, salt);
}
#[test]
fn test_invalid_salt_length() {
let password = "test";
let invalid_salt = vec![1u8; 16];
let result = DerivedKey::from_password_with_salt(password, &invalid_salt);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), CryptoError::InvalidSalt { .. }));
}
}