use super::MatPolynomialRingZq;
use crate::{
error::MathError,
integer::PolyOverZ,
traits::{CompareBase, MatrixDimensions, MatrixGetEntry, Tensor},
};
use flint_sys::{fmpz_poly::fmpz_poly_mul, fmpz_poly_mat::fmpz_poly_mat_entry};
impl Tensor for MatPolynomialRingZq {
fn tensor_product(&self, other: &Self) -> Self {
self.tensor_product_safe(other).unwrap()
}
}
impl MatPolynomialRingZq {
pub fn tensor_product_safe(&self, other: &Self) -> Result<Self, MathError> {
if !self.compare_base(other) {
return Err(self.call_compare_base_error(other).unwrap());
}
let mut out = MatPolynomialRingZq::new(
self.get_num_rows() * other.get_num_rows(),
self.get_num_columns() * other.get_num_columns(),
self.get_mod(),
);
for i in 0..self.get_num_rows() {
for j in 0..self.get_num_columns() {
let entry: PolyOverZ = unsafe { self.get_entry_unchecked(i, j) };
if !entry.is_zero() {
unsafe { set_matrix_window_mul(&mut out, i, j, entry, other) }
}
}
}
Ok(out)
}
}
unsafe fn set_matrix_window_mul(
out: &mut MatPolynomialRingZq,
row_left: i64,
column_upper: i64,
scalar: PolyOverZ,
matrix: &MatPolynomialRingZq,
) {
let columns_other = matrix.get_num_columns();
let rows_other = matrix.get_num_rows();
assert!(row_left >= 0 && row_left + rows_other <= out.get_num_rows());
assert!(column_upper >= 0 && column_upper + columns_other <= out.get_num_columns());
for i_other in 0..rows_other {
for j_other in 0..columns_other {
unsafe {
fmpz_poly_mul(
fmpz_poly_mat_entry(
&out.matrix.matrix,
row_left * rows_other + i_other,
column_upper * columns_other + j_other,
),
&scalar.poly,
fmpz_poly_mat_entry(&matrix.matrix.matrix, i_other, j_other),
);
out.reduce_entry(
row_left * rows_other + i_other,
column_upper * columns_other + j_other,
);
}
}
}
}
#[cfg(test)]
mod test_tensor {
use crate::{
integer_mod_q::{MatPolynomialRingZq, ModulusPolynomialRingZq},
traits::{MatrixDimensions, Tensor},
};
use std::str::FromStr;
#[test]
fn dimensions_fit() {
let mod_poly = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 17").unwrap();
let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly);
let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly);
let mat_3 = mat_1.tensor_product(&mat_2);
assert_eq!(51, mat_3.get_num_rows());
assert_eq!(52, mat_3.get_num_columns());
}
#[test]
fn identity() {
let mod_poly =
ModulusPolynomialRingZq::from_str(&format!("3 1 2 1 mod {}", u64::MAX)).unwrap();
let identity = MatPolynomialRingZq::identity(2, 2, &mod_poly);
let mat_1 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MIN,
u64::MAX
))
.unwrap();
let mat_2 = identity.tensor_product(&mat_1);
let mat_3 = mat_1.tensor_product(&identity);
let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 1 {}, 1 1, 0, 0, 0],[0, 1 {}, 1 -1, 0, 0, 0],[0, 0, 0, 1 1, 1 {}, 1 1],[0, 0, 0, 0, 1 {}, 1 -1]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MIN,
i64::MAX,
i64::MIN,
u64::MAX
))
.unwrap();
let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 0, 1 {}, 0, 1 1, 0],[0, 1 1, 0, 1 {}, 0, 1 1],[0, 0, 1 {}, 0, 1 -1, 0],[0, 0, 0, 1 {}, 0, 1 -1]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MAX,
i64::MIN,
i64::MIN,
u64::MAX
))
.unwrap();
assert_eq!(cmp_mat_2, mat_2);
assert_eq!(cmp_mat_3, mat_3);
}
#[test]
fn vector_matrix() {
let vector =
MatPolynomialRingZq::from_str(&format!("[[1 1],[1 -1]] / 3 1 2 1 mod {}", u64::MAX))
.unwrap();
let mat_1 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MAX,
u64::MAX
))
.unwrap();
let mat_2 = vector.tensor_product(&mat_1);
let mat_3 = mat_1.tensor_product(&vector);
let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 1 {}, 1 1],[0, 1 {}, 1 -1],[1 -1, 1 -{}, 1 -1],[0, 1 -{}, 1 1]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MAX,
i64::MAX,
i64::MAX,
u64::MAX
))
.unwrap();
let cmp_mat_3 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 1 {}, 1 1],[1 -1, 1 -{}, 1 -1],[0, 1 {}, 1 -1],[0, 1 -{}, 1 1]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MAX,
i64::MAX,
i64::MAX,
u64::MAX
))
.unwrap();
assert_eq!(cmp_mat_2, mat_2);
assert_eq!(cmp_mat_3, mat_3);
}
#[test]
fn vector_vector() {
let vec_1 =
MatPolynomialRingZq::from_str(&format!("[[1 2],[1 1]] / 3 1 2 1 mod {}", u64::MAX))
.unwrap();
let vec_2 = MatPolynomialRingZq::from_str(&format!(
"[[1 {}],[1 {}]] / 3 1 2 1 mod {}",
(u64::MAX - 1) / 2,
i64::MIN / 2,
u64::MAX
))
.unwrap();
let vec_3 = vec_1.tensor_product(&vec_2);
let vec_4 = vec_2.tensor_product(&vec_1);
let cmp_vec_3 = MatPolynomialRingZq::from_str(&format!(
"[[1 {}],[1 {}],[1 {}],[1 {}]] / 3 1 2 1 mod {}",
u64::MAX - 1,
i64::MIN,
(u64::MAX - 1) / 2,
i64::MIN / 2,
u64::MAX
))
.unwrap();
let cmp_vec_4 = MatPolynomialRingZq::from_str(&format!(
"[[1 {}],[1 {}],[1 {}],[1 {}]] / 3 1 2 1 mod {}",
u64::MAX - 1,
(u64::MAX - 1) / 2,
i64::MIN,
i64::MIN / 2,
u64::MAX
))
.unwrap();
assert_eq!(cmp_vec_3, vec_3);
assert_eq!(cmp_vec_4, vec_4);
}
#[test]
fn higher_degree() {
let higher_degree = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 2 0 1, 2 1 1]] / 3 1 2 1 mod {}",
u64::MAX
))
.unwrap();
let mat_1 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 1 {}, 2 1 {}]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MIN,
u64::MAX
))
.unwrap();
let mat_2 = higher_degree.tensor_product(&mat_1);
let cmp_mat_2 = MatPolynomialRingZq::from_str(&format!(
"[[1 1, 1 {}, 2 1 {}, 2 0 1, 2 0 {}, 3 0 1 {}, 2 1 1, 2 {} {}, 3 1 {} {}]] / 3 1 2 1 mod {}",
i64::MAX,
i64::MIN,
i64::MAX,
i64::MIN,
i64::MAX,
i64::MAX,
i64::MIN + 1,
i64::MIN,
u64::MAX
))
.unwrap();
assert_eq!(cmp_mat_2, mat_2);
}
#[test]
#[should_panic]
fn moduli_mismatch_panic() {
let mod_poly_1 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 17").unwrap();
let mod_poly_2 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 16").unwrap();
let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
let _ = mat_1.tensor_product(&mat_2);
}
#[test]
fn moduli_mismatch_error() {
let mod_poly_1 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 17").unwrap();
let mod_poly_2 = ModulusPolynomialRingZq::from_str("3 1 2 1 mod 16").unwrap();
let mat_1 = MatPolynomialRingZq::new(17, 13, &mod_poly_1);
let mat_2 = MatPolynomialRingZq::new(3, 4, &mod_poly_2);
assert!(mat_1.tensor_product_safe(&mat_2).is_err());
}
}