use crate::{
integer::Z,
integer_mod_q::{MatNTTPolynomialRingZq, ModulusPolynomialRingZq},
traits::MatrixDimensions,
};
impl MatrixDimensions for MatNTTPolynomialRingZq {
fn get_num_rows(&self) -> i64 {
self.nr_rows as i64
}
fn get_num_columns(&self) -> i64 {
self.nr_columns as i64
}
}
impl MatNTTPolynomialRingZq {
pub fn get_mod(&self) -> ModulusPolynomialRingZq {
self.modulus.clone()
}
pub fn get_entry(&self, row: usize, column: usize) -> &[Z] {
assert!(
row < self.nr_rows,
"`row` needs to be smaller than `nr_rows`."
);
assert!(
column < self.nr_columns,
"`column` needs to be smaller than `nr_columns`."
);
let index = self.modulus.get_degree() as usize * row
+ self.modulus.get_degree() as usize * self.nr_rows * column;
&self.matrix[index..index + self.modulus.get_degree() as usize]
}
}
#[cfg(test)]
mod test_matrix_dimensions {
use crate::{
integer::{MatPolyOverZ, Z},
integer_mod_q::{MatNTTPolynomialRingZq, MatPolynomialRingZq, ModulusPolynomialRingZq},
traits::MatrixDimensions,
};
use std::str::FromStr;
#[test]
fn nr_rows() {
let mut modulus = ModulusPolynomialRingZq::from_str("5 1 0 0 0 1 mod 257").unwrap();
modulus.set_ntt_unchecked(64);
let matrix = MatNTTPolynomialRingZq::sample_uniform(17, 2, &modulus);
let nr_rows = matrix.get_num_rows();
assert_eq!(17, nr_rows);
}
#[test]
fn nr_columns() {
let mut modulus = ModulusPolynomialRingZq::from_str("5 1 0 0 0 1 mod 257").unwrap();
modulus.set_ntt_unchecked(64);
let matrix = MatNTTPolynomialRingZq::sample_uniform(2, 13, &modulus);
let nr_columns = matrix.get_num_columns();
assert_eq!(13, nr_columns);
}
#[test]
fn get_entry() {
let mut modulus = ModulusPolynomialRingZq::from_str("5 1 0 0 0 1 mod 257").unwrap();
modulus.set_ntt_unchecked(64);
let mat_poly = MatPolyOverZ::from_str("[[4 15 17 19 21],[4 1 2 3 4]]").unwrap();
let matrix = MatPolynomialRingZq::from((&mat_poly, &modulus));
let ntt_matrix = MatNTTPolynomialRingZq::from(&matrix);
assert_eq!(
[Z::from(112), Z::from(189), Z::from(81), Z::from(192)],
ntt_matrix.get_entry(0, 0)
);
assert_eq!(
[Z::from(97), Z::from(56), Z::from(66), Z::from(42)],
ntt_matrix.get_entry(1, 0)
);
}
}