#![deny(unsafe_code)]
#![deny(missing_docs)]
#![deny(clippy::unwrap_used)]
#![deny(clippy::panic)]
use crate::prelude::error::{LatticeArcError, Result};
use aws_lc_rs::hmac::{self, HMAC_SHA256};
use subtle::ConstantTimeEq;
use zeroize::Zeroizing;
pub struct CounterKdfResult {
key: Zeroizing<Vec<u8>>,
}
impl std::fmt::Debug for CounterKdfResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CounterKdfResult")
.field("key", &"[REDACTED]")
.field("key_length", &self.key.len())
.finish()
}
}
impl ConstantTimeEq for CounterKdfResult {
fn ct_eq(&self, other: &Self) -> subtle::Choice {
self.key.ct_eq(&*other.key)
}
}
impl CounterKdfResult {
#[must_use]
pub fn key(&self) -> &[u8] {
&self.key
}
#[must_use]
pub fn key_length(&self) -> usize {
self.key.len()
}
}
#[derive(Clone)]
pub struct CounterKdfParams {
pub label: Vec<u8>,
pub context: Vec<u8>,
}
impl std::fmt::Debug for CounterKdfParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CounterKdfParams")
.field("label", &format!("[{} bytes]", self.label.len()))
.field("context", &format!("[{} bytes]", self.context.len()))
.finish()
}
}
impl Default for CounterKdfParams {
fn default() -> Self {
Self { label: b"Default KDF Label".to_vec(), context: vec![] }
}
}
impl CounterKdfParams {
#[must_use]
pub fn new(label: &[u8]) -> Self {
Self { label: label.to_vec(), context: vec![] }
}
#[must_use]
pub fn with_context(mut self, context: &[u8]) -> Self {
self.context = context.to_vec();
self
}
#[must_use]
pub fn for_encryption() -> Self {
Self::new(b"Encryption Key")
}
#[must_use]
pub fn for_mac() -> Self {
Self::new(b"MAC Key")
}
#[must_use]
pub fn for_iv() -> Self {
Self::new(b"IV Generation")
}
}
pub fn counter_kdf(
ki: &[u8],
params: &CounterKdfParams,
key_length: usize,
) -> Result<CounterKdfResult> {
if ki.is_empty() {
return Err(LatticeArcError::InvalidParameter(
"Keying material must not be empty".to_string(),
));
}
if key_length == 0 {
return Err(LatticeArcError::InvalidParameter(
"Key length must be greater than 0".to_string(),
));
}
const HASH_LEN: usize = 32;
#[allow(clippy::arithmetic_side_effects)]
let max_len = (1u64 << 32) * HASH_LEN as u64;
if key_length > usize::try_from(max_len).unwrap_or(usize::MAX) {
return Err(LatticeArcError::InvalidParameter(format!(
"Key length {} exceeds maximum of {}",
key_length, max_len
)));
}
let iterations = key_length.div_ceil(HASH_LEN);
let l_bits = u32::try_from(key_length.saturating_mul(8)).map_err(|_e| {
LatticeArcError::InvalidParameter("Key length too large for bit representation".to_string())
})?;
let mut derived_key: Zeroizing<Vec<u8>> = Zeroizing::new(vec![0u8; key_length]);
let mut offset = 0;
let hmac_key = hmac::Key::new(HMAC_SHA256, ki);
let iterations_u32 = u32::try_from(iterations).map_err(|_e| {
LatticeArcError::InvalidParameter("Too many KDF iterations required".to_string())
})?;
for i in 1..=iterations_u32 {
let mut hmac_input: Zeroizing<Vec<u8>> = Zeroizing::new(Vec::new());
hmac_input.extend_from_slice(&i.to_be_bytes());
hmac_input.extend_from_slice(¶ms.label);
hmac_input.push(0x00);
hmac_input.extend_from_slice(¶ms.context);
hmac_input.extend_from_slice(&l_bits.to_be_bytes());
let tag = hmac::sign(&hmac_key, &hmac_input);
let result_vec: Zeroizing<Vec<u8>> = Zeroizing::new(tag.as_ref().to_vec());
let copy_len = std::cmp::min(HASH_LEN, key_length.saturating_sub(offset));
let end_offset = offset.checked_add(copy_len).ok_or_else(|| {
LatticeArcError::InvalidParameter("KDF output offset overflow".to_string())
})?;
let dest_slice = derived_key.get_mut(offset..end_offset).ok_or_else(|| {
LatticeArcError::InvalidParameter("KDF output buffer overflow".to_string())
})?;
let src_slice = result_vec.get(..copy_len).ok_or_else(|| {
LatticeArcError::InvalidParameter("KDF source slice out of bounds".to_string())
})?;
dest_slice.copy_from_slice(src_slice);
offset = end_offset;
}
Ok(CounterKdfResult { key: derived_key })
}
pub fn derive_multiple_keys(
ki: &[u8],
context: &[u8],
key_specs: &[(&[u8], usize)],
) -> Result<Vec<CounterKdfResult>> {
let mut keys = Vec::with_capacity(key_specs.len());
for (label, length) in key_specs {
let params = CounterKdfParams::new(label).with_context(context);
let key = counter_kdf(ki, ¶ms, *length)?;
keys.push(key);
}
Ok(keys)
}
pub fn derive_encryption_key(ki: &[u8], context: &[u8]) -> Result<CounterKdfResult> {
let params = CounterKdfParams::for_encryption().with_context(context);
counter_kdf(ki, ¶ms, 32)
}
pub fn derive_mac_key(ki: &[u8], context: &[u8]) -> Result<CounterKdfResult> {
let params = CounterKdfParams::for_mac().with_context(context);
counter_kdf(ki, ¶ms, 32)
}
pub fn derive_iv(ki: &[u8], context: &[u8]) -> Result<CounterKdfResult> {
let params = CounterKdfParams::for_iv().with_context(context);
counter_kdf(ki, ¶ms, 16)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)] #[allow(clippy::indexing_slicing)] mod tests {
use super::*;
#[test]
fn test_counter_kdf_basic_succeeds() {
let ki = b"0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b";
let params = CounterKdfParams::new(b"Example Label");
let result = counter_kdf(ki.as_ref(), ¶ms, 32).unwrap();
assert_eq!(result.key.len(), 32);
assert_eq!(result.key_length(), 32);
}
#[test]
fn test_counter_kdf_deterministic_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Test Label");
let result1 = counter_kdf(ki, ¶ms, 32).unwrap();
let result2 = counter_kdf(ki, ¶ms, 32).unwrap();
assert_eq!(result1.key, result2.key);
}
#[test]
fn test_counter_kdf_different_labels_produce_distinct_keys_are_unique() {
let ki = b"test keying material";
let params1 = CounterKdfParams::new(b"Label 1");
let params2 = CounterKdfParams::new(b"Label 2");
let result1 = counter_kdf(ki, ¶ms1, 32).unwrap();
let result2 = counter_kdf(ki, ¶ms2, 32).unwrap();
assert_ne!(result1.key, result2.key);
}
#[test]
fn test_counter_kdf_different_contexts_produce_distinct_keys_are_unique() {
let ki = b"test keying material";
let params1 = CounterKdfParams::new(b"Label").with_context(b"Context 1");
let params2 = CounterKdfParams::new(b"Label").with_context(b"Context 2");
let result1 = counter_kdf(ki, ¶ms1, 32).unwrap();
let result2 = counter_kdf(ki, ¶ms2, 32).unwrap();
assert_ne!(result1.key, result2.key);
}
#[test]
fn test_counter_kdf_different_lengths_produce_correct_sizes_has_correct_size() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label");
let result16 = counter_kdf(ki, ¶ms, 16).unwrap();
let result32 = counter_kdf(ki, ¶ms, 32).unwrap();
let result64 = counter_kdf(ki, ¶ms, 64).unwrap();
assert_eq!(result16.key.len(), 16);
assert_eq!(result32.key.len(), 32);
assert_eq!(result64.key.len(), 64);
}
#[test]
fn test_counter_kdf_different_ki_produce_distinct_keys_are_unique() {
let params = CounterKdfParams::new(b"Label");
let result1 = counter_kdf(b"ki1", ¶ms, 32).unwrap();
let result2 = counter_kdf(b"ki2", ¶ms, 32).unwrap();
assert_ne!(result1.key, result2.key);
}
#[test]
fn test_counter_kdf_with_context_succeeds() {
let ki = b"test keying material";
let params_with_context = CounterKdfParams::new(b"Label").with_context(b"My Context");
let params_without_context = CounterKdfParams::new(b"Label");
let result1 = counter_kdf(ki, ¶ms_with_context, 32).unwrap();
let result2 = counter_kdf(ki, ¶ms_without_context, 32).unwrap();
assert_ne!(result1.key, result2.key);
}
#[test]
fn test_counter_kdf_validation_rejects_empty_ki_and_zero_length_fails() {
let params = CounterKdfParams::new(b"Label");
assert!(counter_kdf(b"", ¶ms, 32).is_err());
assert!(counter_kdf(b"ki", ¶ms, 0).is_err());
assert!(counter_kdf(b"ki", ¶ms, 32).is_ok());
assert!(counter_kdf(b"ki", ¶ms, 32).is_ok());
assert!(counter_kdf(b"ki", ¶ms, 33).is_ok());
assert!(counter_kdf(b"ki", ¶ms, 64).is_ok());
}
#[test]
fn test_derive_multiple_keys_succeeds_with_distinct_outputs_are_unique() {
let ki = b"master secret";
let context = b"my-app-v1";
let key_specs =
vec![("encryption".as_bytes(), 32), ("mac".as_bytes(), 32), ("iv".as_bytes(), 16)];
let keys = derive_multiple_keys(ki, context, &key_specs).unwrap();
assert_eq!(keys.len(), 3);
assert_eq!(keys[0].key.len(), 32);
assert_eq!(keys[1].key.len(), 32);
assert_eq!(keys[2].key.len(), 16);
assert_ne!(keys[0].key, keys[1].key);
assert_ne!(keys[1].key, keys[2].key);
assert_ne!(keys[0].key, keys[2].key);
}
#[test]
fn test_derive_encryption_key_returns_32_bytes_succeeds() {
let ki = b"master secret";
let context = b"my-app-v1";
let key = derive_encryption_key(ki, context).unwrap();
assert_eq!(key.key().len(), 32);
assert_eq!(key.key_length(), 32);
}
#[test]
fn test_derive_mac_key_returns_32_bytes_succeeds() {
let ki = b"master secret";
let context = b"my-app-v1";
let key = derive_mac_key(ki, context).unwrap();
assert_eq!(key.key().len(), 32);
assert_eq!(key.key_length(), 32);
}
#[test]
fn test_derive_iv_returns_16_bytes_succeeds() {
let ki = b"master secret";
let context = b"my-app-v1";
let iv = derive_iv(ki, context).unwrap();
assert_eq!(iv.key().len(), 16);
assert_eq!(iv.key_length(), 16);
}
#[test]
fn test_convenience_functions_are_unique() {
let ki = b"master secret";
let context = b"my-app-v1";
let enc_key = derive_encryption_key(ki, context).unwrap();
let mac_key = derive_mac_key(ki, context).unwrap();
let iv = derive_iv(ki, context).unwrap();
assert_ne!(enc_key.key, mac_key.key);
assert_ne!(mac_key.key(), &iv.key()[..16]);
}
#[test]
fn test_counter_kdf_empty_label_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"");
let result = counter_kdf(ki, ¶ms, 32).unwrap();
assert_eq!(result.key.len(), 32);
}
#[test]
fn test_counter_kdf_empty_context_matches_no_context_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label").with_context(b"");
let result1 = counter_kdf(ki, ¶ms, 32).unwrap();
let params2 = CounterKdfParams::new(b"Label");
let result2 = counter_kdf(ki, ¶ms2, 32).unwrap();
assert_eq!(result1.key, result2.key);
}
#[test]
fn test_counter_kdf_long_inputs_succeeds() {
let ki = vec![0u8; 256]; let label = vec![b'A'; 256]; let context = vec![0xFF; 256];
let params = CounterKdfParams::new(&label).with_context(&context);
let result = counter_kdf(&ki, ¶ms, 32).unwrap();
assert_eq!(result.key.len(), 32);
}
#[test]
fn test_counter_kdf_result_zeroize_on_drop_clears_memory_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label");
let key_bytes = {
let result = counter_kdf(ki, ¶ms, 32).unwrap();
let key_copy = result.key.clone();
drop(result);
key_copy
};
assert_eq!(key_bytes.len(), 32);
}
#[test]
fn test_default_params_has_expected_label_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::default();
assert_eq!(params.label, b"Default KDF Label");
assert!(params.context.is_empty());
let result = counter_kdf(ki, ¶ms, 32).unwrap();
assert_eq!(result.key.len(), 32);
}
#[test]
fn test_counter_kdf_large_output_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label");
let result = counter_kdf(ki, ¶ms, 100).unwrap();
assert_eq!(result.key.len(), 100);
}
#[test]
fn test_counter_kdf_result_key_accessor_returns_full_key_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label");
let result = counter_kdf(ki, ¶ms, 32).unwrap();
assert_eq!(result.key(), &result.key[..]);
assert_eq!(result.key().len(), 32);
}
#[test]
fn test_counter_kdf_params_for_encryption_has_correct_label_succeeds() {
let params = CounterKdfParams::for_encryption();
assert_eq!(params.label, b"Encryption Key");
assert!(params.context.is_empty());
}
#[test]
fn test_counter_kdf_params_for_mac_has_correct_label_succeeds() {
let params = CounterKdfParams::for_mac();
assert_eq!(params.label, b"MAC Key");
}
#[test]
fn test_counter_kdf_params_for_iv_has_correct_label_succeeds() {
let params = CounterKdfParams::for_iv();
assert_eq!(params.label, b"IV Generation");
}
#[test]
fn test_counter_kdf_params_debug_produces_redacted_output_succeeds() {
let params = CounterKdfParams::new(b"Test").with_context(b"ctx");
let debug = format!("{:?}", params);
assert!(debug.contains("CounterKdfParams"));
}
#[test]
fn test_counter_kdf_result_zeroize_clears_on_drop_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label");
let result = counter_kdf(ki, ¶ms, 32).unwrap();
let key_copy = result.key().to_vec();
assert_eq!(key_copy, result.key());
assert_eq!(result.key_length(), 32);
drop(result);
}
#[test]
fn test_counter_kdf_exact_hash_boundary_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label");
let result = counter_kdf(ki, ¶ms, 32).unwrap();
assert_eq!(result.key.len(), 32);
}
#[test]
fn test_counter_kdf_one_byte_over_boundary_succeeds() {
let ki = b"test keying material";
let params = CounterKdfParams::new(b"Label");
let result = counter_kdf(ki, ¶ms, 33).unwrap();
assert_eq!(result.key.len(), 33);
}
}