use alloc::vec::Vec;
use hkdf::Hkdf;
use sha2::{Sha256, Sha512};
use sha3::Sha3_256;
use crate::error::CryptoError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum KdfAlgorithm {
HkdfSha2_256,
HkdfSha2_512,
HkdfSha3_256,
}
impl KdfAlgorithm {
pub fn name(&self) -> &'static str {
match self {
KdfAlgorithm::HkdfSha2_256 => "hkdf-sha2-256",
KdfAlgorithm::HkdfSha2_512 => "hkdf-sha2-512",
KdfAlgorithm::HkdfSha3_256 => "hkdf-sha3-256",
}
}
pub fn max_output_length(&self) -> usize {
match self {
KdfAlgorithm::HkdfSha2_256 => 255 * 32, KdfAlgorithm::HkdfSha2_512 => 255 * 64, KdfAlgorithm::HkdfSha3_256 => 255 * 32, }
}
pub fn derive(
&self,
ikm: impl AsRef<[u8]>,
salt: Option<&[u8]>,
info: impl AsRef<[u8]>,
output_length: usize,
) -> Result<Vec<u8>, CryptoError> {
let ikm = ikm.as_ref();
let info = info.as_ref();
if output_length > self.max_output_length() {
return Err(CryptoError::InvalidLength {
message: format!("Output length {} exceeds maximum {}", output_length, self.max_output_length()),
});
}
match self {
KdfAlgorithm::HkdfSha2_256 => {
let hk = Hkdf::<Sha256>::new(salt, ikm);
let mut okm = vec![0u8; output_length];
hk.expand(info, &mut okm)?;
Ok(okm)
}
KdfAlgorithm::HkdfSha2_512 => {
let hk = Hkdf::<Sha512>::new(salt, ikm);
let mut okm = vec![0u8; output_length];
hk.expand(info, &mut okm)?;
Ok(okm)
}
KdfAlgorithm::HkdfSha3_256 => {
let hk = Hkdf::<Sha3_256>::new(salt, ikm);
let mut okm = vec![0u8; output_length];
hk.expand(info, &mut okm)?;
Ok(okm)
}
}
}
pub fn expand_only(
&self,
prk: impl AsRef<[u8]>,
info: impl AsRef<[u8]>,
output_length: usize,
) -> Result<Vec<u8>, CryptoError> {
let prk = prk.as_ref();
let info = info.as_ref();
if output_length > self.max_output_length() {
return Err(CryptoError::InvalidLength {
message: format!("Output length {} exceeds maximum {}", output_length, self.max_output_length()),
});
}
match self {
KdfAlgorithm::HkdfSha2_256 => {
let hk = Hkdf::<Sha256>::from_prk(prk)?;
let mut okm = vec![0u8; output_length];
hk.expand(info, &mut okm)?;
Ok(okm)
}
KdfAlgorithm::HkdfSha2_512 => {
let hk = Hkdf::<Sha512>::from_prk(prk)?;
let mut okm = vec![0u8; output_length];
hk.expand(info, &mut okm)?;
Ok(okm)
}
KdfAlgorithm::HkdfSha3_256 => {
let hk = Hkdf::<Sha3_256>::from_prk(prk)?;
let mut okm = vec![0u8; output_length];
hk.expand(info, &mut okm)?;
Ok(okm)
}
}
}
pub fn derive_array<const N: usize>(
&self,
ikm: impl AsRef<[u8]>,
salt: Option<&[u8]>,
info: impl AsRef<[u8]>,
) -> Result<[u8; N], CryptoError> {
let okm = self.derive(ikm, salt, info, N)?;
let mut array = [0u8; N];
array.copy_from_slice(&okm);
Ok(array)
}
pub fn expand_only_array<const N: usize>(
&self,
prk: impl AsRef<[u8]>,
info: impl AsRef<[u8]>,
) -> Result<[u8; N], CryptoError> {
let okm = self.expand_only(prk, info, N)?;
let mut array = [0u8; N];
array.copy_from_slice(&okm);
Ok(array)
}
}
#[cfg(test)]
mod tests {
use super::*;
const ALL_ALGORITHMS: [KdfAlgorithm; 3] =
[KdfAlgorithm::HkdfSha2_256, KdfAlgorithm::HkdfSha2_512, KdfAlgorithm::HkdfSha3_256];
#[test]
fn test_kdf_algorithm_properties() {
let algorithms = ALL_ALGORITHMS;
for algo in algorithms {
assert!(!algo.name().is_empty());
assert!(algo.max_output_length() > 0);
for other in algorithms {
if algo != other {
assert_ne!(algo.name(), other.name());
}
}
}
}
#[test]
fn test_hkdf_derivation() -> Result<(), CryptoError> {
let ikm = b"test input key material";
let salt = Some(b"optional salt".as_slice());
let info = b"application info";
for algo in ALL_ALGORITHMS {
for &length in &[16, 32, 48, 64] {
let okm = algo.derive(ikm, salt, info, length)?;
assert_eq!(okm.len(), length);
if length > 16 {
let shorter = algo.derive(ikm, salt, info, 16)?;
assert_eq!(okm[..16], shorter[..]);
}
}
let no_salt = algo.derive(ikm, None, info, 32)?;
let with_salt = algo.derive(ikm, salt, info, 32)?;
assert_ne!(no_salt, with_salt);
let info1 = algo.derive(ikm, salt, b"info1", 32)?;
let info2 = algo.derive(ikm, salt, b"info2", 32)?;
assert_ne!(info1, info2);
let ikm1 = algo.derive(b"ikm1", salt, info, 32)?;
let ikm2 = algo.derive(b"ikm2", salt, info, 32)?;
assert_ne!(ikm1, ikm2);
}
Ok(())
}
#[test]
fn test_hkdf_array_derivation() -> Result<(), CryptoError> {
let ikm = b"test input key material";
let salt = Some(b"salt".as_slice());
let info = b"info";
for algo in ALL_ALGORITHMS {
let array: [u8; 32] = algo.derive_array(ikm, salt, info)?;
let vec_result = algo.derive(ikm, salt, info, 32)?;
assert_eq!(array.to_vec(), vec_result);
let array16: [u8; 16] = algo.derive_array(ikm, salt, info)?;
let array64: [u8; 64] = algo.derive_array(ikm, salt, info)?;
assert_eq!(array16[..], array[..16]);
assert_eq!(array[..], array64[..32]);
let array_diff_salt: [u8; 32] = algo.derive_array(ikm, None, info)?;
let array_diff_info: [u8; 32] = algo.derive_array(ikm, salt, b"different")?;
assert_ne!(array, array_diff_salt);
assert_ne!(array, array_diff_info);
}
Ok(())
}
#[test]
fn test_deterministic_derivation() -> Result<(), CryptoError> {
let ikm = b"test input key material";
let salt = Some(b"salt".as_slice());
let info = b"info";
for algo in ALL_ALGORITHMS {
let result1 = algo.derive(ikm, salt, info, 32)?;
let result2 = algo.derive(ikm, salt, info, 32)?;
let result3 = algo.derive(ikm, salt, info, 32)?;
assert_eq!(result1, result2);
assert_eq!(result2, result3);
}
Ok(())
}
#[test]
fn test_error_conditions() -> Result<(), CryptoError> {
let ikm = b"test";
let algo = KdfAlgorithm::HkdfSha2_256;
let max_len = algo.max_output_length();
let invalid_result = algo.derive(ikm, None, b"", max_len + 1);
assert!(matches!(invalid_result, Err(CryptoError::InvalidLength { .. })));
let zero_result = algo.derive(ikm, None, b"", 0)?;
assert!(zero_result.is_empty());
Ok(())
}
#[test]
fn test_algorithm_differences() -> Result<(), CryptoError> {
let ikm = b"test input key material";
let salt = Some(b"salt".as_slice());
let info = b"info";
let sha256_result = KdfAlgorithm::HkdfSha2_256.derive(ikm, salt, info, 32)?;
let sha512_result = KdfAlgorithm::HkdfSha2_512.derive(ikm, salt, info, 32)?;
assert_ne!(sha256_result, sha512_result);
Ok(())
}
#[test]
fn test_ecies_compatibility() -> Result<(), CryptoError> {
let ephemeral_pk = hex::decode("04abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef1234567890").map_err(|_| CryptoError::InvalidInput)?;
let shared_secret = hex::decode("1234567890abcdef1234567890abcdef1234567890abcdef1234567890abcdef12")
.map_err(|_| CryptoError::InvalidInput)?;
let mut combined = Vec::with_capacity(ephemeral_pk.len() + shared_secret.len());
combined.extend_from_slice(&ephemeral_pk);
combined.extend_from_slice(&shared_secret);
let derived_key = KdfAlgorithm::HkdfSha2_256.derive_array::<32>(&combined, None, b"")?;
let derived_key2 = KdfAlgorithm::HkdfSha2_256.derive(&combined, None, b"", 32)?;
assert_eq!(derived_key.to_vec(), derived_key2);
assert_eq!(derived_key.len(), 32);
Ok(())
}
}