use super::super::PolyOverZq;
use crate::error::MathError;
use crate::integer::Z;
use crate::integer_mod_q::Zq;
use crate::macros::arithmetics::{
arithmetic_assign_between_types, arithmetic_assign_trait_borrowed_to_owned,
arithmetic_trait_borrowed_to_owned, arithmetic_trait_mixed_borrowed_owned,
arithmetic_trait_reverse,
};
use crate::macros::for_others::implement_for_others;
use crate::traits::CompareBase;
use flint_sys::fmpz_mod_poly::{fmpz_mod_poly_scalar_mul_fmpz, fmpz_mod_poly_scalar_mul_ui};
use std::ops::{Mul, MulAssign};
impl Mul<&Z> for &PolyOverZq {
type Output = PolyOverZq;
fn mul(self, scalar: &Z) -> Self::Output {
let mut out = PolyOverZq::from(&self.modulus);
unsafe {
fmpz_mod_poly_scalar_mul_fmpz(
&mut out.poly,
&self.poly,
&scalar.value,
self.modulus.get_fmpz_mod_ctx_struct(),
);
}
out
}
}
arithmetic_trait_reverse!(Mul, mul, Z, PolyOverZq, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, PolyOverZq, Z, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, Z, PolyOverZq, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, PolyOverZq, Z, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, Z, PolyOverZq, PolyOverZq);
implement_for_others!(Z, PolyOverZq, PolyOverZq, Mul Scalar for i8 i16 i32 i64 u8 u16 u32 u64);
impl Mul<&Zq> for &PolyOverZq {
type Output = PolyOverZq;
fn mul(self, scalar: &Zq) -> PolyOverZq {
self.mul_scalar_zq_safe(scalar).unwrap()
}
}
arithmetic_trait_reverse!(Mul, mul, Zq, PolyOverZq, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, PolyOverZq, Zq, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, Zq, PolyOverZq, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, PolyOverZq, Zq, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, Zq, PolyOverZq, PolyOverZq);
impl PolyOverZq {
pub fn mul_scalar_zq_safe(&self, scalar: &Zq) -> Result<Self, MathError> {
if !self.compare_base(scalar) {
return Err(self.call_compare_base_error(scalar).unwrap());
}
let mut out = PolyOverZq::from(&scalar.modulus);
unsafe {
fmpz_mod_poly_scalar_mul_fmpz(
&mut out.poly,
&self.poly,
&scalar.value.value,
out.modulus.get_fmpz_mod_ctx_struct(),
)
}
Ok(out)
}
}
impl MulAssign<&Z> for PolyOverZq {
fn mul_assign(&mut self, scalar: &Z) {
unsafe {
fmpz_mod_poly_scalar_mul_fmpz(
&mut self.poly,
&self.poly,
&scalar.value,
self.modulus.get_fmpz_mod_ctx_struct(),
)
};
}
}
impl MulAssign<&Zq> for PolyOverZq {
fn mul_assign(&mut self, scalar: &Zq) {
if !self.compare_base(scalar) {
panic!("{}", self.call_compare_base_error(scalar).unwrap())
}
unsafe {
fmpz_mod_poly_scalar_mul_fmpz(
&mut self.poly,
&self.poly,
&scalar.value.value,
self.modulus.get_fmpz_mod_ctx_struct(),
)
};
}
}
arithmetic_assign_trait_borrowed_to_owned!(MulAssign, mul_assign, PolyOverZq, Zq);
impl MulAssign<i64> for PolyOverZq {
fn mul_assign(&mut self, other: i64) {
let z = Z::from(other);
unsafe {
fmpz_mod_poly_scalar_mul_fmpz(
&mut self.poly,
&self.poly,
&z.value,
self.modulus.get_fmpz_mod_ctx_struct(),
)
};
}
}
impl MulAssign<u64> for PolyOverZq {
fn mul_assign(&mut self, other: u64) {
unsafe {
fmpz_mod_poly_scalar_mul_ui(
&mut self.poly,
&self.poly,
other,
self.modulus.get_fmpz_mod_ctx_struct(),
)
};
}
}
arithmetic_assign_trait_borrowed_to_owned!(MulAssign, mul_assign, PolyOverZq, Z);
arithmetic_assign_between_types!(MulAssign, mul_assign, PolyOverZq, i64, i32 i16 i8);
arithmetic_assign_between_types!(MulAssign, mul_assign, PolyOverZq, u64, u32 u16 u8);
#[cfg(test)]
mod test_mul_z {
use super::PolyOverZq;
use crate::integer::Z;
use std::str::FromStr;
#[test]
fn borrowed_correctness() {
let poly_1 =
PolyOverZq::from_str(&format!("3 1 2 {} mod {}", i64::MAX, u64::MAX)).unwrap();
let poly_2 = poly_1.clone();
let poly_3 = PolyOverZq::from_str(&format!(
"3 2 4 {} mod {}",
(i64::MAX as u64) * 2,
u64::MAX
))
.unwrap();
let integer = Z::from(2);
let poly_1 = &poly_1 * &integer;
let poly_2 = &integer * &poly_2;
assert_eq!(poly_3, poly_1);
assert_eq!(poly_3, poly_2);
}
#[test]
fn availability() {
let poly = PolyOverZq::from_str("3 1 2 3 mod 17").unwrap();
let z = Z::from(2);
_ = poly.clone() * z.clone();
_ = poly.clone() * 2i8;
_ = poly.clone() * 2u8;
_ = poly.clone() * 2i16;
_ = poly.clone() * 2u16;
_ = poly.clone() * 2i32;
_ = poly.clone() * 2u32;
_ = poly.clone() * 2i64;
_ = poly.clone() * 2u64;
_ = z.clone() * poly.clone();
_ = 2i8 * poly.clone();
_ = 2u64 * poly.clone();
_ = &poly * &z;
_ = &z * &poly;
_ = &poly * z.clone();
_ = z.clone() * &poly;
_ = poly.clone() * &z;
_ = &z * poly.clone();
_ = &poly * 2i8;
_ = 2i8 * &poly;
}
}
#[cfg(test)]
mod test_mul_zq {
use super::PolyOverZq;
use crate::integer_mod_q::Zq;
use std::str::FromStr;
#[test]
fn borrowed_correctness() {
let poly_1 =
PolyOverZq::from_str(&format!("3 1 2 {} mod {}", i64::MAX, u64::MAX)).unwrap();
let poly_2 = poly_1.clone();
let poly_3 = PolyOverZq::from_str(&format!(
"3 2 4 {} mod {}",
(i64::MAX as u64) * 2,
u64::MAX
))
.unwrap();
let integer = Zq::from((2, u64::MAX));
let poly_1 = &poly_1 * &integer;
let poly_2 = &integer * &poly_2;
assert_eq!(poly_3, poly_1);
assert_eq!(poly_3, poly_2);
}
#[test]
fn availability() {
let poly = PolyOverZq::from_str("3 1 2 3 mod 17").unwrap();
let z = Zq::from((2, 17));
_ = poly.clone() * z.clone();
_ = z.clone() * poly.clone();
_ = &poly * &z;
_ = &z * &poly;
_ = &poly * z.clone();
_ = z.clone() * &poly;
_ = &z * poly.clone();
_ = poly.clone() * &z;
}
#[test]
#[should_panic]
fn different_moduli_panic() {
let poly = PolyOverZq::from_str("3 1 2 3 mod 17").unwrap();
let z = Zq::from((2, 16));
_ = &poly * &z;
}
#[test]
fn different_moduli_error() {
let poly = PolyOverZq::from_str("3 1 2 3 mod 17").unwrap();
let z = Zq::from((2, 16));
assert!(poly.mul_scalar_zq_safe(&z).is_err());
}
}
#[cfg(test)]
mod test_mul_assign {
use crate::integer::{PolyOverZ, Z};
use crate::integer_mod_q::{PolyOverZq, Zq};
use std::str::FromStr;
#[test]
fn consistency() {
let mut a = PolyOverZq::from_str(&format!("2 2 -1 mod {}", u64::MAX - 1)).unwrap();
let cmp = &a * i32::MAX;
a *= i32::MAX;
assert_eq!(cmp, a);
}
#[test]
fn availability() {
let mut poly_zq = PolyOverZq::from_str("3 1 2 -3 mod 8").unwrap();
let z = Z::from(2);
let zq = Zq::from((2, 8));
let poly_z = PolyOverZ::from_str("2 3 1").unwrap();
poly_zq *= &z;
poly_zq *= z;
poly_zq *= &zq;
poly_zq *= zq;
poly_zq *= &poly_z;
poly_zq *= poly_z;
poly_zq *= 1_u8;
poly_zq *= 1_u16;
poly_zq *= 1_u32;
poly_zq *= 1_u64;
poly_zq *= 1_i8;
poly_zq *= 1_i16;
poly_zq *= 1_i32;
poly_zq *= 1_i64;
}
#[test]
#[should_panic]
fn mismatching_modulus_zq() {
let mut poly_zq = PolyOverZq::from_str("3 1 2 -3 mod 8").unwrap();
let zq = Zq::from((2, u64::MAX - 1));
poly_zq *= &zq;
}
}