use hkdf::Hkdf;
use sha2::Sha256;
use crate::error::{CryptoError, SrxError};
pub struct KeyDerivation;
impl KeyDerivation {
pub fn derive_initial_seed(
master_key: &[u8],
timestamp: u64,
session_nonce: &[u8],
) -> crate::error::Result<[u8; 32]> {
let mut salt = Vec::with_capacity(8 + session_nonce.len());
salt.extend_from_slice(×tamp.to_be_bytes());
salt.extend_from_slice(session_nonce);
let hk = Hkdf::<Sha256>::new(Some(&salt), master_key);
let mut seed = [0u8; 32];
hk.expand(b"srx-seed-init", &mut seed)
.map_err(|e| SrxError::Crypto(CryptoError::KdfFailed(e.to_string())))?;
Ok(seed)
}
pub fn rotate_seed(
current_seed: &[u8; 32],
traffic_stat_fragment: &[u8],
) -> crate::error::Result<[u8; 32]> {
let hk = Hkdf::<Sha256>::new(Some(traffic_stat_fragment), current_seed);
let mut new_seed = [0u8; 32];
hk.expand(b"srx-seed-rotate", &mut new_seed)
.map_err(|e| SrxError::Crypto(CryptoError::KdfFailed(e.to_string())))?;
Ok(new_seed)
}
pub fn derive_data_key(seed: &[u8; 32], key_index: u64) -> crate::error::Result<[u8; 32]> {
let salt = key_index.to_be_bytes();
let hk = Hkdf::<Sha256>::new(Some(&salt), seed);
let mut key = [0u8; 32];
hk.expand(b"srx-data-key", &mut key)
.map_err(|e| SrxError::Crypto(CryptoError::KdfFailed(e.to_string())))?;
Ok(key)
}
pub fn derive_nonce(seed: &[u8; 32], counter: u64) -> crate::error::Result<[u8; 12]> {
let salt = counter.to_be_bytes();
let hk = Hkdf::<Sha256>::new(Some(&salt), seed);
let mut nonce = [0u8; 12];
hk.expand(b"srx-nonce", &mut nonce)
.map_err(|e| SrxError::Crypto(CryptoError::KdfFailed(e.to_string())))?;
Ok(nonce)
}
pub fn frame_id_mask(seed: &[u8; 32]) -> crate::error::Result<u64> {
let hk = Hkdf::<Sha256>::new(Some(seed), b"");
let mut buf = [0u8; 8];
hk.expand(b"srx-frame-id-mask", &mut buf)
.map_err(|e| SrxError::Crypto(CryptoError::KdfFailed(e.to_string())))?;
Ok(u64::from_be_bytes(buf))
}
pub fn combine_secrets(
pqc_secret: &[u8],
ecdh_secret: &[u8],
) -> crate::error::Result<[u8; 32]> {
let hk = Hkdf::<Sha256>::new(Some(ecdh_secret), pqc_secret);
let mut master = [0u8; 32];
hk.expand(b"srx-hybrid-master", &mut master)
.map_err(|e| SrxError::Crypto(CryptoError::KdfFailed(e.to_string())))?;
Ok(master)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_initial_seed_deterministic() {
let master = [0xABu8; 32];
let nonce = b"test-nonce";
let ts = 1700000000u64;
let s1 = KeyDerivation::derive_initial_seed(&master, ts, nonce).unwrap();
let s2 = KeyDerivation::derive_initial_seed(&master, ts, nonce).unwrap();
assert_eq!(s1, s2);
}
#[test]
fn test_derive_initial_seed_differs_on_input() {
let master = [0xABu8; 32];
let nonce = b"test-nonce";
let s1 = KeyDerivation::derive_initial_seed(&master, 1, nonce).unwrap();
let s2 = KeyDerivation::derive_initial_seed(&master, 2, nonce).unwrap();
assert_ne!(s1, s2);
}
#[test]
fn test_rotate_seed() {
let seed = [0x42u8; 32];
let stats = b"some-traffic-stats";
let rotated = KeyDerivation::rotate_seed(&seed, stats).unwrap();
assert_ne!(seed, rotated);
let rotated2 = KeyDerivation::rotate_seed(&seed, stats).unwrap();
assert_eq!(rotated, rotated2);
}
#[test]
fn test_derive_data_key() {
let seed = [0x11u8; 32];
let k0 = KeyDerivation::derive_data_key(&seed, 0).unwrap();
let k1 = KeyDerivation::derive_data_key(&seed, 1).unwrap();
assert_ne!(k0, k1);
}
#[test]
fn test_derive_nonce() {
let seed = [0x22u8; 32];
let n0 = KeyDerivation::derive_nonce(&seed, 0).unwrap();
let n1 = KeyDerivation::derive_nonce(&seed, 1).unwrap();
assert_ne!(n0, n1);
assert_eq!(n0.len(), 12);
}
#[test]
fn test_frame_id_mask_deterministic() {
let seed = [0x33u8; 32];
let m1 = KeyDerivation::frame_id_mask(&seed).unwrap();
let m2 = KeyDerivation::frame_id_mask(&seed).unwrap();
assert_eq!(m1, m2);
assert_ne!(m1, 0);
}
#[test]
fn test_frame_id_mask_differs_on_seed() {
let m1 = KeyDerivation::frame_id_mask(&[0x11u8; 32]).unwrap();
let m2 = KeyDerivation::frame_id_mask(&[0x22u8; 32]).unwrap();
assert_ne!(m1, m2);
}
#[test]
fn test_combine_secrets() {
let pqc = [0xAAu8; 32];
let ecdh = [0xBBu8; 32];
let master = KeyDerivation::combine_secrets(&pqc, &ecdh).unwrap();
assert_ne!(master, [0u8; 32]);
let master2 = KeyDerivation::combine_secrets(&ecdh, &pqc).unwrap();
assert_ne!(master, master2);
}
}