use crate::error::{CryptoKitError, Result};
pub mod hkdf_sha256;
pub mod hkdf_sha384;
pub mod hkdf_sha512;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HashAlgorithm {
SHA256,
SHA384,
SHA512,
}
impl HashAlgorithm {
pub fn output_length(&self) -> usize {
match self {
HashAlgorithm::SHA256 => 32,
HashAlgorithm::SHA384 => 48,
HashAlgorithm::SHA512 => 64,
}
}
pub fn max_hkdf_output_length(&self) -> usize {
255 * self.output_length()
}
}
pub trait KeyDerivationFunction {
fn derive(
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>>;
}
pub struct HKDF;
impl HKDF {
pub fn derive_key(
algorithm: HashAlgorithm,
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
if input_key_material.is_empty() {
return Err(CryptoKitError::InvalidInput(
"Input key material cannot be empty".to_string(),
));
}
if output_length == 0 || output_length > algorithm.max_hkdf_output_length() {
return Err(CryptoKitError::InvalidLength);
}
match algorithm {
HashAlgorithm::SHA256 => {
hkdf_sha256::HKDF_SHA256::derive(input_key_material, salt, info, output_length)
}
HashAlgorithm::SHA384 => {
hkdf_sha384::HKDF_SHA384::derive(input_key_material, salt, info, output_length)
}
HashAlgorithm::SHA512 => {
hkdf_sha512::HKDF_SHA512::derive(input_key_material, salt, info, output_length)
}
}
}
pub fn derive_key_sha256(
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
Self::derive_key(
HashAlgorithm::SHA256,
input_key_material,
salt,
info,
output_length,
)
}
pub fn derive_key_sha384(
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
Self::derive_key(
HashAlgorithm::SHA384,
input_key_material,
salt,
info,
output_length,
)
}
pub fn derive_key_sha512(
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
Self::derive_key(
HashAlgorithm::SHA512,
input_key_material,
salt,
info,
output_length,
)
}
pub fn derive_symmetric_key(
shared_secret: &[u8],
algorithm: HashAlgorithm,
salt: Option<&[u8]>,
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
let salt = salt.unwrap_or(&[]);
Self::derive_key(algorithm, shared_secret, salt, info, output_length)
}
}
pub use hkdf_sha256::HKDF_SHA256;
pub use hkdf_sha384::HKDF_SHA384;
pub use hkdf_sha512::HKDF_SHA512;
pub fn hkdf_sha256_derive(
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
hkdf_sha256::hkdf_sha256_derive(input_key_material, salt, info, output_length)
}
pub fn hkdf_sha384_derive(
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
hkdf_sha384::hkdf_sha384_derive(input_key_material, salt, info, output_length)
}
pub fn hkdf_sha512_derive(
input_key_material: &[u8],
salt: &[u8],
info: &[u8],
output_length: usize,
) -> Result<Vec<u8>> {
hkdf_sha512::hkdf_sha512_derive(input_key_material, salt, info, output_length)
}
pub mod key_sizes {
pub const AES_128: usize = 16;
pub const AES_192: usize = 24;
pub const AES_256: usize = 32;
pub const CHACHA20: usize = 32;
pub const HMAC_SHA256: usize = 32;
pub const HMAC_SHA384: usize = 48;
pub const HMAC_SHA512: usize = 64;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hash_algorithm_properties() {
assert_eq!(HashAlgorithm::SHA256.output_length(), 32);
assert_eq!(HashAlgorithm::SHA384.output_length(), 48);
assert_eq!(HashAlgorithm::SHA512.output_length(), 64);
assert_eq!(HashAlgorithm::SHA256.max_hkdf_output_length(), 255 * 32);
assert_eq!(HashAlgorithm::SHA384.max_hkdf_output_length(), 255 * 48);
assert_eq!(HashAlgorithm::SHA512.max_hkdf_output_length(), 255 * 64);
}
#[test]
fn test_invalid_input_validation() {
let result = HKDF::derive_key(HashAlgorithm::SHA256, &[], b"salt", b"info", 32);
assert!(matches!(result, Err(CryptoKitError::InvalidInput(_))));
let result = HKDF::derive_key(HashAlgorithm::SHA256, b"ikm", b"salt", b"info", 0);
assert!(matches!(result, Err(CryptoKitError::InvalidLength)));
let result = HKDF::derive_key(
HashAlgorithm::SHA256,
b"ikm",
b"salt",
b"info",
255 * 32 + 1,
);
assert!(matches!(result, Err(CryptoKitError::InvalidLength)));
}
}