rustpq 0.3.0

Pure Rust post-quantum cryptography suite - ML-KEM, ML-DSA, and more
Documentation
use super::encode::{
    poly_compress, poly_decompress, poly_from_msg, poly_to_msg, polyvec_compress,
    polyvec_decompress, polyvec_from_bytes, polyvec_to_bytes,
};
use super::ntt;
use super::params::{N, POLYBYTES, SYMBYTES};
use super::poly::Poly;
use super::polyvec::PolyVec;
use super::sampling::{sample_poly_cbd_eta1, sample_poly_cbd_eta2, sample_uniform};
use super::symmetric::hash_g;

fn gen_matrix<const K: usize>(a: &mut [[Poly; K]; K], seed: &[u8], transposed: bool) {
    for i in 0..K {
        for j in 0..K {
            if transposed {
                sample_uniform(&mut a[i][j], seed, i as u8, j as u8);
            } else {
                sample_uniform(&mut a[i][j], seed, j as u8, i as u8);
            }
        }
    }
}

pub fn keypair<const K: usize, const ETA1: usize>(
    pk: &mut [u8],
    sk: &mut [u8],
    seed: &[u8; SYMBYTES],
) {
    let buf = hash_g(seed);
    let (publicseed, noiseseed) = buf.split_at(SYMBYTES);

    let mut a: [[Poly; K]; K] = core::array::from_fn(|_| core::array::from_fn(|_| Poly::new()));
    gen_matrix::<K>(&mut a, publicseed, false);

    let mut skpv: PolyVec<K> = PolyVec::new();
    let mut e: PolyVec<K> = PolyVec::new();

    let mut nonce = 0u8;
    for i in 0..K {
        sample_poly_cbd_eta1::<ETA1>(&mut skpv.vec[i], noiseseed.try_into().unwrap(), nonce);
        nonce += 1;
    }
    for i in 0..K {
        sample_poly_cbd_eta1::<ETA1>(&mut e.vec[i], noiseseed.try_into().unwrap(), nonce);
        nonce += 1;
    }

    skpv.ntt();
    skpv.reduce();
    e.ntt();

    let mut pkpv: PolyVec<K> = PolyVec::new();
    for i in 0..K {
        let mut t = Poly::new();
        for j in 0..K {
            let tmp = ntt::basemul(&a[i][j], &skpv.vec[j]);
            t.add(&tmp);
        }
        t.reduce();
        t.montgomery_reduce_coeffs();
        pkpv.vec[i] = t;
    }

    pkpv.add(&e);
    pkpv.reduce();

    polyvec_to_bytes::<K>(&pkpv, pk);
    pk[K * POLYBYTES..K * POLYBYTES + SYMBYTES].copy_from_slice(publicseed);

    polyvec_to_bytes::<K>(&skpv, sk);
}

pub fn enc<
    const K: usize,
    const ETA1: usize,
    const ETA2: usize,
    const DU: usize,
    const DV: usize,
>(
    ct: &mut [u8],
    msg: &[u8; SYMBYTES],
    pk: &[u8],
    coins: &[u8; SYMBYTES],
) {
    let mut pkpv: PolyVec<K> = PolyVec::new();
    polyvec_from_bytes::<K>(&mut pkpv, &pk[..K * POLYBYTES]);

    let seed = &pk[K * POLYBYTES..K * POLYBYTES + SYMBYTES];

    let mut at: [[Poly; K]; K] = core::array::from_fn(|_| core::array::from_fn(|_| Poly::new()));
    gen_matrix::<K>(&mut at, seed, true);

    let mut sp: PolyVec<K> = PolyVec::new();
    let mut ep: PolyVec<K> = PolyVec::new();
    let mut epp = Poly::new();

    let mut nonce = 0u8;
    for i in 0..K {
        sample_poly_cbd_eta1::<ETA1>(&mut sp.vec[i], coins, nonce);
        nonce += 1;
    }
    for i in 0..K {
        sample_poly_cbd_eta2(&mut ep.vec[i], coins, nonce);
        nonce += 1;
    }
    sample_poly_cbd_eta2(&mut epp, coins, nonce);

    sp.ntt();
    sp.reduce();

    let mut u: PolyVec<K> = PolyVec::new();
    for i in 0..K {
        let mut t = Poly::new();
        for j in 0..K {
            let tmp = ntt::basemul(&at[i][j], &sp.vec[j]);
            t.add(&tmp);
        }
        t.reduce();
        u.vec[i] = t;
    }

    let mut v = pkpv.pointwise_acc_montgomery(&sp);

    u.inv_ntt();
    ntt::inv_ntt(&mut v);

    u.add(&ep);
    u.reduce();

    let mut k = Poly::new();
    poly_from_msg(&mut k, msg);
    v.add(&epp);
    v.add(&k);
    v.reduce();

    u.cond_sub_q();
    v.cond_sub_q();

    let u_bytes = K * N * DU / 8;
    polyvec_compress::<K, DU>(&u, &mut ct[..u_bytes]);
    poly_compress::<DV>(&v, &mut ct[u_bytes..]);
}

pub fn dec<const K: usize, const DU: usize, const DV: usize>(
    msg: &mut [u8; SYMBYTES],
    ct: &[u8],
    sk: &[u8],
) {
    let mut u: PolyVec<K> = PolyVec::new();
    let mut v = Poly::new();

    let u_bytes = K * N * DU / 8;
    polyvec_decompress::<K, DU>(&mut u, &ct[..u_bytes]);
    poly_decompress::<DV>(&mut v, &ct[u_bytes..]);

    let mut skpv: PolyVec<K> = PolyVec::new();
    polyvec_from_bytes::<K>(&mut skpv, sk);

    u.ntt();
    let mut mp = skpv.pointwise_acc_montgomery(&u);
    ntt::inv_ntt(&mut mp);

    v.sub(&mp);
    v.reduce();
    v.cond_sub_q();

    poly_to_msg(&v, msg);
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_indcpa_roundtrip_768() {
        let seed = [0u8; 32];
        let coins = [1u8; 32];
        let msg = [2u8; 32];

        let mut pk = [0u8; crate::ml_kem::params::mlkem768::PUBLICKEYBYTES];
        let mut sk = [0u8; 3 * POLYBYTES];
        keypair::<3, 2>(&mut pk, &mut sk, &seed);

        let mut ct = [0u8; crate::ml_kem::params::mlkem768::CIPHERTEXTBYTES];
        enc::<3, 2, 2, 10, 4>(&mut ct, &msg, &pk, &coins);

        let mut dec_msg = [0u8; 32];
        dec::<3, 10, 4>(&mut dec_msg, &ct, &sk);

        assert_eq!(msg, dec_msg);
    }
}