use crate::error::{Error, Result};
use sha2::{Digest, Sha256};
pub fn derive_key_ecdh_es(
shared_secret: &[u8],
apu: &[u8],
apv: &[u8],
key_data_len: usize,
) -> Result<Vec<u8>> {
if key_data_len == 0 || !key_data_len.is_multiple_of(8) {
return Err(Error::Cryptography(
"key_data_len must be a positive multiple of 8".to_string(),
));
}
let algorithm_id = b"ECDH-ES+A256KW";
let mut other_info = Vec::new();
other_info.extend_from_slice(&(algorithm_id.len() as u32).to_be_bytes());
other_info.extend_from_slice(algorithm_id);
other_info.extend_from_slice(&(apu.len() as u32).to_be_bytes());
other_info.extend_from_slice(apu);
other_info.extend_from_slice(&(apv.len() as u32).to_be_bytes());
other_info.extend_from_slice(apv);
other_info.extend_from_slice(&(key_data_len as u32).to_be_bytes());
let key_data_len_bytes = key_data_len / 8;
let hash_len = 32; let reps = key_data_len_bytes.div_ceil(hash_len);
let mut derived = Vec::with_capacity(key_data_len_bytes);
for counter in 1..=reps {
let mut hasher = Sha256::new();
hasher.update((counter as u32).to_be_bytes());
hasher.update(shared_secret);
hasher.update(&other_info);
derived.extend_from_slice(&hasher.finalize());
}
derived.truncate(key_data_len_bytes);
Ok(derived)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kdf_basic() {
let secret = [0x42u8; 32];
let result = derive_key_ecdh_es(&secret, b"", b"", 256);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 32);
}
#[test]
fn test_kdf_with_apu_apv() {
let secret = [0x42u8; 32];
let result = derive_key_ecdh_es(&secret, b"sender", b"recipient", 256);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 32);
}
#[test]
fn test_kdf_invalid_length() {
let secret = [0x42u8; 32];
assert!(derive_key_ecdh_es(&secret, b"", b"", 0).is_err());
assert!(derive_key_ecdh_es(&secret, b"", b"", 100).is_err());
}
#[test]
fn test_kdf_deterministic() {
let secret = [0x42u8; 32];
let k1 = derive_key_ecdh_es(&secret, b"a", b"b", 256).unwrap();
let k2 = derive_key_ecdh_es(&secret, b"a", b"b", 256).unwrap();
assert_eq!(k1, k2);
}
#[test]
fn test_kdf_different_inputs() {
let secret = [0x42u8; 32];
let k1 = derive_key_ecdh_es(&secret, b"a", b"b", 256).unwrap();
let k2 = derive_key_ecdh_es(&secret, b"a", b"c", 256).unwrap();
assert_ne!(k1, k2);
}
}