use hkdf::Hkdf;
use sha2::Sha256;
use zeroize::Zeroize;
const INFO_ROOT: &[u8] = b"pq-ratchet v1 root";
const INFO_CHAIN: &[u8] = b"pq-ratchet v1 chain-key";
const INFO_MSG: &[u8] = b"pq-ratchet v1 msg-key";
pub fn kdf_rk(
root_key: &[u8; 32],
dh_output: &[u8; 32],
pq_output: &[u8; 32],
) -> ([u8; 32], [u8; 32]) {
let mut ikm = [0u8; 64];
ikm[..32].copy_from_slice(dh_output);
ikm[32..].copy_from_slice(pq_output);
let h = Hkdf::<Sha256>::new(Some(root_key.as_ref()), &ikm);
ikm.zeroize();
let mut out = [0u8; 64];
h.expand(INFO_ROOT, &mut out)
.expect("HKDF-SHA256 output length 64 is always valid");
let mut new_root = [0u8; 32];
let mut new_chain = [0u8; 32];
new_root.copy_from_slice(&out[..32]);
new_chain.copy_from_slice(&out[32..]);
out.zeroize();
(new_root, new_chain)
}
pub fn kdf_ck(chain_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
let h = Hkdf::<Sha256>::new(None, chain_key.as_ref());
let mut new_chain = [0u8; 32];
let mut msg_key = [0u8; 32];
h.expand(INFO_CHAIN, &mut new_chain)
.expect("HKDF-SHA256 output length 32 is always valid");
h.expand(INFO_MSG, &mut msg_key)
.expect("HKDF-SHA256 output length 32 is always valid");
(new_chain, msg_key)
}