Skip to main content

crypto_bigint/uint/
invert_mod.rs

1use super::Uint;
2use crate::{
3    Choice, CtOption, InvertMod, Limb, NonZero, Odd, U64, UintRef, modular::safegcd, mul::karatsuba,
4};
5
6/// Perform a modified recursive Hensel quadratic modular inversion to calculate
7/// `a^-1 mod w^p` given `a^-1 mod w^k` where `w` is the size of `Limb`.
8/// For reference see Algorithm 2: <https://arxiv.org/pdf/1209.6626>
9///
10/// `p` is determined by the length of the in-out buffer `buf`, which must be
11/// pre-populated with `a^-1 mod w^k` (constituting `k` limbs).
12///
13/// This method uses recursion, but the maximum depth is limited by
14/// the bit-width of the number of limbs being inverted (`p`).
15///
16/// This method is variable time in `k` and `p` only.
17///
18/// `scratch` must be a pair of mutable buffers, each with capacity at least `p`.
19#[inline]
20pub(crate) const fn expand_invert_mod2k(
21    a: &Odd<UintRef>,
22    buf: &mut UintRef,
23    mut k: usize,
24    scratch: (&mut UintRef, &mut UintRef),
25) {
26    assert!(k > 0);
27    let p = buf.nlimbs();
28    let zs = p.trailing_zeros();
29
30    // Calculate a target width at which we may need to trim the output of
31    // the doubling loop. We reduce the size of `p` by eliminating multiple factors
32    // of two or a single odd factor, recursing until the target width is small enough
33    // to calculate by doubling without significant overhead.
34    let mut target = if zs > 0 { p >> zs } else { p.div_ceil(2) };
35    if target > 8 {
36        expand_invert_mod2k(a, buf.leading_mut(target), k, (scratch.0, scratch.1));
37        k = target;
38        target = p;
39    } else if target <= k {
40        target = p;
41    }
42
43    // Perform the required number of doublings.
44    while k < p {
45        let mut k2 = k * 2;
46        // `target` represents the point at which we may need to trim the output before
47        // continuing with the doubling until we reach `p`.
48        if k2 >= target {
49            (k2, target) = (target, p);
50        }
51        expand_invert_mod2k_step(a, buf.leading_mut(k2), k, (scratch.0, scratch.1));
52        k = k2;
53    }
54}
55
56/// One step of the Hensel quadratic modular inverse calculation, doubling the width
57/// of the inverted output, and wrapping at capacity of `buf`.
58#[inline(always)]
59const fn expand_invert_mod2k_step(
60    a: &Odd<UintRef>,
61    buf: &mut UintRef,
62    buf_init_len: usize,
63    scratch: (&mut UintRef, &mut UintRef),
64) {
65    let new_len = buf.nlimbs();
66
67    assert!(
68        scratch.0.nlimbs() >= new_len
69            && scratch.1.nlimbs() >= new_len
70            && buf_init_len < new_len
71            && buf_init_len >= (new_len >> 1)
72    );
73
74    // Calculate u0^2, wrapping at `new_len` words
75    let u0_p2 = scratch.0.leading_mut(new_len);
76    u0_p2.fill(Limb::ZERO);
77    karatsuba::wrapping_square(buf.leading(buf_init_len), u0_p2);
78
79    // tmp = u0^2•a
80    let tmp = scratch.1.leading_mut(new_len);
81    tmp.fill(Limb::ZERO);
82    karatsuba::wrapping_mul(u0_p2, a.as_ref(), tmp, false);
83
84    // u1 = u0 << 1
85    buf.shl1_assign();
86    // u1 -= u0^2•a
87    buf.borrowing_sub_assign(tmp, Limb::ZERO);
88}
89
90impl<const LIMBS: usize> Uint<LIMBS> {
91    /// Computes 1/`self` mod `2^k`.
92    /// This method is constant-time w.r.t. `self` but not `k`.
93    ///
94    /// If the inverse does not exist (`k > 0` and `self` is even),
95    /// returns `Choice::FALSE` as the second element of the tuple,
96    /// otherwise returns `Choice::TRUE`.
97    #[deprecated(since = "0.7.0", note = "please use `invert_mod2k_vartime` instead")]
98    #[must_use]
99    pub const fn inv_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
100        self.invert_mod2k_vartime(k)
101    }
102
103    /// Computes 1/`self` mod `2^k`.
104    /// This method is constant-time w.r.t. `self` but not `k`.
105    ///
106    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > Self::BITS`),
107    /// returns `CtOption::none`, otherwise returns `CtOption::some`.
108    #[must_use]
109    pub const fn invert_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
110        if k == 0 {
111            CtOption::some(Self::ZERO)
112        } else if k > Self::BITS {
113            CtOption::new(Self::ZERO, Choice::FALSE)
114        } else {
115            let is_some = self.is_odd();
116            let inv = Odd(Uint::select(&Uint::ONE, self, is_some)).invert_mod2k_vartime(k);
117            CtOption::new(inv, is_some)
118        }
119    }
120
121    /// Computes 1/`self` mod `2^k`.
122    ///
123    /// If the inverse does not exist (`k > 0` and `self` is even, `k > Self::BITS`),
124    /// returns `CtOption::none`, otherwise returns `CtOption::some`.
125    #[deprecated(since = "0.7.0", note = "please use `invert_mod2k` instead")]
126    #[must_use]
127    pub const fn inv_mod2k(&self, k: u32) -> CtOption<Self> {
128        self.invert_mod2k(k)
129    }
130
131    /// Computes 1/`self` mod `2^k`.
132    ///
133    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > Self::BITS`),
134    /// returns `CtOption::none`, otherwise returns `CtOption::some`.
135    #[must_use]
136    pub const fn invert_mod2k(&self, k: u32) -> CtOption<Self> {
137        let is_some =
138            Choice::from_u32_le(k, Self::BITS).and(Choice::from_u32_nz(k).not().or(self.is_odd()));
139        let inv = Odd(Uint::select(&Uint::ONE, self, is_some)).invert_mod_precision();
140        CtOption::new(inv.restrict_bits(k), is_some)
141    }
142
143    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
144    #[deprecated(since = "0.7.0", note = "please use `invert_odd_mod` instead")]
145    #[must_use]
146    pub const fn inv_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
147        self.invert_odd_mod(modulus)
148    }
149
150    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
151    #[must_use]
152    pub const fn invert_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
153        safegcd::invert_odd_mod::<LIMBS, false>(self, modulus)
154    }
155
156    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
157    ///
158    /// This method is variable-time with respect to `self`.
159    #[must_use]
160    pub const fn invert_odd_mod_vartime(&self, modulus: &Odd<Self>) -> CtOption<Self> {
161        safegcd::invert_odd_mod::<LIMBS, true>(self, modulus)
162    }
163
164    /// Computes the multiplicative inverse of `self` mod `modulus`.
165    ///
166    /// Returns some if an inverse exists, otherwise none.
167    #[deprecated(since = "0.7.0", note = "please use `invert_mod` instead")]
168    #[must_use]
169    pub const fn inv_mod(&self, modulus: &Self) -> CtOption<Self> {
170        let is_nz = modulus.is_nonzero();
171        let m = NonZero(Uint::select(&Uint::ONE, modulus, is_nz));
172        self.invert_mod(&m).filter_by(is_nz)
173    }
174
175    /// Computes the multiplicative inverse of `self` mod `modulus`.
176    ///
177    /// Returns some if an inverse exists, otherwise none.
178    #[must_use]
179    pub const fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
180        // Decompose `modulus = s * 2^k` where `s` is odd
181        let k = modulus.as_ref().trailing_zeros();
182        let s = Odd(modulus.as_ref().shr(k));
183
184        // Decompose `self` into RNS with moduli `2^k` and `s` and calculate the inverses.
185        // Using the fact that `(z^{-1} mod (m1 * m2)) mod m1 == z^{-1} mod m1`
186        let maybe_a = self.invert_odd_mod(&s);
187
188        let maybe_b = self.invert_mod2k(k);
189        let is_some = maybe_a.is_some().and(maybe_b.is_some());
190
191        // Extract inner values to avoid mapping through CtOptions.
192        // if `a` or `b` don't exist, the returned CtOption will be None anyway.
193        let a = maybe_a.to_inner_unchecked();
194        let b = maybe_b.to_inner_unchecked();
195
196        // Restore from RNS:
197        // self^{-1} = a mod s = b mod 2^k
198        // => self^{-1} = a + s * ((b - a) * s^(-1) mod 2^k)
199        // (essentially one step of the Garner's algorithm for recovery from RNS).
200
201        // `s` is odd, so this always exists
202        let m_odd_inv = s.invert_mod_precision();
203
204        // This part is mod 2^k
205        let t = b.wrapping_sub(&a).wrapping_mul(&m_odd_inv).restrict_bits(k);
206
207        // Will not overflow since `a <= s - 1`, `t <= 2^k - 1`,
208        // so `a + s * t <= s * 2^k - 1 == modulus - 1`.
209        let result = a.wrapping_add(&s.as_ref().wrapping_mul(&t));
210        CtOption::new(result, is_some)
211    }
212}
213
214impl<const LIMBS: usize> Odd<Uint<LIMBS>> {
215    /// Compute a full-width quadratic inversion, `self^-1 mod 2^Self::BITS`.
216    #[inline]
217    pub(crate) const fn invert_mod_precision(&self) -> Uint<LIMBS> {
218        self.invert_mod2k_vartime(Self::BITS)
219    }
220
221    /// Compute a quadratic inversion, `self^-1 mod 2^k` where `k <= Self::BITS`.
222    ///
223    /// This method is variable-time in `k` only.
224    #[allow(clippy::integer_division_remainder_used, reason = "vartime")]
225    pub(crate) const fn invert_mod2k_vartime(&self, k: u32) -> Uint<LIMBS> {
226        assert!(k <= Self::BITS);
227
228        let k_limbs = k.div_ceil(Limb::BITS) as usize;
229        let mut inv = U64::from_u64(self.as_uint_ref().invert_mod_u64()).resize::<LIMBS>();
230
231        if k_limbs <= U64::LIMBS {
232            // trim to k_limbs
233            inv.as_mut_uint_ref().trailing_mut(k_limbs).fill(Limb::ZERO);
234        } else {
235            // expand to k_limbs
236            let mut scratch = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
237            expand_invert_mod2k(
238                self.as_uint_ref(),
239                inv.as_mut_uint_ref().leading_mut(k_limbs),
240                U64::LIMBS,
241                (scratch.0.as_mut_uint_ref(), scratch.1.as_mut_uint_ref()),
242            );
243        }
244
245        // clear bits in the high limb if necessary
246        let k_bits = k % Limb::BITS;
247        if k_bits > 0 {
248            inv.limbs[k_limbs - 1] = inv.limbs[k_limbs - 1].restrict_bits(k_bits);
249        }
250
251        inv
252    }
253}
254
255impl<const LIMBS: usize> InvertMod for Uint<LIMBS> {
256    type Output = Self;
257
258    fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
259        self.invert_mod(modulus)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use crate::{Odd, U64, U256, U1024, Uint};
266
267    #[test]
268    fn invert_mod2k() {
269        let v =
270            U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f");
271        let e =
272            U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf");
273        let a = v.invert_mod2k(256).unwrap();
274        assert_eq!(e, a);
275
276        let a = v.invert_mod2k_vartime(256).unwrap();
277        assert_eq!(e, a);
278
279        let v =
280            U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
281        let e =
282            U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1");
283        let a = v.invert_mod2k(256).unwrap();
284        assert_eq!(e, a);
285
286        let a = v.invert_mod2k_vartime(256).unwrap();
287        assert_eq!(e, a);
288
289        // Check that even if the number is >= 2^k, the inverse is still correct.
290
291        let v =
292            U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
293        let e =
294            U256::from_be_hex("0000000000000000000000000000000000000000034613dbb4f20099aa774ec1");
295        let a = v.invert_mod2k(90).unwrap();
296        assert_eq!(e, a);
297
298        let a = v.invert_mod2k_vartime(90).unwrap();
299        assert_eq!(e, a);
300
301        // An inverse of an even number does not exist.
302
303        let a = U256::from(10u64).invert_mod2k(4);
304        assert!(a.is_none().to_bool_vartime());
305
306        let a = U256::from(10u64).invert_mod2k_vartime(4);
307        assert!(a.is_none().to_bool_vartime());
308
309        // A degenerate case. An inverse mod 2^0 == 1 always exists even for even numbers.
310
311        let a = U256::from(10u64).invert_mod2k_vartime(0).unwrap();
312        assert_eq!(a, U256::ZERO);
313    }
314
315    #[test]
316    fn test_invert_odd() {
317        let a = U1024::from_be_hex(concat![
318            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
319            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
320            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
321            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
322        ]);
323        let m = U1024::from_be_hex(concat![
324            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
325            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
326            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
327            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
328        ])
329        .to_odd()
330        .unwrap();
331        let expected = U1024::from_be_hex(concat![
332            "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
333            "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
334            "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
335            "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
336        ]);
337
338        let res = a.invert_odd_mod(&m).unwrap();
339        assert_eq!(res, expected);
340
341        // Even though it is less efficient, it still works
342        let res = a.invert_mod(m.as_nz_ref()).unwrap();
343        assert_eq!(res, expected);
344    }
345
346    #[test]
347    fn test_invert_odd_no_inverse() {
348        // 2^128 - 159, a prime
349        let p1 =
350            U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff61");
351        // 2^128 - 173, a prime
352        let p2 =
353            U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff53");
354
355        let m = p1.wrapping_mul(&p2).to_odd().unwrap();
356
357        // `m` is a multiple of `p1`, so no inverse exists
358        let res = p1.invert_odd_mod(&m);
359        assert!(res.is_none().to_bool_vartime());
360    }
361
362    #[test]
363    fn test_invert_even() {
364        let a = U1024::from_be_hex(concat![
365            "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
366            "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
367            "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
368            "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
369        ]);
370        let m = U1024::from_be_hex(concat![
371            "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
372            "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
373            "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
374            "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
375        ])
376        .to_nz()
377        .unwrap();
378        let expected = U1024::from_be_hex(concat![
379            "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
380            "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
381            "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
382            "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
383        ]);
384
385        let res = a.invert_mod(&m).unwrap();
386        assert_eq!(res, expected);
387    }
388
389    #[test]
390    fn test_invert_small() {
391        let a = U64::from(3u64);
392        let m = U64::from(13u64).to_odd().unwrap();
393
394        let res = a.invert_odd_mod(&m).unwrap();
395        assert_eq!(U64::from(9u64), res);
396    }
397
398    #[test]
399    fn test_no_inverse_small() {
400        let a = U64::from(14u64);
401        let m = U64::from(49u64).to_odd().unwrap();
402
403        let res = a.invert_odd_mod(&m);
404        assert!(res.is_none().to_bool_vartime());
405    }
406
407    #[test]
408    fn test_invert_edge() {
409        assert!(
410            U256::ZERO
411                .invert_odd_mod(&U256::ONE.to_odd().unwrap())
412                .is_none()
413                .to_bool_vartime()
414        );
415        assert_eq!(
416            U256::ONE
417                .invert_odd_mod(&U256::ONE.to_odd().unwrap())
418                .unwrap(),
419            U256::ZERO
420        );
421        assert_eq!(
422            U256::ONE
423                .invert_odd_mod(&U256::MAX.to_odd().unwrap())
424                .unwrap(),
425            U256::ONE
426        );
427        assert!(
428            U256::MAX
429                .invert_odd_mod(&U256::MAX.to_odd().unwrap())
430                .is_none()
431                .to_bool_vartime()
432        );
433        assert_eq!(
434            U256::MAX
435                .invert_odd_mod(&U256::ONE.to_odd().unwrap())
436                .unwrap(),
437            U256::ZERO
438        );
439    }
440
441    #[test]
442    fn invert_mod_precision() {
443        const BIG: Odd<Uint<8>> = Odd(Uint::MAX);
444
445        fn test_invert_size<const LIMBS: usize>() {
446            let a = BIG.resize::<LIMBS>();
447            let a_inv = a.invert_mod_precision();
448            assert_eq!(a.as_ref().wrapping_mul(&a_inv), Uint::ONE);
449        }
450
451        test_invert_size::<1>();
452        test_invert_size::<2>();
453        test_invert_size::<3>();
454        test_invert_size::<4>();
455        test_invert_size::<5>();
456        test_invert_size::<6>();
457        test_invert_size::<7>();
458        test_invert_size::<8>();
459        test_invert_size::<9>();
460        test_invert_size::<10>();
461    }
462}