use alloc::vec::Vec;
use crate::field::reduce_poly_simd;
use crate::poly::{
NttPoly,
Poly,
};
#[derive(Clone, Debug, PartialEq, Eq, zeroize::Zeroize, zeroize::ZeroizeOnDrop)]
pub struct ModuleVec(pub Vec<Poly>);
pub struct ModuleMatrix {
pub rows: usize,
pub cols: usize,
pub entries_ntt: Vec<NttPoly>,
}
impl ModuleMatrix {
#[must_use]
pub fn expand_from_seed(seed: &[u8; 32], rows: usize, cols: usize) -> Self {
Self {
rows,
cols,
entries_ntt: crate::expand_a_from_seed(seed, rows, cols),
}
}
#[must_use]
pub fn mul_vec_ntt(&self, v_ntt: &[NttPoly]) -> ModuleVec {
assert_eq!(v_ntt.len(), self.cols);
assert_eq!(self.entries_ntt.len(), self.rows * self.cols);
let mut out = Vec::with_capacity(self.rows);
for i in 0..self.rows {
let mut acc = NttPoly::zero();
for (j, v_cell) in v_ntt.iter().enumerate() {
let mut prod = v_cell.clone();
prod.pointwise_mul_assign(&self.entries_ntt[i * self.cols + j]);
acc.add_assign(&prod);
}
reduce_poly_simd(acc.as_simd_mut());
out.push(acc.to_poly());
}
ModuleVec(out)
}
#[must_use]
pub fn mul_vec(&self, v: &ModuleVec) -> ModuleVec {
self.mul_vec_polys(&v.0)
}
#[must_use]
pub fn mul_vec_polys(&self, v: &[Poly]) -> ModuleVec {
assert_eq!(v.len(), self.cols);
let v_ntt: Vec<NttPoly> = v.iter().map(Poly::to_ntt).collect();
self.mul_vec_ntt(&v_ntt)
}
}