1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
use dashu_base::ExtendedGcd;
use crate::{
buffer::Buffer,
gcd,
helper_macros::debug_assert_zero,
memory::MemoryAllocation,
primitive::{locate_top_word_plus_one, lowest_dword, PrimitiveSigned},
shift::{shl_in_place, shr_in_place},
Sign,
};
use super::{
modulo::{Modulo, ModuloDoubleRaw, ModuloLargeRaw, ModuloRepr, ModuloSingleRaw},
modulo_ring::{ModuloRingDouble, ModuloRingLarge, ModuloRingSingle},
};
impl<'a> Modulo<'a> {
#[inline]
pub fn inv(self) -> Option<Modulo<'a>> {
match self.into_repr() {
ModuloRepr::Single(raw, ring) => ring.inv(raw).map(|v| Modulo::from_single(v, ring)),
ModuloRepr::Double(raw, ring) => ring.inv(raw).map(|v| Modulo::from_double(v, ring)),
ModuloRepr::Large(raw, ring) => ring.inv(raw).map(|v| Modulo::from_large(v, ring)),
}
}
}
macro_rules! impl_mod_inv_for_primitive {
($ring:ty, $raw:ident) => {
impl $ring {
#[inline]
fn inv(&self, raw: $raw) -> Option<$raw> {
if raw.0 == 0 {
return None;
}
let (g, _, coeff) = self.0.divisor().gcd_ext(raw.0 >> self.shift());
if g != 1 {
return None;
}
let (sign, coeff) = coeff.to_sign_magnitude();
let coeff = $raw(coeff << self.shift());
if sign == Sign::Negative {
Some(self.negate(coeff))
} else {
Some(coeff)
}
}
}
};
}
impl_mod_inv_for_primitive!(ModuloRingSingle, ModuloSingleRaw);
impl_mod_inv_for_primitive!(ModuloRingDouble, ModuloDoubleRaw);
impl ModuloRingLarge {
#[inline]
fn inv(&self, mut raw: ModuloLargeRaw) -> Option<ModuloLargeRaw> {
let mut modulus = Buffer::allocate_exact(self.normalized_modulus().len());
modulus.push_slice(self.normalized_modulus());
debug_assert_zero!(shr_in_place(&mut modulus, self.shift()));
debug_assert_zero!(shr_in_place(&mut raw.0, self.shift()));
let raw_len = locate_top_word_plus_one(&raw.0);
let (is_g_one, b_sign) = match raw_len {
0 => return None,
1 => {
let (g, _, b_sign) = gcd::gcd_ext_word(&mut modulus, *raw.0.first().unwrap());
(g == 1, b_sign)
}
2 => {
let (g, _, b_sign) = gcd::gcd_ext_dword(&mut modulus, lowest_dword(&raw.0));
(g == 1, b_sign)
}
_ => {
let mut allocation = MemoryAllocation::new(gcd::memory_requirement_ext_exact(
modulus.len(),
raw_len,
));
let (g_len, b_len, b_sign) = gcd::gcd_ext_in_place(
&mut modulus,
&mut raw.0[..raw_len],
&mut allocation.memory(),
);
modulus[b_len..].fill(0);
(g_len == 1 && *raw.0.first().unwrap() == 1, b_sign)
}
};
if !is_g_one {
return None;
}
shl_in_place(&mut modulus, self.shift());
let mut inv = ModuloLargeRaw(modulus.into_boxed_slice());
debug_assert!(self.is_valid(&inv));
if b_sign == Sign::Negative {
self.negate_in_place(&mut inv);
}
Some(inv)
}
}