rustpq 0.3.0

Pure Rust post-quantum cryptography suite - ML-KEM, ML-DSA, and more
Documentation
use crate::ml_dsa::params::N;
use crate::ml_dsa::poly::Poly;
use crate::ml_dsa::polyvec::PolyVec;
use sha3::digest::{ExtendableOutput, Update, XofReader};
use sha3::Shake128;

pub fn expand_a<const K: usize, const L: usize>(rho: &[u8; 32]) -> [PolyVec<L>; K] {
    let mut mat: [PolyVec<L>; K] = core::array::from_fn(|_| PolyVec::new());

    for (i, mat_row) in mat.iter_mut().enumerate() {
        for j in 0..L {
            poly_uniform(&mut mat_row.vec[j], rho, (i as u16) << 8 | (j as u16));
        }
    }

    mat
}

pub fn poly_uniform(a: &mut Poly, seed: &[u8; 32], nonce: u16) {
    use crate::ml_dsa::params::Q;

    let mut state = Shake128::default();
    state.update(seed);
    state.update(&nonce.to_le_bytes());
    let mut reader = state.finalize_xof();

    const SHAKE128_RATE: usize = 168;
    let mut buf = [0u8; SHAKE128_RATE * 5];
    reader.read(&mut buf);

    let mut ctr = 0;
    let mut pos = 0;

    while ctr < N {
        if pos + 3 > buf.len() {
            reader.read(&mut buf);
            pos = 0;
        }

        let t = (buf[pos] as u32)
            | ((buf[pos + 1] as u32) << 8)
            | (((buf[pos + 2] as u32) & 0x7F) << 16);
        pos += 3;

        if t < Q as u32 {
            a.coeffs[ctr] = t as i32;
            ctr += 1;
        }
    }
}

pub fn challenge<const TAU: usize>(c: &mut Poly, seed: &[u8; 32]) {
    use sha3::Shake256;

    let mut state = Shake256::default();
    state.update(seed);
    let mut reader = state.finalize_xof();

    let mut buf = [0u8; 136];
    reader.read(&mut buf);

    let mut signs = 0u64;
    for (i, &byte) in buf.iter().enumerate().take(8) {
        signs |= (byte as u64) << (8 * i);
    }

    let mut pos = 8;
    for i in 0..N {
        c.coeffs[i] = 0;
    }

    for i in (N - TAU)..N {
        let mut b;
        loop {
            if pos >= buf.len() {
                reader.read(&mut buf);
                pos = 0;
            }
            b = buf[pos] as usize;
            pos += 1;

            if b <= i {
                break;
            }
        }

        c.coeffs[i] = c.coeffs[b];
        c.coeffs[b] = if (signs & 1) != 0 { -1 } else { 1 };
        signs >>= 1;
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_expand_a() {
        let rho = [0u8; 32];
        let _mat = expand_a::<4, 4>(&rho);
    }

    #[test]
    fn test_challenge() {
        let seed = [0u8; 32];
        let mut c = Poly::new();
        challenge::<39>(&mut c, &seed);

        let mut count = 0;
        for i in 0..N {
            if c.coeffs[i] != 0 {
                count += 1;
                assert!(c.coeffs[i] == 1 || c.coeffs[i] == -1);
            }
        }
        assert_eq!(count, 39);
    }
}