use super::super::PolyOverZq;
use crate::{
error::MathError,
integer::PolyOverZ,
macros::arithmetics::{
arithmetic_assign_trait_borrowed_to_owned, arithmetic_trait_borrowed_to_owned,
arithmetic_trait_mixed_borrowed_owned, arithmetic_trait_reverse,
},
traits::CompareBase,
};
use core::panic;
use flint_sys::fmpz_mod_poly::fmpz_mod_poly_mul;
use std::{
ops::{Mul, MulAssign},
str::FromStr,
};
impl MulAssign<&PolyOverZq> for PolyOverZq {
fn mul_assign(&mut self, other: &Self) {
if !self.compare_base(other) {
panic!("{}", self.call_compare_base_error(other).unwrap());
}
unsafe {
fmpz_mod_poly_mul(
&mut self.poly,
&self.poly,
&other.poly,
self.modulus.get_fmpz_mod_ctx_struct(),
)
};
}
}
impl MulAssign<&PolyOverZ> for PolyOverZq {
fn mul_assign(&mut self, other: &PolyOverZ) {
let other = PolyOverZq::from((other, self.get_mod()));
self.mul_assign(&other);
}
}
arithmetic_assign_trait_borrowed_to_owned!(MulAssign, mul_assign, PolyOverZq, PolyOverZq);
arithmetic_assign_trait_borrowed_to_owned!(MulAssign, mul_assign, PolyOverZq, PolyOverZ);
impl Mul for &PolyOverZq {
type Output = PolyOverZq;
fn mul(self, other: Self) -> Self::Output {
self.mul_safe(other).unwrap()
}
}
arithmetic_trait_borrowed_to_owned!(Mul, mul, PolyOverZq, PolyOverZq, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, PolyOverZq, PolyOverZq, PolyOverZq);
impl Mul<&PolyOverZ> for &PolyOverZq {
type Output = PolyOverZq;
fn mul(self, other: &PolyOverZ) -> Self::Output {
let mut out = PolyOverZq::from(&self.modulus);
unsafe {
fmpz_mod_poly_mul(
&mut out.poly,
&self.poly,
&PolyOverZq::from((other, &self.modulus)).poly,
self.modulus.get_fmpz_mod_ctx_struct(),
);
}
out
}
}
arithmetic_trait_reverse!(Mul, mul, PolyOverZ, PolyOverZq, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, PolyOverZq, PolyOverZ, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Mul, mul, PolyOverZ, PolyOverZq, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, PolyOverZq, PolyOverZ, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Mul, mul, PolyOverZ, PolyOverZq, PolyOverZq);
impl PolyOverZq {
pub fn mul_safe(&self, other: &Self) -> Result<PolyOverZq, MathError> {
if !self.compare_base(other) {
return Err(self.call_compare_base_error(other).unwrap());
}
let mut out = PolyOverZq::from_str(&format!("0 mod {}", self.modulus)).unwrap();
unsafe {
fmpz_mod_poly_mul(
&mut out.poly,
&self.poly,
&other.poly,
self.modulus.get_fmpz_mod_ctx_struct(),
);
}
Ok(out)
}
}
#[cfg(test)]
mod test_mul_assign {
use super::PolyOverZq;
use crate::integer::PolyOverZ;
use std::str::FromStr;
#[test]
fn correct_small() {
let mut a = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b = PolyOverZq::from_str("2 2 4 mod 7").unwrap();
a *= b;
assert_eq!(a, PolyOverZq::from_str("4 4 2 4 4 mod 7").unwrap());
}
#[test]
fn correct_large() {
let mut a = PolyOverZq::from_str(&format!(
"2 {} {} mod {}",
u64::MAX,
i64::MAX,
u64::MAX - 58
))
.unwrap();
let b = PolyOverZq::from_str(&format!(
"2 {} {} mod {}",
i64::MAX,
i64::MIN,
u64::MAX - 58
))
.unwrap();
a *= b;
assert_eq!(
a,
PolyOverZq::from_str(&format!(
"3 {} {} {} mod {}",
i128::from(i64::MAX) * 58,
i128::from(i64::MIN) * 58 + i128::from(i64::MAX) * i128::from(i64::MAX),
i128::from(i64::MAX) * i128::from(i64::MIN),
u64::MAX - 58
))
.unwrap()
);
}
#[test]
fn availability() {
let mut a = PolyOverZq::from_str("3 1 2 -3 mod 5").unwrap();
let b = PolyOverZq::from_str("3 -1 -2 3 mod 5").unwrap();
let c = PolyOverZ::from_str("2 -2 2").unwrap();
a *= &b;
a *= b;
a *= &c;
a *= c;
}
#[test]
#[should_panic]
fn mismatching_moduli() {
let mut a: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 8").unwrap();
a *= b;
}
}
#[cfg(test)]
mod test_mul {
use super::PolyOverZq;
use std::str::FromStr;
#[test]
fn mul() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("2 2 4 mod 7").unwrap();
let c: PolyOverZq = a * b;
assert_eq!(c, PolyOverZq::from_str("4 4 2 4 4 mod 7").unwrap());
}
#[test]
fn mul_borrow() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("2 2 4 mod 7").unwrap();
let c: PolyOverZq = &a * &b;
assert_eq!(c, PolyOverZq::from_str("4 4 2 4 4 mod 7").unwrap());
}
#[test]
fn mul_first_borrowed() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("2 2 4 mod 7").unwrap();
let c: PolyOverZq = &a * b;
assert_eq!(c, PolyOverZq::from_str("4 4 2 4 4 mod 7").unwrap());
}
#[test]
fn mul_second_borrowed() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("2 2 4 mod 7").unwrap();
let c: PolyOverZq = a * &b;
assert_eq!(c, PolyOverZq::from_str("4 4 2 4 4 mod 7").unwrap());
}
#[test]
fn mul_constant() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("1 2 mod 7").unwrap();
let c: PolyOverZq = &a * b;
assert_eq!(c, PolyOverZq::from_str("3 4 1 2 mod 7").unwrap());
assert_eq!(
a * PolyOverZq::from_str("0 mod 7").unwrap(),
PolyOverZq::from_str("0 mod 7").unwrap()
);
}
#[test]
fn mul_large_numbers() {
let a: PolyOverZq = PolyOverZq::from_str(&format!(
"2 {} {} mod {}",
u64::MAX,
i64::MAX,
u64::MAX - 58
))
.unwrap();
let b: PolyOverZq = PolyOverZq::from_str(&format!(
"2 {} {} mod {}",
i64::MAX,
i64::MIN,
u64::MAX - 58
))
.unwrap();
let c: PolyOverZq = a * &b;
assert_eq!(
c,
PolyOverZq::from_str(&format!(
"3 {} {} {} mod {}",
i128::from(i64::MAX) * 58,
i128::from(i64::MIN) * 58 + i128::from(i64::MAX) * i128::from(i64::MAX),
i128::from(i64::MAX) * i128::from(i64::MIN),
u64::MAX - 58
))
.unwrap()
);
}
#[test]
#[should_panic]
fn mul_mismatching_modulus() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 8").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("2 -5 4 mod 7").unwrap();
let _c: PolyOverZq = a * b;
}
#[test]
fn mul_safe_is_err() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 9").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("2 -5 4 mod 7").unwrap();
assert!(&a.mul_safe(&b).is_err());
}
}
#[cfg(test)]
mod test_mul_poly_over_z {
use super::PolyOverZq;
use crate::integer::PolyOverZ;
use std::str::FromStr;
#[test]
fn borrowed_correctness() {
let poly_1 = PolyOverZq::from_str(&format!("1 {} mod {}", i64::MAX, u64::MAX)).unwrap();
let poly_2 = PolyOverZ::from_str("2 1 2").unwrap();
let poly_cmp = PolyOverZq::from_str(&format!(
"2 {} {} mod {}",
i64::MAX,
i64::MAX as u64 * 2,
u64::MAX
))
.unwrap();
let poly_1 = &poly_1 * &poly_2;
assert_eq!(poly_cmp, poly_1);
}
#[test]
fn availability() {
let poly = PolyOverZq::from_str("3 1 2 3 mod 17").unwrap();
let z = PolyOverZ::from(2);
_ = poly.clone() * z.clone();
_ = z.clone() * poly.clone();
_ = &poly * &z;
_ = &z * &poly;
_ = &poly * z.clone();
_ = z.clone() * &poly;
_ = &z * poly.clone();
_ = poly.clone() * &z;
}
}