use rand::rngs::OsRng;
use std::collections::HashSet;
use std::fmt;
use anubis_core::{
format::{FileKey, Stanza, FILE_KEY_BYTES},
primitives::{aead_decrypt, aead_encrypt, hkdf},
secrecy::ExposeSecret,
};
use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine};
use oqs::kem::{Algorithm, Kem};
use zeroize::{Zeroize, Zeroizing};
use crate::{
error::{DecryptError, EncryptError},
pqc::{mlkem, x25519},
};
const HYBRID_RECIPIENT_TAG: &str = "hybrid";
const HYBRID_LABEL: &str = "postquantum";
fn mlkem() -> Kem {
oqs::init();
Kem::new(Algorithm::MlKem1024).expect("ML-KEM-1024 algorithm available")
}
fn hybrid_combiner(
x25519_ss: &[u8; 32],
mlkem_ss: &[u8; 32],
x25519_epk: &[u8; 32],
mlkem_ct: &[u8; 1568],
) -> [u8; 32] {
let mut ikm = Vec::with_capacity(64);
ikm.extend_from_slice(x25519_ss);
ikm.extend_from_slice(mlkem_ss);
let mut salt = Vec::with_capacity(1600);
salt.extend_from_slice(x25519_epk);
salt.extend_from_slice(mlkem_ct);
hkdf(&salt, b"anubis-hybrid-v2/X25519+MLKEM-1024", &ikm)
}
pub struct Identity {
x25519: x25519::Identity,
mlkem: mlkem::Identity,
}
impl Identity {
pub fn generate() -> Self {
Identity {
x25519: x25519::Identity::generate(),
mlkem: mlkem::Identity::generate(),
}
}
pub fn to_public(&self) -> Recipient {
Recipient {
x25519: self.x25519.to_public(),
mlkem: self.mlkem.to_public(),
}
}
fn unwrap_stanza(&self, stanza: &Stanza) -> Option<Result<FileKey, DecryptError>> {
if stanza.tag != HYBRID_RECIPIENT_TAG {
return None;
}
if stanza.args.len() != 2 {
return Some(Err(DecryptError::InvalidHeader));
}
let x25519_epk_bytes = match BASE64_STANDARD_NO_PAD.decode(&stanza.args[0]) {
Ok(bytes) if bytes.len() == 32 => {
let mut arr = [0u8; 32];
arr.copy_from_slice(&bytes);
arr
}
_ => return Some(Err(DecryptError::InvalidHeader)),
};
let mlkem_ct_bytes = match BASE64_STANDARD_NO_PAD.decode(&stanza.args[1]) {
Ok(bytes) if bytes.len() == 1568 => {
let mut arr = [0u8; 1568];
arr.copy_from_slice(&bytes);
arr
}
_ => return Some(Err(DecryptError::InvalidHeader)),
};
let x25519_ss = match self.x25519.diffie_hellman(&x25519_epk_bytes) {
Ok(ss) => ss,
Err(_) => return Some(Err(DecryptError::DecryptionFailed)),
};
let mlkem_ss = match self.mlkem.decapsulate(&mlkem_ct_bytes) {
Ok(ss) => ss,
Err(_) => return Some(Err(DecryptError::DecryptionFailed)),
};
let wrap_key = hybrid_combiner(&x25519_ss, &mlkem_ss, &x25519_epk_bytes, &mlkem_ct_bytes);
const ENCRYPTED_FILE_KEY_BYTES: usize = FILE_KEY_BYTES + 16;
if stanza.body.len() != ENCRYPTED_FILE_KEY_BYTES {
return Some(Err(DecryptError::InvalidHeader));
}
aead_decrypt(&Zeroizing::new(wrap_key), FILE_KEY_BYTES, &stanza.body)
.ok()
.map(|mut plaintext| {
Ok(FileKey::init_with_mut(|file_key| {
file_key.copy_from_slice(&plaintext);
plaintext.zeroize();
}))
})
}
}
impl crate::Identity for Identity {
fn unwrap_stanza(&self, stanza: &Stanza) -> Option<Result<FileKey, DecryptError>> {
Identity::unwrap_stanza(self, stanza)
}
}
#[derive(Clone)]
pub struct Recipient {
x25519: x25519::Recipient,
mlkem: mlkem::Recipient,
}
impl Recipient {
fn wrap_file_key(&self, file_key: &FileKey) -> Result<Vec<Stanza>, EncryptError> {
let mut rng = OsRng;
let x25519_esk = x25519_dalek::EphemeralSecret::random_from_rng(&mut rng);
let x25519_epk = x25519_dalek::PublicKey::from(&x25519_esk);
let x25519_ss = x25519_esk.diffie_hellman(self.x25519.public_key());
let x25519_ss_bytes: [u8; 32] = *x25519_ss.as_bytes();
let (mlkem_ct, mlkem_ss) = self.mlkem.encapsulate(&mut rng)?;
let mlkem_ss_bytes: [u8; 32] = mlkem_ss[..32].try_into().unwrap();
let x25519_epk_bytes: [u8; 32] = *x25519_epk.as_bytes();
let wrap_key = hybrid_combiner(&x25519_ss_bytes, &mlkem_ss_bytes, &x25519_epk_bytes, &mlkem_ct);
let encrypted_file_key = aead_encrypt(&Zeroizing::new(wrap_key), file_key.expose_secret());
Ok(vec![Stanza {
tag: HYBRID_RECIPIENT_TAG.to_string(),
args: vec![
BASE64_STANDARD_NO_PAD.encode(&x25519_epk_bytes),
BASE64_STANDARD_NO_PAD.encode(&mlkem_ct),
],
body: encrypted_file_key,
}])
}
}
impl crate::Recipient for Recipient {
fn wrap_file_key(
&self,
file_key: &FileKey,
) -> Result<(Vec<Stanza>, HashSet<String>), EncryptError> {
let mut labels = HashSet::new();
labels.insert(HYBRID_LABEL.to_string());
Ok((self.wrap_file_key(file_key)?, labels))
}
}
impl fmt::Display for Recipient {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "anubis1hybrid{}{}", self.x25519, self.mlkem)
}
}
impl fmt::Display for Identity {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ANUBIS-HYBRID-SECRET-KEY-1{}\n{}",
self.x25519, self.mlkem
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hybrid_combiner_deterministic() {
let x25519_ss = [1u8; 32];
let mlkem_ss = [2u8; 32];
let x25519_epk = [3u8; 32];
let mlkem_ct = [4u8; 1568];
let key1 = hybrid_combiner(&x25519_ss, &mlkem_ss, &x25519_epk, &mlkem_ct);
let key2 = hybrid_combiner(&x25519_ss, &mlkem_ss, &x25519_epk, &mlkem_ct);
assert_eq!(key1, key2, "Combiner should be deterministic");
}
#[test]
fn hybrid_combiner_different_inputs() {
let x25519_ss = [1u8; 32];
let mlkem_ss = [2u8; 32];
let x25519_epk = [3u8; 32];
let mlkem_ct = [4u8; 1568];
let key1 = hybrid_combiner(&x25519_ss, &mlkem_ss, &x25519_epk, &mlkem_ct);
let x25519_ss2 = [5u8; 32];
let key2 = hybrid_combiner(&x25519_ss2, &mlkem_ss, &x25519_epk, &mlkem_ct);
assert_ne!(key1, key2, "Different X25519 SS should produce different key");
let mlkem_ss2 = [6u8; 32];
let key3 = hybrid_combiner(&x25519_ss, &mlkem_ss2, &x25519_epk, &mlkem_ct);
assert_ne!(key1, key3, "Different ML-KEM SS should produce different key");
}
#[test]
fn hybrid_round_trip() {
let identity = Identity::generate();
let recipient = identity.to_public();
let file_key = FileKey::new(Box::new([42; 16]));
let stanzas = recipient.wrap_file_key(&file_key).unwrap();
assert_eq!(stanzas.len(), 1);
assert_eq!(stanzas[0].tag, HYBRID_RECIPIENT_TAG);
assert_eq!(stanzas[0].args.len(), 2);
let decrypted = identity.unwrap_stanza(&stanzas[0]).unwrap().unwrap();
assert_eq!(decrypted.expose_secret(), file_key.expose_secret());
}
#[test]
fn hybrid_labels() {
let recipient = Identity::generate().to_public();
let file_key = FileKey::new(Box::new([42; 16]));
let (_, labels) = <Recipient as crate::Recipient>::wrap_file_key(&recipient, &file_key)
.unwrap();
assert!(labels.contains(HYBRID_LABEL));
assert_eq!(labels.len(), 1);
}
}