modmath/
inv.rs

1mod signed;
2use signed::Signed;
3
4/*
5Based on this pseudocode:
6
7function inverse(a, n)
8    t := 0;     newt := 1
9    r := n;     newr := a
10
11    while newr ≠ 0 do
12        quotient := r div newr
13        (t, newt) := (newt, t − quotient × newt)
14        (r, newr) := (newr, r − quotient × newr)
15
16    if r > 1 then
17        return "a is not invertible"
18    if t < 0 then
19        t := t + n
20
21    return t
22*/
23
24/// # Modular Inverse (Strict)
25/// Most constrained version that works with references. Requires
26/// reference-based operations for division and subtraction.
27pub fn strict_mod_inv<T>(a: T, modulus: &T) -> Option<T>
28where
29    T: num_traits::Zero
30        + num_traits::One
31        + PartialEq
32        + core::ops::Sub<Output = T>
33        + core::cmp::PartialOrd,
34    for<'a> T: core::ops::Mul<&'a T, Output = T>
35        + core::ops::Div<&'a T, Output = T>
36        + core::ops::Sub<&'a T, Output = T>
37        + core::ops::Add<&'a T, Output = T>
38        + core::ops::AddAssign<&'a T>
39        + core::cmp::PartialOrd,
40    for<'a> &'a T: core::ops::Div<&'a T, Output = T> + core::ops::Sub<T, Output = T>,
41{
42    let mut t = Signed::new(T::zero(), false);
43    let mut new_t = Signed::new(T::one(), false);
44    // makes a clone of modulus
45    let mut r = T::zero() + modulus;
46    let mut new_r = a;
47
48    while new_r != T::zero() {
49        let quotient = &r / &new_r;
50
51        // clone
52        let tmp_t = Signed::new(T::zero(), false) + &new_t;
53        new_t = t - new_t * &quotient;
54        t = tmp_t;
55
56        // clone
57        let tmp_r = T::zero() + &new_r;
58        new_r = r - new_r * &quotient;
59        r = tmp_r;
60    }
61
62    if r > T::one() {
63        return None;
64    }
65
66    if t < T::zero().into() {
67        t = t + modulus;
68    }
69
70    Some(t.into_inner())
71}
72
73// Another version with less constraints
74// Still uses input references. Requires Clone
75
76/// # Modular Inverse (Constrained)
77/// Version that works with references. Requires Clone and
78/// reference-based operations.
79pub fn constrained_mod_inv<T>(a: T, modulus: &T) -> Option<T>
80where
81    T: num_traits::Zero
82        + num_traits::One
83        + Clone
84        + PartialEq
85        + core::cmp::PartialOrd
86        + core::ops::Sub<Output = T>,
87    for<'a> T: core::ops::Add<&'a T, Output = T> + core::ops::Sub<&'a T, Output = T>,
88    for<'a> &'a T: core::ops::Sub<T, Output = T> + core::ops::Div<&'a T, Output = T>,
89{
90    let mut t = Signed::new(T::zero(), false);
91    let mut new_t = Signed::new(T::one(), false);
92    let mut r = modulus.clone();
93    let mut new_r = a;
94
95    while new_r != T::zero() {
96        let quotient = &r / &new_r;
97
98        let tmp_t = new_t.clone();
99        new_t = t - new_t * quotient.clone();
100        t = tmp_t;
101
102        let tmp_r = new_r.clone();
103        new_r = r - new_r * quotient;
104        r = tmp_r;
105    }
106
107    if r > T::one() {
108        return None;
109    }
110
111    if t < T::zero().into() {
112        t = t + modulus;
113    }
114
115    Some(t.into_inner())
116}
117
118/// # Modular Inverse (Basic)
119/// Simple version that operates on values and copies them.
120pub fn basic_mod_inv<T>(a: T, modulus: T) -> Option<T>
121where
122    T: num_traits::Zero
123        + num_traits::One
124        + Copy
125        + PartialEq
126        + core::ops::Div<Output = T>
127        + core::ops::Sub<Output = T>
128        + core::cmp::PartialOrd,
129{
130    let mut t = Signed::new(T::zero(), false);
131    let mut new_t = Signed::new(T::one(), false);
132    let mut r = Signed::new(modulus, false);
133    let mut new_r = Signed::new(a, false);
134
135    while new_r != Signed::new(T::zero(), false) {
136        let quotient = r / new_r;
137
138        let tmp_t = new_t;
139        new_t = t - new_t * quotient;
140        t = tmp_t;
141
142        let tmp_r = new_r;
143        new_r = r - new_r * quotient;
144        r = tmp_r;
145    }
146
147    if r > T::one().into() {
148        return None;
149    }
150
151    if t < T::zero().into() {
152        t = t + modulus.into();
153    }
154
155    Some(t.into_inner())
156}
157
158#[cfg(test)]
159macro_rules! select_mod_inv {
160    ($mod_inv:path, $t:ty, by_ref) => {
161        fn mod_inv(a: $t, modulus: &$t) -> Option<$t> {
162            $mod_inv(a, modulus)
163        }
164    };
165    ($mod_inv:path, $t:ty, by_val) => {
166        fn mod_inv(a: $t, modulus: &$t) -> Option<$t> {
167            $mod_inv(a, *modulus)
168        }
169    };
170}
171
172#[cfg(test)]
173macro_rules! generate_mod_inv_tests_block_1 {
174    ($mod_inv:path, $t:ty , $by_ref:tt) => {
175        select_mod_inv!($mod_inv, $t, $by_ref);
176
177        #[test]
178        fn test_mod_inv_1_mod_13() {
179            assert_eq!(mod_inv(0u32, &7u32), None);
180            assert_eq!(mod_inv(1u32, &7u32).unwrap(), 1);
181            assert_eq!(mod_inv(6u32, &8u32), None);
182            assert_eq!(mod_inv(1u32, &13u32).unwrap(), 1u32); // 1 * 1 ≡ 1 (mod 13)
183            assert_eq!(mod_inv(8u32, &13u32).unwrap(), 5u32); // Check: 8×5=40≡1mod  13.8×5=40≡1mod13.
184            assert_eq!(mod_inv(12u32, &13u32).unwrap(), 12u32); // 1 * 1 ≡ 1 (mod 13)
185            assert_eq!(mod_inv(14u32, &13u32).unwrap(), 1); // 14 * 10 ≡ 1 (mod 13)
186            assert_eq!(mod_inv(15u32, &13u32).unwrap(), 7u32); // 15 * 9 ≡ 1 (mod 13)
187            assert_eq!(mod_inv(16u32, &13u32).unwrap(), 9u32); // 16 * 8 ≡ 1 (mod 13)
188            assert_eq!(mod_inv(10u32, &17).unwrap(), 12);
189        }
190    };
191}
192
193#[cfg(test)]
194mod strict_mod_inv_tests {
195    generate_mod_inv_tests_block_1!(super::strict_mod_inv, u32, by_ref);
196}
197
198#[cfg(test)]
199mod constrained_mod_inv_tests {
200    generate_mod_inv_tests_block_1!(super::constrained_mod_inv, u32, by_ref);
201}
202
203#[cfg(test)]
204mod basic_mod_inv_tests {
205    generate_mod_inv_tests_block_1!(super::basic_mod_inv, u32, by_val);
206}
207
208#[cfg(test)]
209macro_rules! inv_test_module {
210    (
211        $stem:ident,
212        $type_path:path,
213        $(type $type_def:ty = $type_expr:ty;)? // Optional type definition
214        strict: $strict:expr,
215        constrained: $constrained:expr,
216        basic: $basic:expr,
217    ) => {
218        paste::paste! {
219            mod [<$stem _tests>] {
220                #[allow(unused_imports)]
221                use $type_path;
222                $( type $type_def = $type_expr; )?
223
224                #[test]
225                #[allow(unused_variables)]
226                fn test_mod_inv_basic() {
227                    let a_val = 5u8;
228                    let a = U256::from(a_val);
229                    let modulus = U256::from(13u8);
230                    let result_val = 8u8;
231                    let result = U256::from(result_val);
232
233                    crate::maybe_test!($strict, assert_eq!(super::strict_mod_inv(a, &modulus), Some(result)));
234                    let a = U256::from(a_val);
235                    let result = U256::from(result_val);
236                    crate::maybe_test!($constrained, assert_eq!(super::constrained_mod_inv(a, &modulus), Some(result)));
237                    let a = U256::from(a_val);
238                    let result = U256::from(result_val);
239                    crate::maybe_test!($basic, assert_eq!(super::basic_mod_inv(a, modulus), Some(result)));
240                }
241            }
242        }
243    };
244}
245
246#[cfg(test)]
247mod bnum_inv_tests {
248    use super::basic_mod_inv;
249    use super::constrained_mod_inv;
250    use super::strict_mod_inv;
251
252    inv_test_module!(
253        bnum,
254        bnum::types::U256,
255        strict: on,
256        constrained: on,
257        basic: on,
258    );
259
260    inv_test_module!(
261        bnum_patched,
262        bnum_patched::types::U256,
263        strict: on,
264        constrained: on,
265        basic: on,
266    );
267
268    inv_test_module!(
269        crypto_bigint,
270        crypto_bigint::U256,
271        strict: off, // &'a Div missing
272        constrained: off, // &'a Div missing
273        basic: on,
274    );
275
276    inv_test_module!(
277        crypto_bigint_patched,
278        crypto_bigint_patched::U256,
279        strict: on,
280        constrained: on,
281        basic: on,
282    );
283
284    inv_test_module!(
285        num_bigint,
286        num_bigint::BigUint,
287        type U256 = num_bigint::BigUint;
288        strict: on,
289        constrained: on,
290        basic: off, // Copy is not implemented, heap
291    );
292
293    inv_test_module!(
294        num_bigint_patched,
295        num_bigint_patched::BigUint,
296        type U256 = num_bigint_patched::BigUint;
297        strict: on,
298        constrained: on,
299        basic: off, // Copy is not implemented, heap
300    );
301
302    inv_test_module!(
303        ibig,
304        ibig::UBig,
305        type U256 = ibig::UBig;
306        strict: on,
307        constrained: on,
308        basic: off, // Copy is not implemented, heap
309    );
310
311    inv_test_module!(
312        ibig_patched,
313        ibig_patched::UBig,
314        type U256 = ibig_patched::UBig;
315        strict: on,
316        constrained: on,
317        basic: off, // Copy is not implemented, heap
318    );
319
320    inv_test_module!(
321        fixed_bigint,
322        fixed_bigint::FixedUInt,
323        type U256 = fixed_bigint::FixedUInt<u8, 4>;
324        strict: on,
325        constrained: on,
326        basic: on,
327    );
328}