Skip to main content

lib_q_ring/
module.rs

1//! Module operations: matrix–vector product in the NTT domain (ML-DSA style).
2
3use alloc::vec::Vec;
4
5use crate::field::reduce_poly_simd;
6use crate::poly::{
7    NttPoly,
8    Poly,
9};
10
11/// Column vector of ring elements (time domain).
12#[derive(Clone, Debug, PartialEq, Eq, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
13pub struct ModuleVec(pub Vec<Poly>);
14
15/// Public matrix `A` stored row-major as NTT polynomials `Â_{i,j}`.
16pub struct ModuleMatrix {
17    /// Number of rows `k`.
18    pub rows: usize,
19    /// Number of columns `l`.
20    pub cols: usize,
21    /// Row-major `k · l` entries.
22    pub entries_ntt: Vec<NttPoly>,
23}
24
25impl ModuleMatrix {
26    /// Expand `A` from seed `ρ` with [`crate::expand_a_from_seed`].
27    #[must_use]
28    pub fn expand_from_seed(seed: &[u8; 32], rows: usize, cols: usize) -> Self {
29        Self {
30            rows,
31            cols,
32            entries_ntt: crate::expand_a_from_seed(seed, rows, cols),
33        }
34    }
35
36    /// `y_i = InvNTT( Σ_j Â_{i,j} ∘ v̂_j )` — same pattern as ML-DSA `compute_matrix_x_mask`.
37    #[must_use]
38    pub fn mul_vec_ntt(&self, v_ntt: &[NttPoly]) -> ModuleVec {
39        assert_eq!(v_ntt.len(), self.cols);
40        assert_eq!(self.entries_ntt.len(), self.rows * self.cols);
41        let mut out = Vec::with_capacity(self.rows);
42        for i in 0..self.rows {
43            let mut acc = NttPoly::zero();
44            for (j, v_cell) in v_ntt.iter().enumerate() {
45                let mut prod = v_cell.clone();
46                prod.pointwise_mul_assign(&self.entries_ntt[i * self.cols + j]);
47                acc.add_assign(&prod);
48            }
49            reduce_poly_simd(acc.as_simd_mut());
50            out.push(acc.to_poly());
51        }
52        ModuleVec(out)
53    }
54
55    /// [`ModuleMatrix::mul_vec_ntt`] with automatic `NTT` on each input [`Poly`].
56    #[must_use]
57    pub fn mul_vec(&self, v: &ModuleVec) -> ModuleVec {
58        self.mul_vec_polys(&v.0)
59    }
60
61    /// [`ModuleMatrix::mul_vec_ntt`] over a borrowed witness slice (no `ModuleVec` wrapper copy).
62    #[must_use]
63    pub fn mul_vec_polys(&self, v: &[Poly]) -> ModuleVec {
64        assert_eq!(v.len(), self.cols);
65        let v_ntt: Vec<NttPoly> = v.iter().map(Poly::to_ntt).collect();
66        self.mul_vec_ntt(&v_ntt)
67    }
68}