anubis-age 1.4.0

Post-quantum secure encryption library with hybrid X25519+ML-KEM-1024 mode (internal dependency for anubis-rage)
Documentation
//! ML-KEM-1024 implementation for post-quantum secure age encryption.
//!
//! This module implements NIST Level-5 post-quantum security using ML-KEM-1024
//! (formerly known as Kyber1024), as standardized in FIPS 203.

use std::collections::HashSet;
use std::fmt;
use std::io;

use anubis_core::{
    format::{FileKey, Stanza, FILE_KEY_BYTES},
    primitives::{aead_decrypt, aead_encrypt, hkdf},
    secrecy::{ExposeSecret, SecretString},
};
use bech32::{ToBase32, Variant};
use oqs::kem::{Algorithm, Kem, SharedSecret};
use zeroize::{Zeroize, Zeroizing};

use crate::{
    error::{DecryptError, EncryptError},
    util::parse_bech32,
};

const SECRET_KEY_PREFIX: &str = "ANUBIS-MLKEM-1024-SECRET";
const PUBLIC_KEY_PREFIX: &str = "anubis1mlkem1";

/// The stanza tag for ML-KEM-1024 recipients.
pub const MLKEM1024_RECIPIENT_TAG: &str = "MLKEM-1024";
const MLKEM1024_RECIPIENT_KEY_LABEL: &[u8] = b"anubis-encryption.org/v1/MLKEM-1024";

/// The size in bytes of an ML-KEM-1024 public key.
pub const MLKEM1024_PUBLIC_KEY_BYTES: usize = 1568;
/// The size in bytes of an ML-KEM-1024 secret key.
pub const MLKEM1024_SECRET_KEY_BYTES: usize = 3168;
/// The size in bytes of an ML-KEM-1024 ciphertext.
pub const MLKEM1024_CIPHERTEXT_BYTES: usize = 1568;
const ENCRYPTED_FILE_KEY_BYTES: usize = FILE_KEY_BYTES + 16;

fn mlkem() -> Kem {
    oqs::init();
    Kem::new(Algorithm::MlKem1024).expect("ML-KEM-1024 algorithm available")
}

fn derive_wrap_key(
    shared_secret: &SharedSecret,
    public_key: &[u8],
    ciphertext: &[u8],
) -> Zeroizing<[u8; 32]> {
    let mut salt = Vec::with_capacity(public_key.len() + ciphertext.len());
    salt.extend_from_slice(public_key);
    salt.extend_from_slice(ciphertext);
    let key = hkdf(&salt, MLKEM1024_RECIPIENT_KEY_LABEL, shared_secret.as_ref());
    salt.zeroize();
    Zeroizing::new(key)
}

fn invalid_data(msg: &str) -> io::Error {
    io::Error::new(io::ErrorKind::InvalidData, msg)
}

/// An ML-KEM-1024 identity for decrypting age files.
///
/// This provides NIST Level-5 post-quantum security.
#[derive(Clone)]
pub struct Identity {
    secret_key: Zeroizing<Vec<u8>>,
    public_key: Vec<u8>,
}

impl Identity {
    /// Generates a new ML-KEM-1024 identity.
    pub fn generate() -> Self {
        let kem = mlkem();
        let (pk, sk) = kem.keypair().expect("ML-KEM keypair");
        Self {
            secret_key: Zeroizing::new(sk.as_ref().to_vec()),
            public_key: pk.as_ref().to_vec(),
        }
    }

    /// Serializes this identity to a string.
    ///
    /// Returns a Bech32-encoded string starting with `ANUBIS-MLKEM-1024-SECRET`.
    pub fn to_string(&self) -> SecretString {
        let mut material =
            Vec::with_capacity(MLKEM1024_SECRET_KEY_BYTES + MLKEM1024_PUBLIC_KEY_BYTES);
        material.extend_from_slice(self.secret_key.as_ref());
        material.extend_from_slice(&self.public_key);

        let encoded = bech32::encode(SECRET_KEY_PREFIX, material.to_base32(), Variant::Bech32)
            .expect("valid HRP");

        material.zeroize();

        SecretString::from(encoded.to_uppercase())
    }

    /// Returns the recipient corresponding to this identity.
    pub fn to_public(&self) -> Recipient {
        Recipient {
            public_key: self.public_key.clone(),
        }
    }

    /// Decapsulates an ML-KEM ciphertext. For use in hybrid mode.
    pub(crate) fn decapsulate(&self, ct: &[u8; 1568]) -> Result<[u8; 32], DecryptError> {
        let kem = mlkem();
        let sk = kem
            .secret_key_from_bytes(self.secret_key.as_ref())
            .ok_or(DecryptError::InvalidHeader)?;
        let ciphertext = kem
            .ciphertext_from_bytes(ct)
            .ok_or(DecryptError::InvalidHeader)?;
        let shared_secret = kem
            .decapsulate(sk, ciphertext)
            .map_err(|_| DecryptError::DecryptionFailed)?;

        let mut ss_bytes = [0u8; 32];
        ss_bytes.copy_from_slice(shared_secret.as_ref());
        Ok(ss_bytes)
    }
}

impl fmt::Display for Identity {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.to_string().expose_secret())
    }
}

impl std::str::FromStr for Identity {
    type Err = &'static str;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let (hrp, bytes) = parse_bech32(s).ok_or("invalid Bech32 encoding")?;
        if !hrp.eq_ignore_ascii_case(SECRET_KEY_PREFIX) {
            return Err("incorrect HRP");
        }
        if bytes.len() != MLKEM1024_SECRET_KEY_BYTES + MLKEM1024_PUBLIC_KEY_BYTES {
            return Err("incorrect identity length");
        }

        let secret_key = Zeroizing::new(bytes[..MLKEM1024_SECRET_KEY_BYTES].to_vec());
        let public_key = bytes[MLKEM1024_SECRET_KEY_BYTES..].to_vec();

        Ok(Self {
            secret_key,
            public_key,
        })
    }
}

impl crate::Identity for Identity {
    fn unwrap_stanza(&self, stanza: &Stanza) -> Option<Result<FileKey, DecryptError>> {
        if stanza.tag != MLKEM1024_RECIPIENT_TAG {
            return None;
        }

        if stanza.body.len() != MLKEM1024_CIPHERTEXT_BYTES + ENCRYPTED_FILE_KEY_BYTES {
            return Some(Err(DecryptError::InvalidHeader));
        }

        let (ct_bytes, encrypted_file_key) = stanza.body.split_at(MLKEM1024_CIPHERTEXT_BYTES);

        let kem = mlkem();
        let sk = match kem.secret_key_from_bytes(self.secret_key.as_ref()) {
            Some(sk) => sk,
            None => return Some(Err(DecryptError::InvalidHeader)),
        };
        let ciphertext = match kem.ciphertext_from_bytes(ct_bytes) {
            Some(ct) => ct,
            None => return Some(Err(DecryptError::InvalidHeader)),
        };
        let shared_secret = match kem.decapsulate(sk, ciphertext) {
            Ok(ss) => ss,
            Err(_) => return Some(Err(DecryptError::InvalidHeader)),
        };

        let wrap_key = derive_wrap_key(&shared_secret, &self.public_key, ct_bytes);

        aead_decrypt(&wrap_key, FILE_KEY_BYTES, encrypted_file_key)
            .ok()
            .map(|mut plaintext| {
                Ok(FileKey::init_with_mut(|file_key| {
                    file_key.copy_from_slice(&plaintext);
                    plaintext.zeroize();
                }))
            })
    }
}

/// An ML-KEM-1024 recipient for encrypting age files.
///
/// This provides NIST Level-5 post-quantum security.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct Recipient {
    public_key: Vec<u8>,
}

impl Recipient {
    fn ensure_length(bytes: &[u8]) -> Result<(), &'static str> {
        if bytes.len() == MLKEM1024_PUBLIC_KEY_BYTES {
            Ok(())
        } else {
            Err("incorrect pubkey length")
        }
    }

    /// Encapsulates a shared secret to this recipient. For use in hybrid mode.
    pub(crate) fn encapsulate<R: rand::Rng + rand::CryptoRng>(
        &self,
        _rng: &mut R,
    ) -> Result<([u8; 1568], Zeroizing<Vec<u8>>), EncryptError> {
        let kem = mlkem();
        let pk = kem
            .public_key_from_bytes(&self.public_key)
            .ok_or_else(|| EncryptError::Io(invalid_data("invalid ML-KEM public key")))?;
        let (ciphertext, shared_secret) = kem.encapsulate(pk).map_err(|_| {
            EncryptError::Io(invalid_data("failed to encapsulate ML-KEM shared secret"))
        })?;

        let mut ct_bytes = [0u8; 1568];
        ct_bytes.copy_from_slice(ciphertext.as_ref());
        Ok((ct_bytes, Zeroizing::new(shared_secret.as_ref().to_vec())))
    }
}

impl fmt::Display for Recipient {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(
            f,
            "{}",
            bech32::encode(
                PUBLIC_KEY_PREFIX,
                self.public_key.to_base32(),
                Variant::Bech32
            )
            .expect("valid HRP")
        )
    }
}

impl fmt::Debug for Recipient {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self)
    }
}

impl std::str::FromStr for Recipient {
    type Err = &'static str;

    fn from_str(s: &str) -> Result<Self, Self::Err> {
        let (hrp, bytes) = parse_bech32(s).ok_or("invalid Bech32 encoding")?;
        if !hrp.eq_ignore_ascii_case(PUBLIC_KEY_PREFIX) {
            return Err("incorrect HRP");
        }
        Self::ensure_length(&bytes)?;
        Ok(Self { public_key: bytes })
    }
}

impl crate::Recipient for Recipient {
    fn wrap_file_key(
        &self,
        file_key: &FileKey,
    ) -> Result<(Vec<Stanza>, HashSet<String>), EncryptError> {
        let kem = mlkem();

        let pk = kem
            .public_key_from_bytes(&self.public_key)
            .ok_or_else(|| EncryptError::Io(invalid_data("invalid ML-KEM public key")))?;
        let (ciphertext, shared_secret) = kem.encapsulate(pk).map_err(|_| {
            EncryptError::Io(invalid_data("failed to encapsulate ML-KEM shared secret"))
        })?;

        let wrap_key = derive_wrap_key(&shared_secret, &self.public_key, ciphertext.as_ref());
        let encrypted_file_key = aead_encrypt(&wrap_key, file_key.expose_secret());

        let mut body = Vec::with_capacity(MLKEM1024_CIPHERTEXT_BYTES + ENCRYPTED_FILE_KEY_BYTES);
        body.extend_from_slice(ciphertext.as_ref());
        body.extend_from_slice(&encrypted_file_key);

        let mut labels = HashSet::new();
        labels.insert("postquantum".to_string());
        labels.insert("nist-level-5".to_string());

        Ok((
            vec![Stanza {
                tag: MLKEM1024_RECIPIENT_TAG.to_owned(),
                args: vec![],
                body,
            }],
            labels,
        ))
    }
}

#[cfg(test)]
mod tests {
    use anubis_core::{format::FileKey, secrecy::ExposeSecret};

    use super::{Identity, Recipient};
    use crate::{Identity as _, Recipient as _};

    #[test]
    fn round_trip() {
        let identity = Identity::generate();
        let recipient = identity.to_public();
        let file_key = FileKey::new(Box::new([42; 16]));

        let (stanzas, labels) = recipient.wrap_file_key(&file_key).unwrap();
        assert!(labels.contains("postquantum"));
        assert!(labels.contains("nist-level-5"));
        assert_eq!(stanzas.len(), 1);

        let recovered = identity.unwrap_stanzas(&stanzas).unwrap().unwrap();
        assert_eq!(recovered.expose_secret(), file_key.expose_secret());
    }

    #[test]
    fn bech32_round_trip() {
        let identity = Identity::generate();
        let encoded = identity.to_string();
        let reparsed: Identity = encoded.expose_secret().parse().unwrap();
        assert_eq!(identity.public_key, reparsed.public_key);

        let recipient = identity.to_public();
        let encoded_recipient = recipient.to_string();
        let reparsed_recipient: Recipient = encoded_recipient.parse().unwrap();
        assert_eq!(recipient, reparsed_recipient);
    }
}