use std::collections::HashSet;
use std::fmt;
use std::io;
use anubis_core::{
format::{FileKey, Stanza, FILE_KEY_BYTES},
primitives::{aead_decrypt, aead_encrypt, hkdf},
secrecy::{ExposeSecret, SecretString},
};
use bech32::{ToBase32, Variant};
use oqs::kem::{Algorithm, Kem, SharedSecret};
use zeroize::{Zeroize, Zeroizing};
use crate::{
error::{DecryptError, EncryptError},
util::parse_bech32,
};
const SECRET_KEY_PREFIX: &str = "ANUBIS-MLKEM-1024-SECRET";
const PUBLIC_KEY_PREFIX: &str = "anubis1mlkem1";
pub const MLKEM1024_RECIPIENT_TAG: &str = "MLKEM-1024";
const MLKEM1024_RECIPIENT_KEY_LABEL: &[u8] = b"anubis-encryption.org/v1/MLKEM-1024";
pub const MLKEM1024_PUBLIC_KEY_BYTES: usize = 1568;
pub const MLKEM1024_SECRET_KEY_BYTES: usize = 3168;
pub const MLKEM1024_CIPHERTEXT_BYTES: usize = 1568;
const ENCRYPTED_FILE_KEY_BYTES: usize = FILE_KEY_BYTES + 16;
fn mlkem() -> Kem {
oqs::init();
Kem::new(Algorithm::MlKem1024).expect("ML-KEM-1024 algorithm available")
}
fn derive_wrap_key(
shared_secret: &SharedSecret,
public_key: &[u8],
ciphertext: &[u8],
) -> Zeroizing<[u8; 32]> {
let mut salt = Vec::with_capacity(public_key.len() + ciphertext.len());
salt.extend_from_slice(public_key);
salt.extend_from_slice(ciphertext);
let key = hkdf(&salt, MLKEM1024_RECIPIENT_KEY_LABEL, shared_secret.as_ref());
salt.zeroize();
Zeroizing::new(key)
}
fn invalid_data(msg: &str) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, msg)
}
#[derive(Clone)]
pub struct Identity {
secret_key: Zeroizing<Vec<u8>>,
public_key: Vec<u8>,
}
impl Identity {
pub fn generate() -> Self {
let kem = mlkem();
let (pk, sk) = kem.keypair().expect("ML-KEM keypair");
Self {
secret_key: Zeroizing::new(sk.as_ref().to_vec()),
public_key: pk.as_ref().to_vec(),
}
}
pub fn to_string(&self) -> SecretString {
let mut material =
Vec::with_capacity(MLKEM1024_SECRET_KEY_BYTES + MLKEM1024_PUBLIC_KEY_BYTES);
material.extend_from_slice(self.secret_key.as_ref());
material.extend_from_slice(&self.public_key);
let encoded = bech32::encode(SECRET_KEY_PREFIX, material.to_base32(), Variant::Bech32)
.expect("valid HRP");
material.zeroize();
SecretString::from(encoded.to_uppercase())
}
pub fn to_public(&self) -> Recipient {
Recipient {
public_key: self.public_key.clone(),
}
}
pub(crate) fn decapsulate(&self, ct: &[u8; 1568]) -> Result<[u8; 32], DecryptError> {
let kem = mlkem();
let sk = kem
.secret_key_from_bytes(self.secret_key.as_ref())
.ok_or(DecryptError::InvalidHeader)?;
let ciphertext = kem
.ciphertext_from_bytes(ct)
.ok_or(DecryptError::InvalidHeader)?;
let shared_secret = kem
.decapsulate(sk, ciphertext)
.map_err(|_| DecryptError::DecryptionFailed)?;
let mut ss_bytes = [0u8; 32];
ss_bytes.copy_from_slice(shared_secret.as_ref());
Ok(ss_bytes)
}
}
impl fmt::Display for Identity {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_string().expose_secret())
}
}
impl std::str::FromStr for Identity {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (hrp, bytes) = parse_bech32(s).ok_or("invalid Bech32 encoding")?;
if !hrp.eq_ignore_ascii_case(SECRET_KEY_PREFIX) {
return Err("incorrect HRP");
}
if bytes.len() != MLKEM1024_SECRET_KEY_BYTES + MLKEM1024_PUBLIC_KEY_BYTES {
return Err("incorrect identity length");
}
let secret_key = Zeroizing::new(bytes[..MLKEM1024_SECRET_KEY_BYTES].to_vec());
let public_key = bytes[MLKEM1024_SECRET_KEY_BYTES..].to_vec();
Ok(Self {
secret_key,
public_key,
})
}
}
impl crate::Identity for Identity {
fn unwrap_stanza(&self, stanza: &Stanza) -> Option<Result<FileKey, DecryptError>> {
if stanza.tag != MLKEM1024_RECIPIENT_TAG {
return None;
}
if stanza.body.len() != MLKEM1024_CIPHERTEXT_BYTES + ENCRYPTED_FILE_KEY_BYTES {
return Some(Err(DecryptError::InvalidHeader));
}
let (ct_bytes, encrypted_file_key) = stanza.body.split_at(MLKEM1024_CIPHERTEXT_BYTES);
let kem = mlkem();
let sk = match kem.secret_key_from_bytes(self.secret_key.as_ref()) {
Some(sk) => sk,
None => return Some(Err(DecryptError::InvalidHeader)),
};
let ciphertext = match kem.ciphertext_from_bytes(ct_bytes) {
Some(ct) => ct,
None => return Some(Err(DecryptError::InvalidHeader)),
};
let shared_secret = match kem.decapsulate(sk, ciphertext) {
Ok(ss) => ss,
Err(_) => return Some(Err(DecryptError::InvalidHeader)),
};
let wrap_key = derive_wrap_key(&shared_secret, &self.public_key, ct_bytes);
aead_decrypt(&wrap_key, FILE_KEY_BYTES, encrypted_file_key)
.ok()
.map(|mut plaintext| {
Ok(FileKey::init_with_mut(|file_key| {
file_key.copy_from_slice(&plaintext);
plaintext.zeroize();
}))
})
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Recipient {
public_key: Vec<u8>,
}
impl Recipient {
fn ensure_length(bytes: &[u8]) -> Result<(), &'static str> {
if bytes.len() == MLKEM1024_PUBLIC_KEY_BYTES {
Ok(())
} else {
Err("incorrect pubkey length")
}
}
pub(crate) fn encapsulate<R: rand::Rng + rand::CryptoRng>(
&self,
_rng: &mut R,
) -> Result<([u8; 1568], Zeroizing<Vec<u8>>), EncryptError> {
let kem = mlkem();
let pk = kem
.public_key_from_bytes(&self.public_key)
.ok_or_else(|| EncryptError::Io(invalid_data("invalid ML-KEM public key")))?;
let (ciphertext, shared_secret) = kem.encapsulate(pk).map_err(|_| {
EncryptError::Io(invalid_data("failed to encapsulate ML-KEM shared secret"))
})?;
let mut ct_bytes = [0u8; 1568];
ct_bytes.copy_from_slice(ciphertext.as_ref());
Ok((ct_bytes, Zeroizing::new(shared_secret.as_ref().to_vec())))
}
}
impl fmt::Display for Recipient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
bech32::encode(
PUBLIC_KEY_PREFIX,
self.public_key.to_base32(),
Variant::Bech32
)
.expect("valid HRP")
)
}
}
impl fmt::Debug for Recipient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self)
}
}
impl std::str::FromStr for Recipient {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (hrp, bytes) = parse_bech32(s).ok_or("invalid Bech32 encoding")?;
if !hrp.eq_ignore_ascii_case(PUBLIC_KEY_PREFIX) {
return Err("incorrect HRP");
}
Self::ensure_length(&bytes)?;
Ok(Self { public_key: bytes })
}
}
impl crate::Recipient for Recipient {
fn wrap_file_key(
&self,
file_key: &FileKey,
) -> Result<(Vec<Stanza>, HashSet<String>), EncryptError> {
let kem = mlkem();
let pk = kem
.public_key_from_bytes(&self.public_key)
.ok_or_else(|| EncryptError::Io(invalid_data("invalid ML-KEM public key")))?;
let (ciphertext, shared_secret) = kem.encapsulate(pk).map_err(|_| {
EncryptError::Io(invalid_data("failed to encapsulate ML-KEM shared secret"))
})?;
let wrap_key = derive_wrap_key(&shared_secret, &self.public_key, ciphertext.as_ref());
let encrypted_file_key = aead_encrypt(&wrap_key, file_key.expose_secret());
let mut body = Vec::with_capacity(MLKEM1024_CIPHERTEXT_BYTES + ENCRYPTED_FILE_KEY_BYTES);
body.extend_from_slice(ciphertext.as_ref());
body.extend_from_slice(&encrypted_file_key);
let mut labels = HashSet::new();
labels.insert("postquantum".to_string());
labels.insert("nist-level-5".to_string());
Ok((
vec![Stanza {
tag: MLKEM1024_RECIPIENT_TAG.to_owned(),
args: vec![],
body,
}],
labels,
))
}
}
#[cfg(test)]
mod tests {
use anubis_core::{format::FileKey, secrecy::ExposeSecret};
use super::{Identity, Recipient};
use crate::{Identity as _, Recipient as _};
#[test]
fn round_trip() {
let identity = Identity::generate();
let recipient = identity.to_public();
let file_key = FileKey::new(Box::new([42; 16]));
let (stanzas, labels) = recipient.wrap_file_key(&file_key).unwrap();
assert!(labels.contains("postquantum"));
assert!(labels.contains("nist-level-5"));
assert_eq!(stanzas.len(), 1);
let recovered = identity.unwrap_stanzas(&stanzas).unwrap().unwrap();
assert_eq!(recovered.expose_secret(), file_key.expose_secret());
}
#[test]
fn bech32_round_trip() {
let identity = Identity::generate();
let encoded = identity.to_string();
let reparsed: Identity = encoded.expose_secret().parse().unwrap();
assert_eq!(identity.public_key, reparsed.public_key);
let recipient = identity.to_public();
let encoded_recipient = recipient.to_string();
let reparsed_recipient: Recipient = encoded_recipient.parse().unwrap();
assert_eq!(recipient, reparsed_recipient);
}
}