#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use crate::prelude::error::{LatticeArcError, Result};
use aws_lc_rs::hmac::{self, HMAC_SHA256, HMAC_SHA512};
use subtle::ConstantTimeEq;
use zeroize::Zeroizing;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum PrfType {
HmacSha256,
HmacSha512,
}
#[derive(Debug, Clone)]
pub struct Pbkdf2Params {
pub salt: Vec<u8>,
pub iterations: u32,
pub key_length: usize,
pub prf: PrfType,
}
impl Pbkdf2Params {
pub fn new(salt_length: usize) -> Result<Self> {
if salt_length == 0 {
return Err(LatticeArcError::InvalidParameter(
"Salt length must be greater than 0".to_string(),
));
}
let mut salt = vec![0u8; salt_length];
get_random_bytes(&mut salt);
Ok(Self { salt, iterations: 600_000, key_length: 32, prf: PrfType::HmacSha256 })
}
#[must_use]
pub fn with_salt(salt: &[u8]) -> Self {
Self { salt: salt.to_vec(), iterations: 600_000, key_length: 32, prf: PrfType::HmacSha256 }
}
#[must_use]
pub fn iterations(mut self, iterations: u32) -> Self {
self.iterations = iterations;
self
}
#[must_use]
pub fn key_length(mut self, key_length: usize) -> Self {
self.key_length = key_length;
self
}
#[must_use]
pub fn prf(mut self, prf: PrfType) -> Self {
self.prf = prf;
self
}
}
pub struct Pbkdf2Result {
key: Zeroizing<Vec<u8>>,
params: Pbkdf2Params,
}
impl std::fmt::Debug for Pbkdf2Result {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Pbkdf2Result")
.field("key", &"[REDACTED]")
.field("params", &self.params)
.finish()
}
}
impl ConstantTimeEq for Pbkdf2Result {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.key.ct_eq(&*other.key)
}
}
impl Pbkdf2Result {
#[must_use]
pub fn key(&self) -> &[u8] {
&self.key
}
#[must_use]
pub fn key_length(&self) -> usize {
self.key.len()
}
#[must_use]
pub fn params(&self) -> &Pbkdf2Params {
&self.params
}
pub fn verify_password(&self, password: &[u8]) -> Result<bool> {
let derived = pbkdf2(password, &self.params)?;
let len_eq = self.key.len().ct_eq(&derived.key.len());
let bytes_eq = self
.key
.iter()
.zip(derived.key.iter())
.fold(subtle::Choice::from(1u8), |acc, (x, y)| acc & x.ct_eq(y));
Ok((len_eq & bytes_eq).into())
}
}
pub fn pbkdf2(password: &[u8], params: &Pbkdf2Params) -> Result<Pbkdf2Result> {
if params.salt.is_empty() {
return Err(LatticeArcError::InvalidParameter("Salt must not be empty".to_string()));
}
if params.salt.iter().all(|&b| b == 0) {
return Err(LatticeArcError::InvalidParameter(
"Salt must not be all zeros - use a cryptographically random salt".to_string(),
));
}
if params.iterations < 1000 {
return Err(LatticeArcError::InvalidParameter(
"Iteration count must be at least 1000".to_string(),
));
}
if params.iterations > 10_000_000 {
return Err(LatticeArcError::InvalidParameter(
"Iteration count must not exceed 10,000,000".to_string(),
));
}
if params.key_length == 0 {
return Err(LatticeArcError::InvalidParameter(
"Key length must be greater than 0".to_string(),
));
}
let prf_output_len = match params.prf {
PrfType::HmacSha256 => 32,
PrfType::HmacSha512 => 64,
};
let block_count = params.key_length.div_ceil(prf_output_len);
let mut derived_key: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0u8; params.key_length]);
let mut offset = 0;
for block_index in 1..=block_count {
let block_index_u32 = u32::try_from(block_index).map_err(|_e| {
LatticeArcError::InvalidParameter(format!(
"Block index {} exceeds u32::MAX",
block_index
))
})?;
let block =
generate_block(password, ¶ms.salt, params.iterations, block_index_u32, params.prf);
let copy_len = std::cmp::min(block.len(), params.key_length.saturating_sub(offset));
let end_offset = offset.checked_add(copy_len).ok_or_else(|| {
LatticeArcError::InvalidParameter("Derived key offset overflow".to_string())
})?;
let dest_slice = derived_key.get_mut(offset..end_offset).ok_or_else(|| {
LatticeArcError::InvalidParameter("Derived key buffer overflow".to_string())
})?;
let src_slice = block.get(..copy_len).ok_or_else(|| {
LatticeArcError::InvalidParameter("Block slice out of bounds".to_string())
})?;
dest_slice.copy_from_slice(src_slice);
offset = end_offset;
}
Ok(Pbkdf2Result { key: derived_key, params: params.clone() })
}
fn generate_block(
password: &[u8],
salt: &[u8],
iterations: u32,
block_index: u32,
prf: PrfType,
) -> Zeroizing<Vec<u8>> {
let mut block_input = salt.to_vec();
block_input.extend_from_slice(&block_index.to_be_bytes());
let mut u = compute_prf(password, &block_input, prf);
let mut result: Zeroizing<Vec<u8>> = Zeroizing::new(u.to_vec());
for _ in 1..iterations {
u = compute_prf(password, &u, prf);
for (res_byte, u_byte) in result.iter_mut().zip(u.iter()) {
*res_byte ^= u_byte;
}
}
result
}
fn compute_prf(password: &[u8], data: &[u8], prf: PrfType) -> Zeroizing<Vec<u8>> {
let algorithm = match prf {
PrfType::HmacSha256 => HMAC_SHA256,
PrfType::HmacSha512 => HMAC_SHA512,
};
let key = hmac::Key::new(algorithm, password);
let tag = hmac::sign(&key, data);
Zeroizing::new(tag.as_ref().to_vec())
}
pub fn pbkdf2_simple(password: &[u8]) -> Result<Pbkdf2Result> {
let params = Pbkdf2Params::new(16)?.iterations(600_000).key_length(32).prf(PrfType::HmacSha256);
pbkdf2(password, ¶ms)
}
pub fn verify_password(
password: &[u8],
derived_key: &[u8],
salt: &[u8],
iterations: u32,
) -> Result<bool> {
let params = Pbkdf2Params::with_salt(salt)
.iterations(iterations)
.key_length(derived_key.len())
.prf(PrfType::HmacSha256);
let result = pbkdf2(password, ¶ms)?;
let len_eq = derived_key.len().ct_eq(&result.key.len());
let bytes_eq = derived_key
.iter()
.zip(result.key.iter())
.fold(subtle::Choice::from(1u8), |acc, (x, y)| acc & x.ct_eq(y));
Ok((len_eq & bytes_eq).into())
}
use super::get_random_bytes;
#[cfg(test)]
#[allow(clippy::unwrap_used)] #[allow(clippy::panic_in_result_fn)] #[allow(clippy::indexing_slicing)] mod tests {
use super::*;
#[test]
fn test_pbkdf2_basic_roundtrip() -> std::result::Result<(), Box<dyn std::error::Error>> {
let password = b"password";
let salt = b"salt123456789012"; let params = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(32);
let result = pbkdf2(password, ¶ms)?;
assert_eq!(result.key.len(), 32);
let result2 = pbkdf2(password, ¶ms)?;
assert_eq!(result.key, result2.key);
Ok(())
}
#[test]
fn test_pbkdf2_different_passwords_produce_different_keys_succeeds()
-> std::result::Result<(), Box<dyn std::error::Error>> {
let salt = b"salt123456789012";
let params = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(32);
let result1 = pbkdf2(b"password1", ¶ms)?;
let result2 = pbkdf2(b"password2", ¶ms)?;
assert_ne!(result1.key, result2.key);
Ok(())
}
#[test]
fn test_pbkdf2_different_salts_produce_different_keys_succeeds()
-> std::result::Result<(), Box<dyn std::error::Error>> {
let params1 = Pbkdf2Params::with_salt(b"salt123456789012").iterations(1000).key_length(32);
let params2 = Pbkdf2Params::with_salt(b"salt223456789012").iterations(1000).key_length(32);
let result1 = pbkdf2(b"password", ¶ms1)?;
let result2 = pbkdf2(b"password", ¶ms2)?;
assert_ne!(result1.key, result2.key);
Ok(())
}
#[test]
fn test_pbkdf2_different_iterations_produce_different_keys_succeeds() {
let password = b"password";
let salt = b"salt123456789012";
let params1 = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(32);
let params2 = Pbkdf2Params::with_salt(salt).iterations(2000).key_length(32);
let result1 = pbkdf2(password, ¶ms1).unwrap();
let result2 = pbkdf2(password, ¶ms2).unwrap();
assert_ne!(result1.key, result2.key);
}
#[test]
fn test_pbkdf2_simple_produces_different_keys_with_different_salts_succeeds() {
let password = b"testpassword";
let result1 = pbkdf2_simple(password).unwrap();
let result2 = pbkdf2_simple(password).unwrap();
assert_ne!(result1.key, result2.key);
assert_eq!(result1.key.len(), 32);
assert_eq!(result2.key.len(), 32);
}
#[test]
fn test_password_verification_is_correct() {
let password = b"correctpassword";
let wrong_password = b"wrongpassword";
let result = pbkdf2_simple(password).unwrap();
assert!(result.verify_password(password).unwrap());
assert!(!result.verify_password(wrong_password).unwrap());
}
#[test]
fn test_verify_password_function_is_correct() {
let password = b"testpass";
let salt = b"1234567890123456";
let params = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(32);
let derived = pbkdf2(password, ¶ms).unwrap();
assert!(verify_password(password, &derived.key, salt, 1000).unwrap());
assert!(!verify_password(b"wrongpass", &derived.key, salt, 1000).unwrap());
assert!(!verify_password(password, &derived.key, b"wrongsalt123456", 1000).unwrap());
assert!(!verify_password(password, &derived.key, salt, 2000).unwrap());
}
#[test]
fn test_pbkdf2_validation_fails_for_invalid_params_fails() {
let password = b"pass";
let salt = b"salt";
let params_empty_salt = Pbkdf2Params::with_salt(b"").iterations(1000).key_length(32);
assert!(pbkdf2(password, ¶ms_empty_salt).is_err());
let params_zero_salt = Pbkdf2Params::with_salt(&[0u8; 16]).iterations(1000).key_length(32);
assert!(pbkdf2(password, ¶ms_zero_salt).is_err());
let params_low_iter = Pbkdf2Params::with_salt(salt).iterations(500).key_length(32);
assert!(pbkdf2(password, ¶ms_low_iter).is_err());
let params_zero_len = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(0);
assert!(pbkdf2(password, ¶ms_zero_len).is_err());
}
#[test]
fn test_prf_types_produce_correct_key_lengths_has_correct_size() {
let password = b"password";
let salt = b"salt123456789012";
let params_sha256 =
Pbkdf2Params::with_salt(salt).iterations(1000).key_length(32).prf(PrfType::HmacSha256);
let params_sha512 = Pbkdf2Params::with_salt(salt)
.iterations(1000)
.key_length(64) .prf(PrfType::HmacSha512);
let result_sha256 = pbkdf2(password, ¶ms_sha256).unwrap();
let result_sha512 = pbkdf2(password, ¶ms_sha512).unwrap();
assert_eq!(result_sha256.key.len(), 32);
assert_eq!(result_sha512.key.len(), 64);
assert_ne!(result_sha256.key(), &result_sha512.key()[..32]);
}
#[test]
fn test_zeroize_on_drop_succeeds() {
let password = b"password";
let salt = b"salt123456789012";
let params = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(32);
let key_bytes = {
let result = pbkdf2(password, ¶ms).unwrap();
let key_copy = result.key.clone();
drop(result);
key_copy
};
assert_eq!(key_bytes.len(), 32);
}
#[test]
fn test_pbkdf2_params_new_zero_salt_length_fails() {
let result = Pbkdf2Params::new(0);
assert!(result.is_err());
}
#[test]
fn test_pbkdf2_params_new_valid_has_correct_defaults_succeeds() {
let params = Pbkdf2Params::new(16).unwrap();
assert_eq!(params.salt.len(), 16);
assert_eq!(params.iterations, 600_000);
assert_eq!(params.key_length, 32);
assert_eq!(params.prf, PrfType::HmacSha256);
}
#[test]
fn test_pbkdf2_multi_block_sha256_has_correct_length_has_correct_size() {
let password = b"password";
let salt = b"salt123456789012";
let params = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(64);
let result = pbkdf2(password, ¶ms).unwrap();
assert_eq!(result.key.len(), 64);
}
#[test]
fn test_pbkdf2_multi_block_sha512_has_correct_length_has_correct_size() {
let password = b"password";
let salt = b"salt123456789012";
let params =
Pbkdf2Params::with_salt(salt).iterations(1000).key_length(128).prf(PrfType::HmacSha512);
let result = pbkdf2(password, ¶ms).unwrap();
assert_eq!(result.key.len(), 128);
}
#[test]
fn test_pbkdf2_result_key_accessor_returns_correct_value_succeeds() {
let password = b"password";
let salt = b"salt123456789012";
let params = Pbkdf2Params::with_salt(salt).iterations(1000).key_length(32);
let result = pbkdf2(password, ¶ms).unwrap();
assert_eq!(result.key(), &result.key[..]);
assert_eq!(result.key().len(), 32);
}
#[test]
fn test_pbkdf2_params_builder_chain_is_correct() {
let params = Pbkdf2Params::with_salt(b"saltsaltsaltsalt")
.iterations(5000)
.key_length(48)
.prf(PrfType::HmacSha512);
assert_eq!(params.iterations, 5000);
assert_eq!(params.key_length, 48);
assert_eq!(params.prf, PrfType::HmacSha512);
}
#[test]
fn test_prf_type_debug_clone_eq_is_correct() {
let prf = PrfType::HmacSha256;
let cloned = prf;
assert_eq!(prf, cloned);
assert_ne!(PrfType::HmacSha256, PrfType::HmacSha512);
let debug = format!("{:?}", prf);
assert!(debug.contains("HmacSha256"));
}
}