use dashu_base::ExtendedGcd;
use crate::{
buffer::Buffer,
error::panic_divide_by_invalid_modulo,
gcd,
helper_macros::debug_assert_zero,
memory::MemoryAllocation,
primitive::{locate_top_word_plus_one, lowest_dword, PrimitiveSigned},
shift::{shl_in_place, shr_in_place},
Sign,
};
use core::ops::{Div, DivAssign};
use super::{
modulo::{Modulo, ModuloDoubleRaw, ModuloLargeRaw, ModuloRepr, ModuloSingleRaw},
modulo_ring::{ModuloRingDouble, ModuloRingLarge, ModuloRingSingle},
};
impl<'a> Modulo<'a> {
#[inline]
pub fn inv(&self) -> Option<Modulo<'a>> {
match self.repr() {
ModuloRepr::Single(raw, ring) => ring.inv(raw).map(|v| Modulo::from_single(v, ring)),
ModuloRepr::Double(raw, ring) => ring.inv(raw).map(|v| Modulo::from_double(v, ring)),
ModuloRepr::Large(raw, ring) => {
ring.inv(raw.clone()).map(|v| Modulo::from_large(v, ring))
}
}
}
}
macro_rules! impl_mod_inv_for_primitive {
($ring:ty, $raw:ident) => {
impl $ring {
#[inline]
fn inv(&self, raw: &$raw) -> Option<$raw> {
let (g, _, coeff) = self.0.divisor().gcd_ext(&raw.0 >> self.shift());
if g != 1 {
return None;
}
let (sign, coeff) = coeff.to_sign_magnitude();
let coeff = $raw(coeff << self.shift());
if sign == Sign::Negative {
Some(self.negate(coeff))
} else {
Some(coeff)
}
}
}
};
}
impl_mod_inv_for_primitive!(ModuloRingSingle, ModuloSingleRaw);
impl_mod_inv_for_primitive!(ModuloRingDouble, ModuloDoubleRaw);
impl ModuloRingLarge {
#[inline]
fn inv(&self, mut raw: ModuloLargeRaw) -> Option<ModuloLargeRaw> {
let mut modulus = Buffer::allocate_exact(self.normalized_modulus().len());
modulus.push_slice(self.normalized_modulus());
debug_assert_zero!(shr_in_place(&mut modulus, self.shift()));
debug_assert_zero!(shr_in_place(&mut raw.0, self.shift()));
let raw_len = locate_top_word_plus_one(&raw.0);
let (is_g_one, b_sign) = match raw_len {
0 => return None,
1 => {
let (g, _, b_sign) = gcd::gcd_ext_word(&mut modulus, *raw.0.first().unwrap());
(g == 1, b_sign)
}
2 => {
let (g, _, b_sign) = gcd::gcd_ext_dword(&mut modulus, lowest_dword(&raw.0));
(g == 1, b_sign)
}
_ => {
let mut allocation = MemoryAllocation::new(gcd::memory_requirement_ext_exact(
modulus.len(),
raw_len,
));
let (g_len, b_len, b_sign) = gcd::gcd_ext_in_place(
&mut modulus,
&mut raw.0[..raw_len],
&mut allocation.memory(),
);
modulus[b_len..].fill(0);
(g_len == 1 && *raw.0.first().unwrap() == 1, b_sign)
}
};
if !is_g_one {
return None;
}
shl_in_place(&mut modulus, self.shift());
let mut inv = ModuloLargeRaw(modulus.into_boxed_slice());
debug_assert!(self.is_valid(&inv));
if b_sign == Sign::Negative {
self.negate_in_place(&mut inv);
}
Some(inv)
}
}
impl<'a> Div<Modulo<'a>> for Modulo<'a> {
type Output = Modulo<'a>;
#[inline]
fn div(self, rhs: Modulo<'a>) -> Modulo<'a> {
(&self).div(&rhs)
}
}
impl<'a> Div<&Modulo<'a>> for Modulo<'a> {
type Output = Modulo<'a>;
#[inline]
fn div(self, rhs: &Modulo<'a>) -> Modulo<'a> {
(&self).div(rhs)
}
}
impl<'a> Div<Modulo<'a>> for &Modulo<'a> {
type Output = Modulo<'a>;
#[inline]
fn div(self, rhs: Modulo<'a>) -> Modulo<'a> {
self.div(&rhs)
}
}
impl<'a> Div<&Modulo<'a>> for &Modulo<'a> {
type Output = Modulo<'a>;
#[inline]
fn div(self, rhs: &Modulo<'a>) -> Modulo<'a> {
#[allow(clippy::suspicious_arithmetic_impl)]
match rhs.inv() {
None => panic_divide_by_invalid_modulo(),
Some(inv_rhs) => self * inv_rhs,
}
}
}
impl<'a> DivAssign<Modulo<'a>> for Modulo<'a> {
#[inline]
fn div_assign(&mut self, rhs: Modulo<'a>) {
self.div_assign(&rhs)
}
}
impl<'a> DivAssign<&Modulo<'a>> for Modulo<'a> {
#[inline]
fn div_assign(&mut self, rhs: &Modulo<'a>) {
*self = (&*self).div(rhs)
}
}