use super::super::PolynomialRingZq;
use crate::error::MathError;
use crate::integer::Z;
use crate::integer_mod_q::Zq;
use crate::macros::arithmetics::{
arithmetic_assign_trait_borrowed_to_owned, arithmetic_trait_borrowed_to_owned,
arithmetic_trait_mixed_borrowed_owned, arithmetic_trait_reverse,
};
use crate::macros::for_others::implement_for_others;
use crate::traits::CompareBase;
use std::ops::{Mul, MulAssign};
impl Mul<&Z> for &PolynomialRingZq {
type Output = PolynomialRingZq;
fn mul(self, scalar: &Z) -> Self::Output {
let mut out = PolynomialRingZq::from(&self.modulus);
out.poly = &self.poly * scalar;
out.reduce();
out
}
}
arithmetic_trait_reverse!(Mul, mul, Z, PolynomialRingZq, PolynomialRingZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, PolynomialRingZq, Z, PolynomialRingZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, Z, PolynomialRingZq, PolynomialRingZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, PolynomialRingZq, Z, PolynomialRingZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, Z, PolynomialRingZq, PolynomialRingZq);
implement_for_others!(Z, PolynomialRingZq, PolynomialRingZq, Mul Scalar for i8 i16 i32 i64 u8 u16 u32 u64);
impl Mul<&Zq> for &PolynomialRingZq {
type Output = PolynomialRingZq;
fn mul(self, scalar: &Zq) -> PolynomialRingZq {
self.mul_scalar_zq_safe(scalar).unwrap()
}
}
arithmetic_trait_reverse!(Mul, mul, Zq, PolynomialRingZq, PolynomialRingZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, PolynomialRingZq, Zq, PolynomialRingZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, Zq, PolynomialRingZq, PolynomialRingZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, PolynomialRingZq, Zq, PolynomialRingZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, Zq, PolynomialRingZq, PolynomialRingZq);
impl MulAssign<&Zq> for PolynomialRingZq {
fn mul_assign(&mut self, rhs: &Zq) {
if !self.compare_base(rhs) {
panic!("{}", self.call_compare_base_error(rhs).unwrap())
}
self.mul_assign(&rhs.value);
}
}
arithmetic_assign_trait_borrowed_to_owned!(MulAssign, mul_assign, PolynomialRingZq, Zq);
impl PolynomialRingZq {
pub fn mul_scalar_zq_safe(&self, scalar: &Zq) -> Result<Self, MathError> {
if !self.compare_base(scalar) {
return Err(self.call_compare_base_error(scalar).unwrap());
}
let mut out = PolynomialRingZq::from(&self.modulus);
out.poly = &self.poly * &scalar.value;
out.reduce();
Ok(out)
}
}
#[cfg(test)]
mod test_mul_z {
use super::PolynomialRingZq;
use crate::integer::Z;
use std::str::FromStr;
#[test]
fn borrowed_correctness() {
let poly_1 = PolynomialRingZq::from_str(&format!(
"3 1 2 {} / 4 1 2 3 1 mod {}",
i64::MAX,
u64::MAX
))
.unwrap();
let poly_2 = poly_1.clone();
let poly_3 = PolynomialRingZq::from_str(&format!(
"3 2 4 {} / 4 1 2 3 1 mod {}",
(i64::MAX as u64) * 2,
u64::MAX
))
.unwrap();
let integer = Z::from(2);
let poly_1 = &poly_1 * &integer;
let poly_2 = &integer * &poly_2;
assert_eq!(poly_3, poly_1);
assert_eq!(poly_3, poly_2);
}
#[test]
fn availability() {
let poly = PolynomialRingZq::from_str("3 1 2 3 / 4 1 2 3 1 mod 17").unwrap();
let z = Z::from(2);
_ = poly.clone() * z.clone();
_ = poly.clone() * 2i8;
_ = poly.clone() * 2u8;
_ = poly.clone() * 2i16;
_ = poly.clone() * 2u16;
_ = poly.clone() * 2i32;
_ = poly.clone() * 2u32;
_ = poly.clone() * 2i64;
_ = poly.clone() * 2u64;
_ = z.clone() * poly.clone();
_ = 2i8 * poly.clone();
_ = 2u64 * poly.clone();
_ = &poly * &z;
_ = &z * &poly;
_ = &poly * z.clone();
_ = z.clone() * &poly;
_ = poly.clone() * &z;
_ = &z * poly.clone();
_ = &poly * 2i8;
_ = 2i8 * &poly;
}
}
#[cfg(test)]
mod test_mul_zq {
use super::PolynomialRingZq;
use crate::integer_mod_q::Zq;
use std::str::FromStr;
#[test]
fn borrowed_correctness() {
let poly_1 = PolynomialRingZq::from_str(&format!(
"3 1 2 {} / 4 1 2 3 1 mod {}",
i64::MAX,
u64::MAX
))
.unwrap();
let poly_2 = poly_1.clone();
let poly_3 = PolynomialRingZq::from_str(&format!(
"3 2 4 {} / 4 1 2 3 1 mod {}",
(i64::MAX as u64) * 2,
u64::MAX
))
.unwrap();
let integer = Zq::from((2, u64::MAX));
let poly_1 = &poly_1 * &integer;
let poly_2 = &integer * &poly_2;
assert_eq!(poly_3, poly_1);
assert_eq!(poly_3, poly_2);
}
#[test]
fn availability() {
let poly = PolynomialRingZq::from_str("3 1 2 3 / 4 1 2 3 1 mod 17").unwrap();
let z = Zq::from((2, 17));
_ = poly.clone() * z.clone();
_ = z.clone() * poly.clone();
_ = &poly * &z;
_ = &z * &poly;
_ = &poly * z.clone();
_ = z.clone() * &poly;
_ = &z * poly.clone();
_ = poly.clone() * &z;
}
#[test]
#[should_panic]
fn different_moduli_panic() {
let poly = PolynomialRingZq::from_str("3 1 2 3 / 4 1 2 3 1 mod 17").unwrap();
let z = Zq::from((2, 16));
_ = &poly * &z;
}
#[test]
fn different_moduli_error() {
let poly = PolynomialRingZq::from_str("3 1 2 3 / 4 1 2 3 1 mod 17").unwrap();
let z = Zq::from((2, 16));
assert!(poly.mul_scalar_zq_safe(&z).is_err());
}
}
#[cfg(test)]
mod test_mul_assign {
use crate::{
integer::{PolyOverZ, Z},
integer_mod_q::{ModulusPolynomialRingZq, PolyOverZq, PolynomialRingZq, Zq},
};
use std::str::FromStr;
#[test]
fn consistency() {
let modulus =
ModulusPolynomialRingZq::from_str(&format!("4 1 0 0 1 mod {}", u64::MAX)).unwrap();
let poly_z = PolyOverZ::from_str("2 3 1").unwrap();
let mut polynomial_ring_zq = PolynomialRingZq::from((&poly_z, &modulus));
let cmp = &polynomial_ring_zq * i64::MAX;
polynomial_ring_zq *= i64::MAX;
assert_eq!(cmp, polynomial_ring_zq)
}
#[test]
fn availability() {
let modulus =
ModulusPolynomialRingZq::from_str(&format!("4 1 0 0 1 mod {}", u64::MAX)).unwrap();
let poly_z = PolyOverZ::from_str("2 3 1").unwrap();
let mut polynomial_ring_zq = PolynomialRingZq::from((&poly_z, &modulus));
let z = Z::from(2);
let zq = Zq::from((2, u64::MAX));
polynomial_ring_zq *= &z;
polynomial_ring_zq *= z;
polynomial_ring_zq *= &zq;
polynomial_ring_zq *= zq;
polynomial_ring_zq *= 1_u8;
polynomial_ring_zq *= 1_u16;
polynomial_ring_zq *= 1_u32;
polynomial_ring_zq *= 1_u64;
polynomial_ring_zq *= 1_i8;
polynomial_ring_zq *= 1_i16;
polynomial_ring_zq *= 1_i32;
polynomial_ring_zq *= 1_i64;
}
#[test]
#[should_panic]
fn mismatching_modulus_zq() {
let modulus =
ModulusPolynomialRingZq::from_str(&format!("4 1 0 0 1 mod {}", u64::MAX)).unwrap();
let poly_z = PolyOverZ::from_str("2 3 1").unwrap();
let mut polynomial_ring_zq = PolynomialRingZq::from((&poly_z, &modulus));
let zq = Zq::from((2, u64::MAX - 1));
polynomial_ring_zq *= &zq;
}
#[test]
#[should_panic]
fn mismatching_modulus_poly_zq() {
let modulus =
ModulusPolynomialRingZq::from_str(&format!("4 1 0 0 1 mod {}", u64::MAX)).unwrap();
let poly_z = PolyOverZ::from_str("2 3 1").unwrap();
let mut polynomial_ring_zq = PolynomialRingZq::from((&poly_z, &modulus));
let poly_z = PolyOverZ::from_str("2 3 1").unwrap();
let poly_zq = PolyOverZq::from((&poly_z, u64::MAX - 1));
polynomial_ring_zq *= &poly_zq;
}
}