use argon2::{
Algorithm, Argon2, Params, ParamsBuilder, PasswordHash, PasswordVerifier, Version,
password_hash::{PasswordHasher, SaltString},
};
use rand::Rng as _;
use thiserror::Error;
use zeroize::Zeroizing;
use crate::EncryptionKey;
#[allow(dead_code)]
const DEFAULT_M_COST: u32 = 65536;
#[allow(dead_code)]
const DEFAULT_T_COST: u32 = 3;
#[allow(dead_code)]
const DEFAULT_P_COST: u32 = 4;
const OUTPUT_LENGTH: usize = 32;
#[derive(Debug, Error)]
pub enum PbkdfError {
#[error("Invalid password")]
InvalidPassword,
#[error("Invalid salt")]
InvalidSalt,
#[error("Argon2 error: {0}")]
Argon2Error(String),
#[error("Invalid parameters: {0}")]
InvalidParams(String),
#[error("Hash verification failed")]
VerificationFailed,
}
#[derive(Debug, Clone, Copy)]
pub enum KeyDerivationStrength {
Fast,
Interactive,
Moderate,
Strong,
Paranoid,
}
impl KeyDerivationStrength {
fn params(&self) -> Result<Params, PbkdfError> {
let (m_cost, t_cost, p_cost) = match self {
Self::Fast => (8 * 1024, 1, 1), Self::Interactive => (64 * 1024, 3, 4), Self::Moderate => (256 * 1024, 4, 4), Self::Strong => (512 * 1024, 5, 4), Self::Paranoid => (1024 * 1024, 10, 8), };
ParamsBuilder::new()
.m_cost(m_cost)
.t_cost(t_cost)
.p_cost(p_cost)
.output_len(OUTPUT_LENGTH)
.build()
.map_err(|e| PbkdfError::Argon2Error(e.to_string()))
}
}
pub struct PasswordKeyDerivation {
params: Params,
}
impl Default for PasswordKeyDerivation {
fn default() -> Self {
Self::new(KeyDerivationStrength::Interactive)
}
}
impl PasswordKeyDerivation {
pub fn new(strength: KeyDerivationStrength) -> Self {
let params = strength.params().expect("Invalid parameters");
Self { params }
}
pub fn with_params(m_cost: u32, t_cost: u32, p_cost: u32) -> Result<Self, PbkdfError> {
let params = ParamsBuilder::new()
.m_cost(m_cost)
.t_cost(t_cost)
.p_cost(p_cost)
.output_len(OUTPUT_LENGTH)
.build()
.map_err(|e| PbkdfError::InvalidParams(e.to_string()))?;
Ok(Self { params })
}
pub fn derive_key(&self, password: &str) -> Result<(EncryptionKey, Vec<u8>), PbkdfError> {
if password.is_empty() {
return Err(PbkdfError::InvalidPassword);
}
let salt = SaltString::generate(&mut rand_core06::OsRng);
let key = self.derive_key_with_salt(password, salt.as_str())?;
Ok((key, salt.as_str().as_bytes().to_vec()))
}
pub fn derive_key_with_salt(
&self,
password: &str,
salt: &str,
) -> Result<EncryptionKey, PbkdfError> {
if password.is_empty() {
return Err(PbkdfError::InvalidPassword);
}
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, self.params.clone());
let password_bytes = Zeroizing::new(password.as_bytes().to_vec());
let salt_string = SaltString::from_b64(salt).map_err(|_| PbkdfError::InvalidSalt)?;
let hash = argon2
.hash_password(&password_bytes, &salt_string)
.map_err(|e| PbkdfError::Argon2Error(e.to_string()))?;
let hash_bytes = hash
.hash
.ok_or_else(|| PbkdfError::Argon2Error("No hash output".to_string()))?;
if hash_bytes.len() != OUTPUT_LENGTH {
return Err(PbkdfError::Argon2Error(format!(
"Invalid output length: {} (expected {})",
hash_bytes.len(),
OUTPUT_LENGTH
)));
}
let mut key = [0u8; 32];
key.copy_from_slice(hash_bytes.as_bytes());
Ok(key)
}
pub fn hash_password(&self, password: &str) -> Result<String, PbkdfError> {
if password.is_empty() {
return Err(PbkdfError::InvalidPassword);
}
let argon2 = Argon2::new(Algorithm::Argon2id, Version::V0x13, self.params.clone());
let salt = SaltString::generate(&mut rand_core06::OsRng);
let password_bytes = Zeroizing::new(password.as_bytes().to_vec());
let hash = argon2
.hash_password(&password_bytes, &salt)
.map_err(|e| PbkdfError::Argon2Error(e.to_string()))?;
Ok(hash.to_string())
}
pub fn verify_password(password: &str, hash: &str) -> Result<(), PbkdfError> {
if password.is_empty() {
return Err(PbkdfError::InvalidPassword);
}
let parsed_hash =
PasswordHash::new(hash).map_err(|e| PbkdfError::Argon2Error(e.to_string()))?;
let password_bytes = Zeroizing::new(password.as_bytes().to_vec());
Argon2::default()
.verify_password(&password_bytes, &parsed_hash)
.map_err(|_| PbkdfError::VerificationFailed)
}
}
pub fn derive_key_from_password(password: &str) -> Result<(EncryptionKey, Vec<u8>), PbkdfError> {
PasswordKeyDerivation::default().derive_key(password)
}
pub fn derive_key_with_salt(password: &str, salt: &str) -> Result<EncryptionKey, PbkdfError> {
PasswordKeyDerivation::default().derive_key_with_salt(password, salt)
}
pub fn generate_salt() -> Vec<u8> {
let mut salt = vec![0u8; 16];
rand::rng().fill_bytes(&mut salt);
salt
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_key_from_password() {
let password = "correct horse battery staple";
let (key1, salt) = derive_key_from_password(password).unwrap();
assert_eq!(key1.len(), 32);
assert!(!salt.is_empty());
let salt_str = std::str::from_utf8(&salt).unwrap();
let key2 = derive_key_with_salt(password, salt_str).unwrap();
assert_eq!(key1, key2);
}
#[test]
fn test_different_passwords_different_keys() {
let (key1, _) = derive_key_from_password("password1").unwrap();
let (key2, _) = derive_key_from_password("password2").unwrap();
assert_ne!(key1, key2);
}
#[test]
fn test_password_hashing() {
let pbkdf = PasswordKeyDerivation::default();
let password = "my secret password";
let hash = pbkdf.hash_password(password).unwrap();
assert!(hash.starts_with("$argon2id$"));
assert!(PasswordKeyDerivation::verify_password(password, &hash).is_ok());
assert!(PasswordKeyDerivation::verify_password("wrong password", &hash).is_err());
}
#[ignore = "slow: PBKDF2 strength-level benchmarking (~200s)"]
#[test]
fn test_strength_levels() {
let password = "test password";
for strength in &[
KeyDerivationStrength::Fast,
KeyDerivationStrength::Interactive,
KeyDerivationStrength::Moderate,
KeyDerivationStrength::Strong,
] {
let pbkdf = PasswordKeyDerivation::new(*strength);
let (key, salt) = pbkdf.derive_key(password).unwrap();
assert_eq!(key.len(), 32);
assert!(!salt.is_empty());
}
}
#[test]
fn test_empty_password() {
let result = derive_key_from_password("");
assert!(result.is_err());
}
#[test]
fn test_custom_params() {
let pbkdf = PasswordKeyDerivation::with_params(4096, 2, 1).unwrap();
let (key, _) = pbkdf.derive_key("test").unwrap();
assert_eq!(key.len(), 32);
}
#[test]
fn test_deterministic_derivation() {
let password = "test password";
let pbkdf = PasswordKeyDerivation::default();
let (_, salt1) = pbkdf.derive_key(password).unwrap();
let salt_str = std::str::from_utf8(&salt1).unwrap();
let key1 = pbkdf.derive_key_with_salt(password, salt_str).unwrap();
let key2 = pbkdf.derive_key_with_salt(password, salt_str).unwrap();
assert_eq!(key1, key2);
}
}