use super::super::PolyOverZq;
use crate::{
error::MathError,
integer::PolyOverZ,
macros::arithmetics::{
arithmetic_assign_trait_borrowed_to_owned, arithmetic_trait_borrowed_to_owned,
arithmetic_trait_mixed_borrowed_owned, arithmetic_trait_reverse,
},
traits::CompareBase,
};
use flint_sys::fmpz_mod_poly::fmpz_mod_poly_add;
use std::{
ops::{Add, AddAssign},
str::FromStr,
};
impl AddAssign<&PolyOverZq> for PolyOverZq {
fn add_assign(&mut self, other: &Self) {
if !self.compare_base(other) {
panic!("{}", self.call_compare_base_error(other).unwrap());
}
unsafe {
fmpz_mod_poly_add(
&mut self.poly,
&self.poly,
&other.poly,
self.modulus.get_fmpz_mod_ctx_struct(),
)
};
}
}
impl AddAssign<&PolyOverZ> for PolyOverZq {
fn add_assign(&mut self, other: &PolyOverZ) {
let other = PolyOverZq::from((other, self.get_mod()));
self.add_assign(&other);
}
}
arithmetic_assign_trait_borrowed_to_owned!(AddAssign, add_assign, PolyOverZq, PolyOverZq);
arithmetic_assign_trait_borrowed_to_owned!(AddAssign, add_assign, PolyOverZq, PolyOverZ);
impl Add for &PolyOverZq {
type Output = PolyOverZq;
fn add(self, other: Self) -> Self::Output {
self.add_safe(other).unwrap()
}
}
arithmetic_trait_borrowed_to_owned!(Add, add, PolyOverZq, PolyOverZq, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Add, add, PolyOverZq, PolyOverZq, PolyOverZq);
impl Add<&PolyOverZ> for &PolyOverZq {
type Output = PolyOverZq;
fn add(self, other: &PolyOverZ) -> Self::Output {
let mut out = PolyOverZq::from(&self.modulus);
unsafe {
fmpz_mod_poly_add(
&mut out.poly,
&self.poly,
&PolyOverZq::from((other, &self.modulus)).poly,
self.modulus.get_fmpz_mod_ctx_struct(),
);
}
out
}
}
arithmetic_trait_reverse!(Add, add, PolyOverZ, PolyOverZq, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Add, add, PolyOverZq, PolyOverZ, PolyOverZq);
arithmetic_trait_borrowed_to_owned!(Add, add, PolyOverZ, PolyOverZq, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Add, add, PolyOverZq, PolyOverZ, PolyOverZq);
arithmetic_trait_mixed_borrowed_owned!(Add, add, PolyOverZ, PolyOverZq, PolyOverZq);
impl PolyOverZq {
pub fn add_safe(&self, other: &Self) -> Result<PolyOverZq, MathError> {
if !self.compare_base(other) {
return Err(self.call_compare_base_error(other).unwrap());
}
let mut out = PolyOverZq::from_str(&format!("0 mod {}", self.modulus)).unwrap();
unsafe {
fmpz_mod_poly_add(
&mut out.poly,
&self.poly,
&other.poly,
self.modulus.get_fmpz_mod_ctx_struct(),
);
}
Ok(out)
}
}
#[cfg(test)]
mod test_add_assign {
use super::PolyOverZq;
use crate::integer::PolyOverZ;
use std::str::FromStr;
#[test]
fn correct_small() {
let mut a = PolyOverZq::from_str("3 6 2 -3 mod 7").unwrap();
let b = PolyOverZq::from_str("5 1 2 5 1 2 mod 7").unwrap();
let cmp = PolyOverZq::from_str("5 0 4 2 1 2 mod 7").unwrap();
a += b;
assert_eq!(cmp, a);
}
#[test]
fn correct_large() {
let mut a = PolyOverZq::from_str(&format!(
"3 {} {} {} mod {}",
u32::MAX,
i32::MIN,
i32::MAX,
u64::MAX
))
.unwrap();
let b = PolyOverZq::from_str(&format!("2 {} {} mod {}", u32::MAX, i32::MAX, u64::MAX))
.unwrap();
let cmp = PolyOverZq::from_str(&format!(
"3 {} -1 {} mod {}",
u64::from(u32::MAX) * 2,
i32::MAX,
u64::MAX
))
.unwrap();
a += b;
assert_eq!(cmp, a);
}
#[test]
fn availability() {
let mut a = PolyOverZq::from_str("3 1 2 -3 mod 5").unwrap();
let b = PolyOverZq::from_str("3 -1 -2 3 mod 5").unwrap();
let c = PolyOverZ::from_str("2 -2 2").unwrap();
a += &b;
a += b;
a += &c;
a += c;
}
#[test]
#[should_panic]
fn mismatching_moduli() {
let mut a: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 8").unwrap();
a += b;
}
}
#[cfg(test)]
mod test_add {
use super::PolyOverZq;
use std::str::FromStr;
#[test]
fn add() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let c: PolyOverZq = a + b;
assert_eq!(c, PolyOverZq::from_str("3 4 1 2 mod 7").unwrap());
}
#[test]
fn add_borrow() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let c: PolyOverZq = &a + &b;
assert_eq!(c, PolyOverZq::from_str("3 4 1 2 mod 7").unwrap());
}
#[test]
fn add_first_borrowed() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let c: PolyOverZq = &a + b;
assert_eq!(c, PolyOverZq::from_str("3 4 1 2 mod 7").unwrap());
}
#[test]
fn add_second_borrowed() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let c: PolyOverZq = a + &b;
assert_eq!(c, PolyOverZq::from_str("3 4 1 2 mod 7").unwrap());
}
#[test]
fn add_reduce() {
let a: PolyOverZq = PolyOverZq::from_str("3 2 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 6 mod 7").unwrap();
let c: PolyOverZq = a + b;
assert_eq!(c, PolyOverZq::from_str("2 4 1 mod 7").unwrap());
}
#[test]
fn add_large_numbers() {
let a: PolyOverZq = PolyOverZq::from_str(&format!(
"3 -{} 4 {} mod {}",
u64::MAX,
i64::MIN,
u64::MAX - 58
))
.unwrap();
let b: PolyOverZq = PolyOverZq::from_str(&format!(
"3 {} 4 {} mod {}",
i64::MIN,
i64::MIN,
u64::MAX - 58
))
.unwrap();
let c: PolyOverZq = a + b;
assert!(
c == PolyOverZq::from_str(&format!(
"3 -{} 8 {} mod {}",
i128::from(i64::MAX) + 59,
-59,
u64::MAX - 58
))
.unwrap()
);
}
#[test]
#[should_panic]
fn add_mismatching_modulus() {
let a: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 8").unwrap();
let _c: PolyOverZq = a + b;
}
#[test]
fn add_safe_is_err() {
let a: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 7").unwrap();
let b: PolyOverZq = PolyOverZq::from_str("3 -5 4 1 mod 11").unwrap();
assert!(&a.add_safe(&b).is_err());
}
}
#[cfg(test)]
mod test_mul_poly_over_z {
use super::PolyOverZq;
use crate::integer::PolyOverZ;
use std::str::FromStr;
#[test]
fn borrowed_correctness() {
let poly_1 = PolyOverZq::from_str(&format!("1 {} mod {}", i64::MAX, u64::MAX)).unwrap();
let poly_2 = PolyOverZ::from_str("2 1 2").unwrap();
let poly_cmp =
PolyOverZq::from_str(&format!("2 {} 2 mod {}", i64::MAX as u64 + 1, u64::MAX))
.unwrap();
let poly_1 = &poly_1 + &poly_2;
assert_eq!(poly_cmp, poly_1);
}
#[test]
fn availability() {
let poly = PolyOverZq::from_str("3 1 2 3 mod 17").unwrap();
let z = PolyOverZ::from(2);
_ = poly.clone() + z.clone();
_ = z.clone() + poly.clone();
_ = &poly + &z;
_ = &z + &poly;
_ = &poly + z.clone();
_ = z.clone() + &poly;
_ = &z + poly.clone();
_ = poly.clone() + &z;
}
}