use crate::{
error::MathError,
integer::{MatPolyOverZ, PolyOverZ, Z},
traits::{MatrixDimensions, MatrixSetEntry},
utils::index::evaluate_index,
};
use std::fmt::Display;
impl MatPolyOverZ {
pub fn sample_uniform(
num_rows: impl TryInto<i64> + Display,
num_cols: impl TryInto<i64> + Display,
max_degree: impl TryInto<i64> + Display,
lower_bound: impl Into<Z>,
upper_bound: impl Into<Z>,
) -> Result<Self, MathError> {
let lower_bound: Z = lower_bound.into();
let upper_bound: Z = upper_bound.into();
let max_degree = evaluate_index(max_degree)?;
let mut matrix = MatPolyOverZ::new(num_rows, num_cols);
for row in 0..matrix.get_num_rows() {
for col in 0..matrix.get_num_columns() {
let sample = PolyOverZ::sample_uniform(max_degree, &lower_bound, &upper_bound)?;
unsafe { matrix.set_entry_unchecked(row, col, sample) };
}
}
Ok(matrix)
}
}
#[cfg(test)]
mod test_sample_uniform {
use crate::traits::{GetCoefficient, MatrixDimensions, MatrixGetEntry};
use crate::{
integer::{MatPolyOverZ, Z},
integer_mod_q::Modulus,
};
#[test]
fn boundaries_kept_small() {
let lower_bound = Z::from(17);
let upper_bound = Z::from(32);
for _ in 0..32 {
let matrix = MatPolyOverZ::sample_uniform(1, 1, 0, &lower_bound, &upper_bound).unwrap();
let sample = matrix.get_entry(0, 0).unwrap();
let coeff = sample.get_coeff(0).unwrap();
assert!(lower_bound <= coeff);
assert!(coeff < upper_bound);
}
}
#[test]
fn boundaries_kept_large() {
let lower_bound = Z::from(i64::MIN) - Z::from(u64::MAX);
let upper_bound = Z::from(i64::MIN);
for _ in 0..256 {
let matrix = MatPolyOverZ::sample_uniform(1, 1, 0, &lower_bound, &upper_bound).unwrap();
let sample = matrix.get_entry(0, 0).unwrap();
let coeff = sample.get_coeff(0).unwrap();
assert!(lower_bound <= coeff);
assert!(coeff < upper_bound);
}
}
#[test]
fn nr_coeffs() {
let degrees = [1, 3, 7, 15, 32, 120];
for degree in degrees {
let matrix = MatPolyOverZ::sample_uniform(1, 1, degree, 1, 15).unwrap();
let poly = matrix.get_entry(0, 0).unwrap();
assert_eq!(degree, poly.get_degree());
}
}
#[should_panic]
#[test]
fn false_size() {
let lower_bound = Z::from(-15);
let upper_bound = Z::from(15);
let _ = MatPolyOverZ::sample_uniform(0, 3, 1, &lower_bound, &upper_bound);
}
#[test]
fn invalid_interval() {
let lb_0 = Z::from(i64::MIN);
let lb_1 = Z::from(i64::MIN);
let lb_2 = Z::ZERO;
let upper_bound = Z::from(i64::MIN);
let mat_0 = MatPolyOverZ::sample_uniform(3, 3, 0, &lb_0, &upper_bound);
let mat_1 = MatPolyOverZ::sample_uniform(4, 1, 0, &lb_1, &upper_bound);
let mat_2 = MatPolyOverZ::sample_uniform(1, 5, 0, &lb_2, &upper_bound);
assert!(mat_0.is_err());
assert!(mat_1.is_err());
assert!(mat_2.is_err());
}
#[test]
fn invalid_max_degree() {
let lower_bound = Z::from(0);
let upper_bound = Z::from(15);
let res_0 = MatPolyOverZ::sample_uniform(1, 1, -1, &lower_bound, &upper_bound);
let res_1 = MatPolyOverZ::sample_uniform(1, 1, i64::MIN, &lower_bound, &upper_bound);
assert!(res_0.is_err());
assert!(res_1.is_err());
}
#[test]
fn availability() {
let modulus = Modulus::from(7);
let z = Z::from(7);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0u8, 0u16, 7u8);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0u16, 0u32, 7u16);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0u32, 0u64, 7u32);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0u64, 0i8, 7u64);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0i8, 0i16, 7i8);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0i16, 0i32, 7i16);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0i32, 0i64, 7i32);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0i64, &Z::ZERO, 7i64);
let _ = MatPolyOverZ::sample_uniform(1, 1, 0, 0u8, &modulus);
let _ = MatPolyOverZ::sample_uniform(1, 1, Z::ZERO, 0, &z);
}
#[test]
fn matrix_size() {
let lower_bound = Z::from(-15);
let upper_bound = Z::from(15);
let mat_0 = MatPolyOverZ::sample_uniform(3, 3, 0, &lower_bound, &upper_bound).unwrap();
let mat_1 = MatPolyOverZ::sample_uniform(4, 1, 0, &lower_bound, &upper_bound).unwrap();
let mat_2 = MatPolyOverZ::sample_uniform(1, 5, 0, &lower_bound, &upper_bound).unwrap();
let mat_3 = MatPolyOverZ::sample_uniform(15, 20, 0, &lower_bound, &upper_bound).unwrap();
assert_eq!(3, mat_0.get_num_rows());
assert_eq!(3, mat_0.get_num_columns());
assert_eq!(4, mat_1.get_num_rows());
assert_eq!(1, mat_1.get_num_columns());
assert_eq!(1, mat_2.get_num_rows());
assert_eq!(5, mat_2.get_num_columns());
assert_eq!(15, mat_3.get_num_rows());
assert_eq!(20, mat_3.get_num_columns());
}
}