use core::borrow::Borrow;
use crate::Gf2Poly;
impl Gf2Poly {
pub fn mod_mul(&self, lhs: &Self, rhs: &Self) -> Self {
let mut res = lhs * rhs;
if res.deg() >= self.deg() {
res %= self;
}
res
}
pub fn mod_inv(&self, elem: &Self) -> Option<Self> {
if self.is_one() {
return Some(Gf2Poly::zero());
}
let (gcd, [_, inv]) = self.clone().xgcd(elem.clone());
if !gcd.is_one() { None } else { Some(inv) }
}
pub fn mod_div(&self, lhs: &Self, rhs: &Self) -> Option<Self> {
let inv = self.mod_inv(rhs)?;
Some(self.mod_mul(lhs, &inv))
}
pub fn mod_square(&self, elem: &Self) -> Self {
let mut square = elem.square();
if square.deg() >= self.deg() {
square %= self;
}
square
}
pub fn mod_power(&self, base: &Self, mut n: u64) -> Self {
if self.is_one() {
return Gf2Poly::zero();
}
if n == 0 {
return Gf2Poly::one();
}
if base.is_zero() {
return Gf2Poly::zero();
}
let mut base = base;
let mut base_val;
let mut result = Gf2Poly::one();
while n > 0 {
if n % 2 == 1 {
result = self.mod_mul(&result, base);
}
base_val = self.mod_square(base);
base = &base_val;
n >>= 1;
}
result
}
}
pub struct Gf2PolyMod<T: Borrow<Gf2Poly>> {
modulus: T,
barrett_reducer: Gf2Poly,
}
impl<T: Borrow<Gf2Poly>> Gf2PolyMod<T> {
pub fn new(modulus: T) -> Self {
if modulus.borrow().is_zero() {
panic!("Zero modulus is not allowed.");
}
let barrett_reducer =
Gf2Poly::x_to_the_power_of(2 * modulus.borrow().deg()) / modulus.borrow();
Gf2PolyMod {
modulus,
barrett_reducer,
}
}
pub fn modulus(&self) -> &Gf2Poly {
self.modulus.borrow()
}
pub fn modulus_value(self) -> T {
self.modulus
}
pub fn deg(&self) -> u64 {
self.modulus().deg()
}
fn barrett_step(&self, upper_half: Gf2Poly) -> Gf2Poly {
(upper_half * &self.barrett_reducer) >> self.deg()
}
fn barrett_divmod(&self, poly: &Gf2Poly) -> (Gf2Poly, Gf2Poly) {
let quotient = self.barrett_step(poly >> self.deg());
let remainder = poly - "ient * self.modulus();
(quotient, remainder)
}
fn barrett_remainder(&self, poly: &Gf2Poly) -> Gf2Poly {
self.barrett_divmod(poly).1
}
pub fn remainder(&self, elem: &Gf2Poly) -> Gf2Poly {
if elem.deg() < self.deg() {
return elem.clone();
}
if self.modulus().is_one() {
return Gf2Poly::zero();
}
let step = self.deg();
if elem.deg() < 2 * step {
return self.barrett_remainder(elem);
}
let last_segment = elem.deg() / step;
let range = |segment: u64| segment * step..(segment + 1) * step;
let mut remainder = elem.subrange(range(last_segment));
for segment in (0..last_segment).rev() {
remainder <<= step;
remainder += elem.subrange(range(segment));
remainder = self.barrett_remainder(&remainder);
}
remainder
}
pub fn divmod(&self, elem: &Gf2Poly) -> (Gf2Poly, Gf2Poly) {
if elem.deg() < self.deg() {
return (Gf2Poly::zero(), elem.clone());
}
if self.modulus().is_one() {
return (elem.clone(), Gf2Poly::zero());
}
let step = self.deg();
if elem.deg() < 2 * step {
let quotient = self.barrett_step(elem >> self.deg());
let remainder = elem - "ient * self.modulus();
return (quotient, remainder);
}
let last_segment = elem.deg() / step;
let range = |segment: u64| segment * step..(segment + 1) * step;
let mut remainder = elem.subrange(range(last_segment));
let mut quotient = Gf2Poly::zero();
for segment in (0..last_segment).rev() {
remainder <<= step;
remainder += elem.subrange(range(segment));
let (q, r) = self.barrett_divmod(&remainder);
remainder = r;
quotient.fused_shl_add(&q, segment * step);
}
(quotient, remainder)
}
pub fn quotient(&self, elem: &Gf2Poly) -> Gf2Poly {
self.divmod(elem).0
}
pub fn mul(&self, lhs: &Gf2Poly, rhs: &Gf2Poly) -> Gf2Poly {
let product = lhs * rhs;
if product.deg() < self.deg() {
return product;
}
self.remainder(&product)
}
pub fn square(&self, elem: &Gf2Poly) -> Gf2Poly {
let square = elem.square();
if square.deg() < self.deg() {
return square;
}
self.remainder(&square)
}
pub fn inverse(&self, elem: &Gf2Poly) -> Option<Gf2Poly> {
self.modulus().mod_inv(elem)
}
pub fn div(&self, lhs: &Gf2Poly, rhs: &Gf2Poly) -> Option<Gf2Poly> {
self.modulus().mod_div(lhs, rhs)
}
}
#[cfg(test)]
mod tests {
use crate::prop_assert_poly_eq;
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn modmul_is_mul_remainder(modulus: Gf2Poly, a: Gf2Poly, b: Gf2Poly) {
let res = modulus.mod_mul(&a, &b);
let rem = &a * &b % &modulus;
prop_assert_poly_eq!(res, rem);
}
#[test]
fn modulus_mul(modulo: Gf2Poly, a: Gf2Poly, b: Gf2Poly) {
prop_assume!(!modulo.is_zero());
let res1 = modulo.mod_mul(&a, &b);
let modulus = Gf2PolyMod::new(&modulo);
let res2 = modulus.mul(&a, &b);
prop_assert_poly_eq!(res1, res2);
}
#[test]
fn modular_inv(modulus: Gf2Poly, elem: Gf2Poly) {
let inv = modulus.mod_inv(&elem);
let elem = &elem % &modulus;
if let Some(inv) = inv {
prop_assert_eq!(modulus.mod_mul(&elem, &inv), Gf2Poly::one() % modulus);
}
}
#[test]
fn modulus_remainder(modulus: Gf2Poly, elem: Gf2Poly) {
prop_assume!(!modulus.is_zero());
let modulus = Gf2PolyMod::new(modulus);
let res1 = modulus.remainder(&elem);
let res2 = elem % &modulus.modulus;
prop_assert_poly_eq!(res1, res2);
}
#[test]
fn modulus_divisor(modulus: Gf2Poly, elem: Gf2Poly) {
prop_assume!(!modulus.is_zero());
let modulus = Gf2PolyMod::new(modulus);
let res1 = modulus.quotient(&elem);
let res2 = elem / &modulus.modulus;
prop_assert_poly_eq!(res1, res2);
}
#[test]
fn modular_div(modulo: Gf2Poly, lhs: Gf2Poly, rhs: Gf2Poly) {
prop_assume!(!rhs.is_zero());
let div = modulo.mod_div(&lhs, &rhs);
if let Some(div) = div {
prop_assert_eq!(modulo.mod_mul(&div, &rhs), lhs % modulo);
}
}
#[test]
fn exponent_homo(modulo: Gf2Poly, a: Gf2Poly, n in 0..128u64, m in 0..128u64) {
prop_assert_poly_eq!(
modulo.mod_mul(&modulo.mod_power(&a, n), &modulo.mod_power(&a, m)),
modulo.mod_power(&a, n + m));
}
#[test]
fn power_homo(modulo: Gf2Poly, a: Gf2Poly, b: Gf2Poly, n in 0..128u64) {
prop_assert_poly_eq!(
modulo.mod_power(&modulo.mod_mul(&a, &b), n),
modulo.mod_mul(&modulo.mod_power(&a, n), &modulo.mod_power(&b, n))
);
}
}
}