use chacha20poly1305::{aead::Aead, ChaCha20Poly1305, KeyInit as AeadKeyInit};
use hkdf::Hkdf;
use hmac::Mac;
use sha2::Sha256;
type HmacSha256 = hmac::Hmac<Sha256>;
use zeroize::Zeroizing;
use super::error::CryptoError;
const KEM_SUITE_ID: &[u8] = b"KEM\x00\x20";
const HPKE_SUITE_ID: &[u8] = b"HPKE\x00\x20\x00\x01\x00\x03";
const VERSION_LABEL: &[u8] = b"HPKE-v1";
const NH: usize = 32;
const N_SECRET: usize = 32;
const NK: usize = 32;
const NN: usize = 12;
const N_PK: usize = 32;
const MODE_BASE: u8 = 0x00;
fn build_labeled_ikm(suite_id: &[u8], label: &[u8], ikm: &[u8]) -> Vec<u8> {
let mut buf = Vec::with_capacity(VERSION_LABEL.len() + suite_id.len() + label.len() + ikm.len());
buf.extend_from_slice(VERSION_LABEL);
buf.extend_from_slice(suite_id);
buf.extend_from_slice(label);
buf.extend_from_slice(ikm);
buf
}
fn build_labeled_info(suite_id: &[u8], label: &[u8], info: &[u8], len: u16) -> Vec<u8> {
let mut buf = Vec::with_capacity(2 + VERSION_LABEL.len() + suite_id.len() + label.len() + info.len());
buf.extend_from_slice(&len.to_be_bytes());
buf.extend_from_slice(VERSION_LABEL);
buf.extend_from_slice(suite_id);
buf.extend_from_slice(label);
buf.extend_from_slice(info);
buf
}
fn labeled_extract(suite_id: &[u8], salt: &[u8], label: &[u8], ikm: &[u8]) -> [u8; NH] {
let labeled_ikm = build_labeled_ikm(suite_id, label, ikm);
let effective_salt: &[u8] = if salt.is_empty() { &[0u8; NH] } else { salt };
let mut mac = <HmacSha256 as Mac>::new_from_slice(effective_salt).expect("HMAC-SHA256 accepts any key length");
Mac::update(&mut mac, &labeled_ikm);
let result = mac.finalize().into_bytes();
let mut prk = [0u8; NH];
prk.copy_from_slice(&result);
prk
}
fn labeled_expand(
suite_id: &[u8],
prk: &[u8; NH],
label: &[u8],
info: &[u8],
len: usize,
) -> Result<Vec<u8>, CryptoError> {
let labeled_info = build_labeled_info(suite_id, label, info, len as u16);
let hkdf = Hkdf::<Sha256>::from_prk(prk)
.map_err(|_| CryptoError::ThresholdDecrypt("invalid PRK for HKDF-Expand".into()))?;
let mut okm = vec![0u8; len];
hkdf.expand(&labeled_info, &mut okm)
.map_err(|_| CryptoError::ThresholdDecrypt("HKDF-Expand failed".into()))?;
Ok(okm)
}
fn extract_and_expand(dh: &[u8], kem_context: &[u8]) -> Result<[u8; N_SECRET], CryptoError> {
let prk = labeled_extract(KEM_SUITE_ID, &[], b"eae_prk", dh);
let shared_secret_vec = labeled_expand(KEM_SUITE_ID, &prk, b"shared_secret", kem_context, N_SECRET)?;
let mut shared_secret = [0u8; N_SECRET];
shared_secret.copy_from_slice(&shared_secret_vec);
Ok(shared_secret)
}
fn key_schedule_base(shared_secret: &[u8; N_SECRET], info: &[u8]) -> Result<([u8; NK], [u8; NN]), CryptoError> {
let psk_id_hash = labeled_extract(HPKE_SUITE_ID, &[], b"psk_id_hash", &[]);
let info_hash = labeled_extract(HPKE_SUITE_ID, &[], b"info_hash", info);
let mut ks_context = Vec::with_capacity(1 + NH + NH);
ks_context.push(MODE_BASE);
ks_context.extend_from_slice(&psk_id_hash);
ks_context.extend_from_slice(&info_hash);
let secret = labeled_extract(HPKE_SUITE_ID, shared_secret, b"secret", &[]);
let key_vec = labeled_expand(HPKE_SUITE_ID, &secret, b"key", &ks_context, NK)?;
let mut key = [0u8; NK];
key.copy_from_slice(&key_vec);
let nonce_vec = labeled_expand(HPKE_SUITE_ID, &secret, b"base_nonce", &ks_context, NN)?;
let mut base_nonce = [0u8; NN];
base_nonce.copy_from_slice(&nonce_vec);
Ok((key, base_nonce))
}
pub fn decrypt_with_precomputed_dh(
dh: &[u8],
enc: &[u8],
pk_r: &[u8],
ciphertext: &[u8],
aad: &[u8],
) -> Result<Zeroizing<Vec<u8>>, CryptoError> {
if enc.len() != N_PK {
return Err(CryptoError::ThresholdDecrypt(format!(
"enc must be {} bytes, got {}",
N_PK,
enc.len()
)));
}
if pk_r.len() != N_PK {
return Err(CryptoError::ThresholdDecrypt(format!(
"pk_r must be {} bytes, got {}",
N_PK,
pk_r.len()
)));
}
let mut kem_context = Vec::with_capacity(N_PK + N_PK);
kem_context.extend_from_slice(enc);
kem_context.extend_from_slice(pk_r);
let shared_secret = extract_and_expand(dh, &kem_context)?;
let (key, base_nonce) = key_schedule_base(&shared_secret, &[])?;
let cipher = <ChaCha20Poly1305 as AeadKeyInit>::new(chacha20poly1305::Key::from_slice(&key));
let nonce = chacha20poly1305::Nonce::from_slice(&base_nonce);
let plaintext = cipher
.decrypt(nonce, chacha20poly1305::aead::Payload { msg: ciphertext, aad })
.map_err(|e| CryptoError::ThresholdDecrypt(format!("AEAD decryption failed: {e}")))?;
Ok(Zeroizing::new(plaintext))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto::hpke;
fn x25519_dh(sk_bytes: &[u8], pk_bytes: &[u8]) -> [u8; 32] {
let sk = x25519_dalek::StaticSecret::from(<[u8; 32]>::try_from(sk_bytes).unwrap());
let pk = x25519_dalek::PublicKey::from(<[u8; 32]>::try_from(pk_bytes).unwrap());
*sk.diffie_hellman(&pk).as_bytes()
}
#[test]
fn equivalence_with_standard_hpke_decrypt() {
let (sk, pk) = hpke::generate_keypair();
let plaintext = b"threshold decryption equivalence test";
let aad = b"newton-privacy-aad";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).expect("encrypt failed");
let dh = x25519_dh(sk.to_bytes(), &enc);
let recovered =
decrypt_with_precomputed_dh(&dh, &enc, pk.to_bytes(), &ct, aad).expect("threshold decrypt failed");
assert_eq!(&*recovered, plaintext, "threshold decrypt must match standard decrypt");
let standard_recovered = hpke::decrypt(&sk, &enc, &ct, aad).expect("standard decrypt failed");
assert_eq!(&*recovered, &*standard_recovered);
}
#[test]
fn equivalence_with_empty_plaintext() {
let (sk, pk) = hpke::generate_keypair();
let plaintext = b"";
let aad = b"empty-test";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).expect("encrypt failed");
let dh = x25519_dh(sk.to_bytes(), &enc);
let recovered =
decrypt_with_precomputed_dh(&dh, &enc, pk.to_bytes(), &ct, aad).expect("threshold decrypt failed");
assert_eq!(&*recovered, plaintext);
}
#[test]
fn equivalence_with_empty_aad() {
let (sk, pk) = hpke::generate_keypair();
let plaintext = b"test data with no aad";
let aad = b"";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).expect("encrypt failed");
let dh = x25519_dh(sk.to_bytes(), &enc);
let recovered =
decrypt_with_precomputed_dh(&dh, &enc, pk.to_bytes(), &ct, aad).expect("threshold decrypt failed");
assert_eq!(&*recovered, plaintext);
}
#[test]
fn equivalence_with_large_plaintext() {
let (sk, pk) = hpke::generate_keypair();
let plaintext = vec![0xABu8; 8192]; let aad = b"large-payload-test";
let (enc, ct) = hpke::encrypt(&pk, &plaintext, aad).expect("encrypt failed");
let dh = x25519_dh(sk.to_bytes(), &enc);
let recovered =
decrypt_with_precomputed_dh(&dh, &enc, pk.to_bytes(), &ct, aad).expect("threshold decrypt failed");
assert_eq!(recovered[..], plaintext[..]);
}
#[test]
fn wrong_dh_output_fails() {
let (sk, pk) = hpke::generate_keypair();
let plaintext = b"should fail with wrong DH";
let aad = b"ctx";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).expect("encrypt failed");
let (wrong_sk, _) = hpke::generate_keypair();
let wrong_dh = x25519_dh(wrong_sk.to_bytes(), &enc);
let result = decrypt_with_precomputed_dh(&wrong_dh, &enc, pk.to_bytes(), &ct, aad);
assert!(result.is_err(), "wrong DH output must cause decryption failure");
}
#[test]
fn wrong_aad_fails() {
let (sk, pk) = hpke::generate_keypair();
let plaintext = b"should fail with wrong AAD";
let aad = b"correct";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).expect("encrypt failed");
let dh = x25519_dh(sk.to_bytes(), &enc);
let result = decrypt_with_precomputed_dh(&dh, &enc, pk.to_bytes(), &ct, b"wrong");
assert!(result.is_err(), "wrong AAD must cause decryption failure");
}
#[test]
fn wrong_pk_r_fails() {
let (sk, pk) = hpke::generate_keypair();
let plaintext = b"should fail with wrong pkR";
let aad = b"ctx";
let (enc, ct) = hpke::encrypt(&pk, plaintext, aad).expect("encrypt failed");
let dh = x25519_dh(sk.to_bytes(), &enc);
let (_, wrong_pk) = hpke::generate_keypair();
let result = decrypt_with_precomputed_dh(&dh, &enc, wrong_pk.to_bytes(), &ct, aad);
assert!(result.is_err(), "wrong pkR must cause decryption failure");
}
#[test]
fn invalid_enc_length_rejected() {
let result = decrypt_with_precomputed_dh(&[0u8; 32], &[0u8; 16], &[0u8; 32], &[0u8; 32], &[]);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("enc must be 32 bytes"));
}
#[test]
fn invalid_pk_r_length_rejected() {
let result = decrypt_with_precomputed_dh(&[0u8; 32], &[0u8; 32], &[0u8; 16], &[0u8; 32], &[]);
assert!(result.is_err());
let err_msg = result.unwrap_err().to_string();
assert!(err_msg.contains("pk_r must be 32 bytes"));
}
#[test]
fn multiple_encryptions_same_key_all_decrypt() {
let (sk, pk) = hpke::generate_keypair();
let aad = b"multi-msg";
for i in 0..5 {
let plaintext = format!("message number {i}");
let (enc, ct) = hpke::encrypt(&pk, plaintext.as_bytes(), aad).expect("encrypt failed");
let dh = x25519_dh(sk.to_bytes(), &enc);
let recovered =
decrypt_with_precomputed_dh(&dh, &enc, pk.to_bytes(), &ct, aad).expect("threshold decrypt failed");
assert_eq!(&*recovered, plaintext.as_bytes());
}
}
}