rustpq 0.3.0

Pure Rust post-quantum cryptography suite - ML-KEM, ML-DSA, and more
Documentation
use super::error::Error;
use super::indcpa;
use super::params::{self, POLYBYTES, SYMBYTES};
use super::symmetric::{hash_g, hash_h, kdf};
use rand_core::CryptoRngCore;
use subtle::{ConditionallySelectable, ConstantTimeEq};
use zeroize::{Zeroize, Zeroizing};

macro_rules! impl_kem {
    ($mod_name:ident, $K:expr, $ETA1:expr, $ETA2:expr, $DU:expr, $DV:expr,
     $PK_BYTES:expr, $SK_BYTES:expr, $CT_BYTES:expr) => {
        pub mod $mod_name {
            use super::*;

            pub const PUBLIC_KEY_BYTES: usize = $PK_BYTES;
            pub const SECRET_KEY_BYTES: usize = $SK_BYTES;
            pub const CIPHERTEXT_BYTES: usize = $CT_BYTES;
            pub const SHARED_SECRET_BYTES: usize = 32;

            #[derive(Clone)]
            pub struct PublicKey {
                bytes: [u8; PUBLIC_KEY_BYTES],
            }

            impl PublicKey {
                pub fn as_bytes(&self) -> &[u8] {
                    &self.bytes
                }

                pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
                    if bytes.len() != PUBLIC_KEY_BYTES {
                        return Err(Error::InvalidPublicKeyLength);
                    }
                    let mut pk = Self {
                        bytes: [0u8; PUBLIC_KEY_BYTES],
                    };
                    pk.bytes.copy_from_slice(bytes);
                    Ok(pk)
                }
            }

            #[derive(Clone)]
            pub struct SecretKey {
                bytes: [u8; SECRET_KEY_BYTES],
            }

            impl Zeroize for SecretKey {
                fn zeroize(&mut self) {
                    self.bytes.zeroize();
                }
            }

            impl Drop for SecretKey {
                fn drop(&mut self) {
                    self.zeroize();
                }
            }

            impl SecretKey {
                pub fn as_bytes(&self) -> &[u8] {
                    &self.bytes
                }

                pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
                    if bytes.len() != SECRET_KEY_BYTES {
                        return Err(Error::InvalidSecretKeyLength);
                    }
                    let mut sk = Self {
                        bytes: [0u8; SECRET_KEY_BYTES],
                    };
                    sk.bytes.copy_from_slice(bytes);
                    Ok(sk)
                }
            }

            #[derive(Clone)]
            pub struct Ciphertext {
                bytes: [u8; CIPHERTEXT_BYTES],
            }

            impl Ciphertext {
                pub fn as_bytes(&self) -> &[u8] {
                    &self.bytes
                }

                pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
                    if bytes.len() != CIPHERTEXT_BYTES {
                        return Err(Error::InvalidCiphertextLength);
                    }
                    let mut ct = Self {
                        bytes: [0u8; CIPHERTEXT_BYTES],
                    };
                    ct.bytes.copy_from_slice(bytes);
                    Ok(ct)
                }
            }

            #[derive(Clone)]
            pub struct SharedSecret {
                bytes: [u8; SHARED_SECRET_BYTES],
            }

            impl Zeroize for SharedSecret {
                fn zeroize(&mut self) {
                    self.bytes.zeroize();
                }
            }

            impl Drop for SharedSecret {
                fn drop(&mut self) {
                    self.zeroize();
                }
            }

            impl SharedSecret {
                pub fn as_bytes(&self) -> &[u8; SHARED_SECRET_BYTES] {
                    &self.bytes
                }
            }

            pub fn generate(rng: &mut impl CryptoRngCore) -> (PublicKey, SecretKey) {
                let mut d = [0u8; SYMBYTES];
                let mut z = [0u8; SYMBYTES];
                rng.fill_bytes(&mut d);
                rng.fill_bytes(&mut z);
                generate_deterministic(&d, &z)
            }

            pub fn generate_deterministic(d: &[u8; 32], z: &[u8; 32]) -> (PublicKey, SecretKey) {
                let mut pk = PublicKey {
                    bytes: [0u8; PUBLIC_KEY_BYTES],
                };
                let mut sk = SecretKey {
                    bytes: [0u8; SECRET_KEY_BYTES],
                };

                let indcpa_sk_len = $K * POLYBYTES;

                indcpa::keypair::<$K, $ETA1>(&mut pk.bytes, &mut sk.bytes[..indcpa_sk_len], d);

                sk.bytes[indcpa_sk_len..indcpa_sk_len + PUBLIC_KEY_BYTES]
                    .copy_from_slice(&pk.bytes);

                let h = hash_h(&pk.bytes);
                sk.bytes
                    [indcpa_sk_len + PUBLIC_KEY_BYTES..indcpa_sk_len + PUBLIC_KEY_BYTES + SYMBYTES]
                    .copy_from_slice(&h);

                sk.bytes[indcpa_sk_len + PUBLIC_KEY_BYTES + SYMBYTES..].copy_from_slice(z);

                (pk, sk)
            }

            pub fn encapsulate(
                pk: &PublicKey,
                rng: &mut impl CryptoRngCore,
            ) -> (Ciphertext, SharedSecret) {
                let mut m = [0u8; SYMBYTES];
                rng.fill_bytes(&mut m);
                encapsulate_deterministic(pk, &m)
            }

            pub fn encapsulate_deterministic(
                pk: &PublicKey,
                m: &[u8; 32],
            ) -> (Ciphertext, SharedSecret) {
                let mut ct = Ciphertext {
                    bytes: [0u8; CIPHERTEXT_BYTES],
                };

                let h = hash_h(&pk.bytes);

                let mut buf = [0u8; 64];
                buf[..SYMBYTES].copy_from_slice(m);
                buf[SYMBYTES..SYMBYTES * 2].copy_from_slice(&h);
                let kr = hash_g(&buf);

                indcpa::enc::<$K, $ETA1, $ETA2, $DU, $DV>(
                    &mut ct.bytes,
                    m,
                    &pk.bytes,
                    kr[SYMBYTES..].try_into().unwrap(),
                );

                let mut ss = SharedSecret {
                    bytes: [0u8; SHARED_SECRET_BYTES],
                };
                ss.bytes.copy_from_slice(&kr[..SYMBYTES]);

                (ct, ss)
            }

            pub fn decapsulate(sk: &SecretKey, ct: &Ciphertext) -> SharedSecret {
                let indcpa_sk_len = $K * POLYBYTES;
                let pk_start = indcpa_sk_len;
                let pk_end = pk_start + PUBLIC_KEY_BYTES;
                let h_start = pk_end;
                let h_end = h_start + SYMBYTES;
                let z_start = h_end;

                let indcpa_sk = &sk.bytes[..indcpa_sk_len];
                let pk = &sk.bytes[pk_start..pk_end];
                let h = &sk.bytes[h_start..h_end];
                let z = &sk.bytes[z_start..z_start + SYMBYTES];

                let mut m_prime = Zeroizing::new([0u8; SYMBYTES]);
                indcpa::dec::<$K, $DU, $DV>(&mut m_prime, &ct.bytes, indcpa_sk);

                let mut buf = [0u8; 64];
                buf[..SYMBYTES].copy_from_slice(&*m_prime);
                buf[SYMBYTES..SYMBYTES * 2].copy_from_slice(h);
                let kr = hash_g(&buf);

                let mut ct_cmp = [0u8; CIPHERTEXT_BYTES];
                indcpa::enc::<$K, $ETA1, $ETA2, $DU, $DV>(
                    &mut ct_cmp,
                    &m_prime,
                    pk,
                    kr[SYMBYTES..].try_into().unwrap(),
                );

                let eq = ct.bytes.ct_eq(&ct_cmp);

                let mut kdf_input = [0u8; SYMBYTES + CIPHERTEXT_BYTES];
                kdf_input[..SYMBYTES].copy_from_slice(z);
                kdf_input[SYMBYTES..].copy_from_slice(&ct.bytes);
                let k_fail = kdf(&kdf_input);

                let mut ss = SharedSecret {
                    bytes: [0u8; SHARED_SECRET_BYTES],
                };
                for i in 0..SHARED_SECRET_BYTES {
                    ss.bytes[i] = u8::conditional_select(&k_fail[i], &kr[i], eq);
                }

                ss
            }
        }
    };
}

#[cfg(feature = "mlkem512")]
impl_kem!(
    mlkem512,
    { params::mlkem512::K },
    { params::mlkem512::ETA1 },
    { params::mlkem512::ETA2 },
    { params::mlkem512::DU },
    { params::mlkem512::DV },
    params::mlkem512::PUBLICKEYBYTES,
    params::mlkem512::SECRETKEYBYTES,
    params::mlkem512::CIPHERTEXTBYTES
);

#[cfg(feature = "mlkem768")]
impl_kem!(
    mlkem768,
    { params::mlkem768::K },
    { params::mlkem768::ETA1 },
    { params::mlkem768::ETA2 },
    { params::mlkem768::DU },
    { params::mlkem768::DV },
    params::mlkem768::PUBLICKEYBYTES,
    params::mlkem768::SECRETKEYBYTES,
    params::mlkem768::CIPHERTEXTBYTES
);

#[cfg(feature = "mlkem1024")]
impl_kem!(
    mlkem1024,
    { params::mlkem1024::K },
    { params::mlkem1024::ETA1 },
    { params::mlkem1024::ETA2 },
    { params::mlkem1024::DU },
    { params::mlkem1024::DV },
    params::mlkem1024::PUBLICKEYBYTES,
    params::mlkem1024::SECRETKEYBYTES,
    params::mlkem1024::CIPHERTEXTBYTES
);

#[cfg(test)]
mod tests {
    #[test]
    #[cfg(feature = "mlkem768")]
    fn test_mlkem768_roundtrip() {
        use super::mlkem768::*;
        use rand::rngs::OsRng;

        let (pk, sk) = generate(&mut OsRng);
        let (ct, ss_sender) = encapsulate(&pk, &mut OsRng);
        let ss_receiver = decapsulate(&sk, &ct);

        assert_eq!(ss_sender.as_bytes(), ss_receiver.as_bytes());
    }

    #[test]
    #[cfg(feature = "mlkem512")]
    fn test_mlkem512_roundtrip() {
        use super::mlkem512::*;
        use rand::rngs::OsRng;

        let (pk, sk) = generate(&mut OsRng);
        let (ct, ss_sender) = encapsulate(&pk, &mut OsRng);
        let ss_receiver = decapsulate(&sk, &ct);

        assert_eq!(ss_sender.as_bytes(), ss_receiver.as_bytes());
    }

    #[test]
    #[cfg(feature = "mlkem1024")]
    fn test_mlkem1024_roundtrip() {
        use super::mlkem1024::*;
        use rand::rngs::OsRng;

        let (pk, sk) = generate(&mut OsRng);
        let (ct, ss_sender) = encapsulate(&pk, &mut OsRng);
        let ss_receiver = decapsulate(&sk, &ct);

        assert_eq!(ss_sender.as_bytes(), ss_receiver.as_bytes());
    }

    #[test]
    #[cfg(feature = "mlkem768")]
    fn test_implicit_rejection() {
        use super::mlkem768::*;
        use rand::rngs::OsRng;

        let (pk, sk) = generate(&mut OsRng);
        let (ct, ss_sender) = encapsulate(&pk, &mut OsRng);

        let mut bad_bytes = [0u8; CIPHERTEXT_BYTES];
        bad_bytes.copy_from_slice(ct.as_bytes());
        bad_bytes[0] ^= 0xFF;
        let bad_ct = Ciphertext::from_bytes(&bad_bytes).unwrap();

        let ss_bad = decapsulate(&sk, &bad_ct);
        assert_ne!(ss_sender.as_bytes(), ss_bad.as_bytes());
    }

    #[test]
    #[cfg(feature = "mlkem768")]
    fn test_deterministic_keygen() {
        use super::mlkem768::*;

        let d = [0u8; 32];
        let z = [1u8; 32];

        let (pk1, sk1) = generate_deterministic(&d, &z);
        let (pk2, sk2) = generate_deterministic(&d, &z);

        assert_eq!(pk1.as_bytes(), pk2.as_bytes());
        assert_eq!(sk1.as_bytes(), sk2.as_bytes());
    }

    #[test]
    #[cfg(feature = "mlkem768")]
    fn test_deterministic_encaps() {
        use super::mlkem768::*;

        let d = [0u8; 32];
        let z = [1u8; 32];
        let m = [2u8; 32];

        let (pk, _) = generate_deterministic(&d, &z);
        let (ct1, ss1) = encapsulate_deterministic(&pk, &m);
        let (ct2, ss2) = encapsulate_deterministic(&pk, &m);

        assert_eq!(ct1.as_bytes(), ct2.as_bytes());
        assert_eq!(ss1.as_bytes(), ss2.as_bytes());
    }

    #[test]
    #[cfg(feature = "mlkem768")]
    fn test_encrypt_decrypt_hello_world() {
        use super::mlkem768::*;
        use rand::rngs::OsRng;

        let message = b"Hello World";

        let (pk, sk) = generate(&mut OsRng);
        let (ct, ss_sender) = encapsulate(&pk, &mut OsRng);
        let ss_receiver = decapsulate(&sk, &ct);

        assert_eq!(ss_sender.as_bytes(), ss_receiver.as_bytes());

        let key = ss_sender.as_bytes();
        let mut ciphertext = [0u8; 11];
        for i in 0..message.len() {
            ciphertext[i] = message[i] ^ key[i];
        }

        let key = ss_receiver.as_bytes();
        let mut plaintext = [0u8; 11];
        for i in 0..ciphertext.len() {
            plaintext[i] = ciphertext[i] ^ key[i];
        }

        assert_eq!(&plaintext, message);
        assert_eq!(core::str::from_utf8(&plaintext).unwrap(), "Hello World");
    }
}