Skip to main content

crypto_bigint/uint/boxed/
invert_mod.rs

1//! [`BoxedUint`] modular inverse (i.e. reciprocal) operations.
2
3use crate::{
4    BoxedUint, Choice, CtEq, CtLt, CtOption, CtSelect, Integer, InvertMod, Limb, NonZero, Odd, U64,
5    modular::safegcd, uint::invert_mod::expand_invert_mod2k,
6};
7
8impl BoxedUint {
9    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
10    #[deprecated(since = "0.7.0", note = "please use `invert_odd_mod` instead")]
11    #[must_use]
12    pub fn inv_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
13        self.invert_odd_mod(modulus)
14    }
15
16    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
17    #[must_use]
18    pub fn invert_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
19        safegcd::boxed::invert_odd_mod::<false>(self, modulus)
20    }
21
22    /// Computes the multiplicative inverse of `self` mod `modulus`, where `modulus` is odd.
23    #[must_use]
24    pub fn invert_odd_mod_vartime(&self, modulus: &Odd<Self>) -> CtOption<Self> {
25        safegcd::boxed::invert_odd_mod::<true>(self, modulus)
26    }
27
28    /// Computes 1/`self` mod `2^k`.
29    /// This method is constant-time w.r.t. `self` but not `k`.
30    ///
31    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > bits_precision()`),
32    /// returns `Choice::FALSE` as the second element of the tuple, otherwise returns `Choice::TRUE`.
33    #[deprecated(since = "0.7.0", note = "please use `invert_mod2k_vartime` instead")]
34    #[must_use]
35    pub fn inv_mod2k_vartime(&self, k: u32) -> (Self, Choice) {
36        self.invert_mod2k_vartime(k)
37    }
38
39    /// Computes 1/`self` mod `2^k`.
40    /// This method is constant-time w.r.t. `self` but not `k`.
41    ///
42    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > bits_precision()`),
43    /// returns `Choice::FALSE` as the second element of the tuple, otherwise returns `Choice::TRUE`.
44    #[must_use]
45    pub fn invert_mod2k_vartime(&self, k: u32) -> (Self, Choice) {
46        let bits = self.bits_precision();
47
48        if k == 0 {
49            (Self::zero_with_precision(bits), Choice::TRUE)
50        } else if k > bits {
51            (Self::zero_with_precision(bits), Choice::FALSE)
52        } else {
53            let is_some = self.is_odd();
54            let inv = Odd(Self::ct_select(
55                &Self::one_with_precision(bits),
56                self,
57                is_some,
58            ))
59            .invert_mod2k_vartime(k);
60            (inv, is_some)
61        }
62    }
63
64    /// Computes 1/`self` mod `2^k`.
65    ///
66    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > bits_precision()`),
67    /// returns `Choice::FALSE` as the second element of the tuple, otherwise returns `Choice::TRUE`.
68    #[deprecated(since = "0.7.0", note = "please use `invert_mod2k` instead")]
69    #[must_use]
70    pub fn inv_mod2k(&self, k: u32) -> (Self, Choice) {
71        self.invert_mod2k(k)
72    }
73
74    /// Computes 1/`self` mod `2^k`.
75    ///
76    /// If the inverse does not exist (`k > 0` and `self` is even, or `k > bits_precision()`),
77    /// returns `Choice::FALSE` as the second element of the tuple, otherwise returns `Choice::TRUE`.
78    #[must_use]
79    pub fn invert_mod2k(&self, k: u32) -> (Self, Choice) {
80        let bits = self.bits_precision();
81        let is_some = k.ct_lt(&(bits + 1)) & (k.ct_eq(&0) | self.is_odd());
82        let mut inv = Odd(Self::ct_select(
83            &Self::one_with_precision(bits),
84            self,
85            is_some,
86        ))
87        .invert_mod_precision();
88        inv.restrict_bits(k);
89        (inv, is_some)
90    }
91
92    /// Computes the multiplicative inverse of `self` mod `modulus`
93    ///
94    /// `self` and `modulus` must have the same number of limbs, or the function will panic
95    ///
96    /// TODO: maybe some better documentation is needed
97    #[deprecated(since = "0.7.0", note = "please use `invert_mod` instead")]
98    #[must_use]
99    pub fn inv_mod(&self, modulus: &Self) -> CtOption<Self> {
100        let is_nz = modulus.is_nonzero();
101        let m = NonZero(Self::ct_select(
102            &Self::one_with_precision(self.bits_precision()),
103            modulus,
104            is_nz,
105        ));
106        let inv_mod_s = self.invert_mod(&m);
107        let is_some = inv_mod_s.is_some();
108        let result =
109            Option::from(inv_mod_s).unwrap_or(Self::zero_with_precision(self.bits_precision()));
110        CtOption::new(result, is_some & is_nz)
111    }
112
113    /// Computes the multiplicative inverse of `self` mod `modulus`
114    ///
115    /// `self` and `modulus` must have the same number of limbs, or the function will panic
116    ///
117    /// TODO: maybe some better documentation is needed
118    #[must_use]
119    pub fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
120        debug_assert_eq!(self.bits_precision(), modulus.bits_precision());
121        let k = modulus.trailing_zeros();
122        let s = Odd(modulus.shr(k));
123
124        let inv_mod_s = self.invert_odd_mod(&s);
125        let invertible_mod_s = inv_mod_s.is_some();
126        let inv_mod_s = inv_mod_s.unwrap_or(Self::zero_with_precision(self.bits_precision()));
127
128        let (inverse_mod2k, invertible_mod_2k) = self.invert_mod2k(k);
129        let is_some = invertible_mod_s & invertible_mod_2k;
130
131        let s_inverse_mod2k = s.invert_mod_precision();
132        let mut t = inverse_mod2k
133            .wrapping_sub(&inv_mod_s)
134            .wrapping_mul(&s_inverse_mod2k);
135        t.restrict_bits(k);
136        let result = inv_mod_s.wrapping_add(s.wrapping_mul(&t));
137
138        CtOption::new(result, is_some)
139    }
140}
141
142impl Odd<BoxedUint> {
143    /// Compute a full-width quadratic inversion, `self^-1 mod 2^bits_precision()`.
144    #[inline]
145    pub(crate) fn invert_mod_precision(&self) -> BoxedUint {
146        self.invert_mod2k_vartime(self.bits_precision())
147    }
148
149    /// Compute a quadratic inversion, `self^-1 mod 2^k` where `k <= bits_precision()`.
150    ///
151    /// This method is variable-time in `k` only.
152    #[allow(clippy::integer_division_remainder_used, reason = "vartime")]
153    pub(crate) fn invert_mod2k_vartime(&self, k: u32) -> BoxedUint {
154        let bits = self.bits_precision();
155        assert!(k <= bits);
156
157        let k_limbs = k.div_ceil(Limb::BITS) as usize;
158        let inv_64 = U64::from_u64(self.as_uint_ref().invert_mod_u64());
159        let mut inv = BoxedUint::from_words_with_precision(*inv_64.as_words(), bits);
160
161        if k_limbs <= U64::LIMBS {
162            // trim to k_limbs
163            inv.as_mut_uint_ref().trailing_mut(k_limbs).fill(Limb::ZERO);
164        } else {
165            // expand to k_limbs
166            #[allow(clippy::cast_possible_truncation)]
167            let mut scratch = BoxedUint::zero_with_precision(k_limbs as u32 * 2 * Limb::BITS);
168
169            expand_invert_mod2k(
170                self.as_uint_ref(),
171                inv.as_mut_uint_ref().leading_mut(k_limbs),
172                U64::LIMBS,
173                scratch.as_mut_uint_ref().split_at_mut(k_limbs),
174            );
175        }
176
177        // clear bits in the high limb if necessary
178        let k_bits = k % Limb::BITS;
179        if k_bits > 0 {
180            inv.limbs[k_limbs - 1] = inv.limbs[k_limbs - 1].restrict_bits(k_bits);
181        }
182        inv
183    }
184}
185
186impl InvertMod for BoxedUint {
187    type Output = Self;
188
189    fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
190        self.invert_mod(modulus)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use crate::{Limb, Odd, Resize, U256};
197
198    use super::BoxedUint;
199    use hex_literal::hex;
200
201    #[test]
202    fn invert_mod2k() {
203        let v = BoxedUint::from_be_slice(
204            &hex!("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f"),
205            256,
206        )
207        .unwrap();
208        let e = BoxedUint::from_be_slice(
209            &hex!("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf"),
210            256,
211        )
212        .unwrap();
213        let (a, is_some) = v.invert_mod2k(256);
214        assert_eq!(e, a);
215        assert!(bool::from(is_some));
216
217        let v = BoxedUint::from_be_slice(
218            &hex!("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141"),
219            256,
220        )
221        .unwrap();
222        let e = BoxedUint::from_be_slice(
223            &hex!("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1"),
224            256,
225        )
226        .unwrap();
227        let (a, is_some) = v.invert_mod2k(256);
228        assert_eq!(e, a);
229        assert!(bool::from(is_some));
230    }
231
232    #[test]
233    fn inv_odd() {
234        let a = BoxedUint::from_be_hex(
235            concat![
236                "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
237                "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
238                "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
239                "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
240            ],
241            1024,
242        )
243        .unwrap();
244        let m = BoxedUint::from_be_hex(
245            concat![
246                "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
247                "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
248                "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
249                "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
250            ],
251            1024,
252        )
253        .unwrap()
254        .to_odd()
255        .unwrap();
256        let expected = BoxedUint::from_be_hex(
257            concat![
258                "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
259                "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
260                "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
261                "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
262            ],
263            1024,
264        )
265        .unwrap();
266        assert_eq!(a.invert_odd_mod(&m).unwrap(), expected);
267
268        assert_eq!(a.invert_mod(m.as_nz_ref()).unwrap(), expected);
269    }
270
271    #[test]
272    fn test_invert_odd_no_inverse() {
273        // 2^128 - 159, a prime
274        let p1 = BoxedUint::from_be_hex(
275            "00000000000000000000000000000000ffffffffffffffffffffffffffffff61",
276            256,
277        )
278        .unwrap();
279        // 2^128 - 173, a prime
280        let p2 = BoxedUint::from_be_hex(
281            "00000000000000000000000000000000ffffffffffffffffffffffffffffff53",
282            256,
283        )
284        .unwrap();
285
286        let m = p1.wrapping_mul(&p2).to_odd().unwrap();
287
288        // `m` is a multiple of `p1`, so no inverse exists
289        let res = p1.invert_odd_mod(&m);
290        let is_none: bool = res.is_none().into();
291        assert!(is_none);
292    }
293
294    #[test]
295    fn test_invert_even() {
296        let a = BoxedUint::from_be_hex(
297            concat![
298                "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
299                "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
300                "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
301                "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
302            ],
303            1024,
304        )
305        .unwrap();
306        let m = BoxedUint::from_be_hex(
307            concat![
308                "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
309                "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
310                "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
311                "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
312            ],
313            1024,
314        )
315        .unwrap()
316        .to_nz()
317        .unwrap();
318        let expected = BoxedUint::from_be_hex(
319            concat![
320                "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
321                "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
322                "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
323                "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
324            ],
325            1024,
326        )
327        .unwrap();
328
329        let res = a.invert_mod(&m).unwrap();
330        assert_eq!(res, expected);
331    }
332
333    #[test]
334    fn test_invert_small() {
335        let a = BoxedUint::from(3u64);
336        let m = BoxedUint::from(13u64).to_odd().unwrap();
337
338        let res = a.invert_odd_mod(&m).unwrap();
339        assert_eq!(BoxedUint::from(9u64), res);
340    }
341
342    #[test]
343    fn test_no_inverse_small() {
344        let a = BoxedUint::from(14u64);
345        let m = BoxedUint::from(49u64).to_odd().unwrap();
346
347        let res = a.invert_odd_mod(&m);
348        let is_none: bool = res.is_none().into();
349        assert!(is_none);
350    }
351
352    #[test]
353    fn test_invert_edge() {
354        assert!(bool::from(
355            BoxedUint::zero()
356                .invert_odd_mod(&BoxedUint::one().to_odd().unwrap())
357                .is_none()
358        ));
359        assert_eq!(
360            BoxedUint::one()
361                .invert_odd_mod(&BoxedUint::one().to_odd().unwrap())
362                .unwrap(),
363            BoxedUint::zero()
364        );
365        assert_eq!(
366            BoxedUint::one()
367                .invert_odd_mod(&BoxedUint::from(U256::MAX).to_odd().unwrap())
368                .unwrap(),
369            BoxedUint::one()
370        );
371        assert!(bool::from(
372            BoxedUint::from(U256::MAX)
373                .invert_odd_mod(&BoxedUint::from(U256::MAX).to_odd().unwrap())
374                .is_none()
375        ));
376    }
377
378    #[test]
379    fn invert_mod_precision() {
380        const PRECISION: u32 = 8 * Limb::BITS;
381
382        for limbs in 1..10 {
383            let a = Odd(BoxedUint::max(PRECISION).resize_unchecked(limbs));
384            let a_inv = a.invert_mod_precision();
385            assert_eq!(a.as_ref().wrapping_mul(&a_inv), BoxedUint::one());
386        }
387    }
388}