1use alloc::vec::Vec;
4
5use crate::field::reduce_poly_simd;
6use crate::poly::{
7 NttPoly,
8 Poly,
9};
10
11#[derive(Clone, Debug, PartialEq, Eq, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
13pub struct ModuleVec(pub Vec<Poly>);
14
15pub struct ModuleMatrix {
17 pub rows: usize,
19 pub cols: usize,
21 pub entries_ntt: Vec<NttPoly>,
23}
24
25impl ModuleMatrix {
26 #[must_use]
28 pub fn expand_from_seed(seed: &[u8; 32], rows: usize, cols: usize) -> Self {
29 Self {
30 rows,
31 cols,
32 entries_ntt: crate::expand_a_from_seed(seed, rows, cols),
33 }
34 }
35
36 #[must_use]
38 pub fn mul_vec_ntt(&self, v_ntt: &[NttPoly]) -> ModuleVec {
39 assert_eq!(v_ntt.len(), self.cols);
40 assert_eq!(self.entries_ntt.len(), self.rows * self.cols);
41 let mut out = Vec::with_capacity(self.rows);
42 for i in 0..self.rows {
43 let mut acc = NttPoly::zero();
44 for (j, v_cell) in v_ntt.iter().enumerate() {
45 let mut prod = v_cell.clone();
46 prod.pointwise_mul_assign(&self.entries_ntt[i * self.cols + j]);
47 acc.add_assign(&prod);
48 }
49 reduce_poly_simd(acc.as_simd_mut());
50 out.push(acc.to_poly());
51 }
52 ModuleVec(out)
53 }
54
55 #[must_use]
57 pub fn mul_vec(&self, v: &ModuleVec) -> ModuleVec {
58 self.mul_vec_polys(&v.0)
59 }
60
61 #[must_use]
63 pub fn mul_vec_polys(&self, v: &[Poly]) -> ModuleVec {
64 assert_eq!(v.len(), self.cols);
65 let v_ntt: Vec<NttPoly> = v.iter().map(Poly::to_ntt).collect();
66 self.mul_vec_ntt(&v_ntt)
67 }
68}