rustpq 0.3.0

Pure Rust post-quantum cryptography suite - ML-KEM, ML-DSA, and more
Documentation
use crate::ml_dsa::ntt::{invntt_tomont, ntt};
use crate::ml_dsa::params::{N, Q};
use crate::ml_dsa::poly::Poly;
use crate::ml_dsa::reduce::montgomery_reduce;
use zeroize::Zeroize;

#[derive(Clone, Zeroize)]
#[zeroize(drop)]
pub struct PolyVec<const K: usize> {
    pub vec: [Poly; K],
}

impl<const K: usize> PolyVec<K> {
    pub fn new() -> Self {
        Self {
            vec: core::array::from_fn(|_| Poly::new()),
        }
    }

    pub fn reduce(&mut self) {
        for i in 0..K {
            self.vec[i].reduce();
        }
    }

    pub fn freeze(&mut self) {
        for i in 0..K {
            self.vec[i].freeze_coeffs();
        }
    }

    pub fn reduce_montgomery(&mut self) {
        for i in 0..K {
            self.vec[i].reduce_montgomery();
        }
    }

    pub fn to_mont(&self) -> PolyVec<K> {
        const Q128: i128 = Q as i128;
        const R2_MOD_Q: i64 = ((1_i128 << 64) % Q128) as i64;
        let mut result = self.clone();
        for i in 0..K {
            for j in 0..N {
                let val = self.vec[i].coeffs[j] as i64;
                result.vec[i].coeffs[j] = montgomery_reduce(val * R2_MOD_Q);
            }
        }
        result
    }

    pub fn add(&mut self, b: &PolyVec<K>) {
        for i in 0..K {
            self.vec[i].add(&b.vec[i]);
        }
    }

    pub fn ntt(&mut self) {
        for i in 0..K {
            ntt(&mut self.vec[i].coeffs);
        }
    }

    pub fn invntt_tomont(&mut self) {
        for i in 0..K {
            invntt_tomont(&mut self.vec[i].coeffs);
        }
    }

    pub fn pointwise_poly_montgomery(&self, a: &Poly) -> PolyVec<K> {
        let mut r = PolyVec::new();
        for i in 0..K {
            r.vec[i] = Poly::pointwise_montgomery(&self.vec[i], a);
        }
        r
    }

    pub fn chknorm(&self, bound: i32) -> bool {
        for i in 0..K {
            if self.vec[i].chknorm(bound) {
                return true;
            }
        }
        false
    }

    pub fn infinity_norm(&self) -> i32 {
        let mut max_norm = 0i32;
        for i in 0..K {
            for j in 0..N {
                let coeff = self.vec[i].coeffs[j];
                let centered = if coeff > Q / 2 {
                    Q - coeff
                } else if coeff < -(Q / 2) {
                    coeff + Q
                } else {
                    coeff
                };
                max_norm = max_norm.max(centered.abs());
            }
        }
        max_norm
    }

    pub fn uniform_eta<const ETA: usize>(&mut self, seed: &[u8; 64], nonce: u16) {
        for i in 0..K {
            self.vec[i].uniform_eta::<ETA>(seed, nonce + i as u16);
        }
    }

    pub fn uniform_gamma1<const GAMMA1: usize>(&mut self, seed: &[u8; 64], nonce: u16) {
        for i in 0..K {
            self.vec[i].uniform_gamma1::<GAMMA1>(seed, nonce + i as u16);
        }
    }
}

impl<const K: usize> Default for PolyVec<K> {
    fn default() -> Self {
        Self::new()
    }
}

pub fn polyvec_matrix_pointwise_montgomery<const K: usize, const L: usize>(
    mat: &[PolyVec<L>; K],
    v: &PolyVec<L>,
) -> PolyVec<K> {
    let mut r = PolyVec::<K>::new();
    let v_mont = v.to_mont();

    for (i, mat_row) in mat.iter().enumerate() {
        r.vec[i] = Poly::pointwise_montgomery(&mat_row.vec[0], &v_mont.vec[0]);
        for j in 1..L {
            let t = Poly::pointwise_montgomery(&mat_row.vec[j], &v_mont.vec[j]);
            r.vec[i].add(&t);
        }
    }

    r
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_polyvec_add() {
        let mut a = PolyVec::<4>::new();
        let mut b = PolyVec::<4>::new();
        a.vec[0].coeffs[0] = 100;
        b.vec[0].coeffs[0] = 50;
        a.add(&b);
        assert_eq!(a.vec[0].coeffs[0], 150);
    }

    #[test]
    fn test_polyvec_chknorm() {
        let mut a = PolyVec::<4>::new();
        a.vec[0].coeffs[0] = 100;
        assert!(!a.chknorm(200));
        assert!(a.chknorm(50));
    }
}