1use super::Uint;
2use crate::{
3 Choice, CtOption, InvertMod, Limb, NonZero, Odd, U64, UintRef, modular::safegcd, mul::karatsuba,
4};
5
6#[inline]
20pub(crate) const fn expand_invert_mod2k(
21 a: &Odd<UintRef>,
22 buf: &mut UintRef,
23 mut k: usize,
24 scratch: (&mut UintRef, &mut UintRef),
25) {
26 assert!(k > 0);
27 let p = buf.nlimbs();
28 let zs = p.trailing_zeros();
29
30 let mut target = if zs > 0 { p >> zs } else { p.div_ceil(2) };
35 if target > 8 {
36 expand_invert_mod2k(a, buf.leading_mut(target), k, (scratch.0, scratch.1));
37 k = target;
38 target = p;
39 } else if target <= k {
40 target = p;
41 }
42
43 while k < p {
45 let mut k2 = k * 2;
46 if k2 >= target {
49 (k2, target) = (target, p);
50 }
51 expand_invert_mod2k_step(a, buf.leading_mut(k2), k, (scratch.0, scratch.1));
52 k = k2;
53 }
54}
55
56#[inline(always)]
59const fn expand_invert_mod2k_step(
60 a: &Odd<UintRef>,
61 buf: &mut UintRef,
62 buf_init_len: usize,
63 scratch: (&mut UintRef, &mut UintRef),
64) {
65 let new_len = buf.nlimbs();
66
67 assert!(
68 scratch.0.nlimbs() >= new_len
69 && scratch.1.nlimbs() >= new_len
70 && buf_init_len < new_len
71 && buf_init_len >= (new_len >> 1)
72 );
73
74 let u0_p2 = scratch.0.leading_mut(new_len);
76 u0_p2.fill(Limb::ZERO);
77 karatsuba::wrapping_square(buf.leading(buf_init_len), u0_p2);
78
79 let tmp = scratch.1.leading_mut(new_len);
81 tmp.fill(Limb::ZERO);
82 karatsuba::wrapping_mul(u0_p2, a.as_ref(), tmp, false);
83
84 buf.shl1_assign();
86 buf.borrowing_sub_assign(tmp, Limb::ZERO);
88}
89
90impl<const LIMBS: usize> Uint<LIMBS> {
91 #[deprecated(since = "0.7.0", note = "please use `invert_mod2k_vartime` instead")]
98 #[must_use]
99 pub const fn inv_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
100 self.invert_mod2k_vartime(k)
101 }
102
103 #[must_use]
109 pub const fn invert_mod2k_vartime(&self, k: u32) -> CtOption<Self> {
110 if k == 0 {
111 CtOption::some(Self::ZERO)
112 } else if k > Self::BITS {
113 CtOption::new(Self::ZERO, Choice::FALSE)
114 } else {
115 let is_some = self.is_odd();
116 let inv = Odd(Uint::select(&Uint::ONE, self, is_some)).invert_mod2k_vartime(k);
117 CtOption::new(inv, is_some)
118 }
119 }
120
121 #[deprecated(since = "0.7.0", note = "please use `invert_mod2k` instead")]
126 #[must_use]
127 pub const fn inv_mod2k(&self, k: u32) -> CtOption<Self> {
128 self.invert_mod2k(k)
129 }
130
131 #[must_use]
136 pub const fn invert_mod2k(&self, k: u32) -> CtOption<Self> {
137 let is_some =
138 Choice::from_u32_le(k, Self::BITS).and(Choice::from_u32_nz(k).not().or(self.is_odd()));
139 let inv = Odd(Uint::select(&Uint::ONE, self, is_some)).invert_mod_precision();
140 CtOption::new(inv.restrict_bits(k), is_some)
141 }
142
143 #[deprecated(since = "0.7.0", note = "please use `invert_odd_mod` instead")]
145 #[must_use]
146 pub const fn inv_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
147 self.invert_odd_mod(modulus)
148 }
149
150 #[must_use]
152 pub const fn invert_odd_mod(&self, modulus: &Odd<Self>) -> CtOption<Self> {
153 safegcd::invert_odd_mod::<LIMBS, false>(self, modulus)
154 }
155
156 #[must_use]
160 pub const fn invert_odd_mod_vartime(&self, modulus: &Odd<Self>) -> CtOption<Self> {
161 safegcd::invert_odd_mod::<LIMBS, true>(self, modulus)
162 }
163
164 #[deprecated(since = "0.7.0", note = "please use `invert_mod` instead")]
168 #[must_use]
169 pub const fn inv_mod(&self, modulus: &Self) -> CtOption<Self> {
170 let is_nz = modulus.is_nonzero();
171 let m = NonZero(Uint::select(&Uint::ONE, modulus, is_nz));
172 self.invert_mod(&m).filter_by(is_nz)
173 }
174
175 #[must_use]
179 pub const fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
180 let k = modulus.as_ref().trailing_zeros();
182 let s = Odd(modulus.as_ref().shr(k));
183
184 let maybe_a = self.invert_odd_mod(&s);
187
188 let maybe_b = self.invert_mod2k(k);
189 let is_some = maybe_a.is_some().and(maybe_b.is_some());
190
191 let a = maybe_a.to_inner_unchecked();
194 let b = maybe_b.to_inner_unchecked();
195
196 let m_odd_inv = s.invert_mod_precision();
203
204 let t = b.wrapping_sub(&a).wrapping_mul(&m_odd_inv).restrict_bits(k);
206
207 let result = a.wrapping_add(&s.as_ref().wrapping_mul(&t));
210 CtOption::new(result, is_some)
211 }
212}
213
214impl<const LIMBS: usize> Odd<Uint<LIMBS>> {
215 #[inline]
217 pub(crate) const fn invert_mod_precision(&self) -> Uint<LIMBS> {
218 self.invert_mod2k_vartime(Self::BITS)
219 }
220
221 #[allow(clippy::integer_division_remainder_used, reason = "vartime")]
225 pub(crate) const fn invert_mod2k_vartime(&self, k: u32) -> Uint<LIMBS> {
226 assert!(k <= Self::BITS);
227
228 let k_limbs = k.div_ceil(Limb::BITS) as usize;
229 let mut inv = U64::from_u64(self.as_uint_ref().invert_mod_u64()).resize::<LIMBS>();
230
231 if k_limbs <= U64::LIMBS {
232 inv.as_mut_uint_ref().trailing_mut(k_limbs).fill(Limb::ZERO);
234 } else {
235 let mut scratch = (Uint::<LIMBS>::ZERO, Uint::<LIMBS>::ZERO);
237 expand_invert_mod2k(
238 self.as_uint_ref(),
239 inv.as_mut_uint_ref().leading_mut(k_limbs),
240 U64::LIMBS,
241 (scratch.0.as_mut_uint_ref(), scratch.1.as_mut_uint_ref()),
242 );
243 }
244
245 let k_bits = k % Limb::BITS;
247 if k_bits > 0 {
248 inv.limbs[k_limbs - 1] = inv.limbs[k_limbs - 1].restrict_bits(k_bits);
249 }
250
251 inv
252 }
253}
254
255impl<const LIMBS: usize> InvertMod for Uint<LIMBS> {
256 type Output = Self;
257
258 fn invert_mod(&self, modulus: &NonZero<Self>) -> CtOption<Self> {
259 self.invert_mod(modulus)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use crate::{Odd, U64, U256, U1024, Uint};
266
267 #[test]
268 fn invert_mod2k() {
269 let v =
270 U256::from_be_hex("fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f");
271 let e =
272 U256::from_be_hex("3642e6faeaac7c6663b93d3d6a0d489e434ddc0123db5fa627c7f6e22ddacacf");
273 let a = v.invert_mod2k(256).unwrap();
274 assert_eq!(e, a);
275
276 let a = v.invert_mod2k_vartime(256).unwrap();
277 assert_eq!(e, a);
278
279 let v =
280 U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
281 let e =
282 U256::from_be_hex("261776f29b6b106c7680cf3ed83054a1af5ae537cb4613dbb4f20099aa774ec1");
283 let a = v.invert_mod2k(256).unwrap();
284 assert_eq!(e, a);
285
286 let a = v.invert_mod2k_vartime(256).unwrap();
287 assert_eq!(e, a);
288
289 let v =
292 U256::from_be_hex("fffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141");
293 let e =
294 U256::from_be_hex("0000000000000000000000000000000000000000034613dbb4f20099aa774ec1");
295 let a = v.invert_mod2k(90).unwrap();
296 assert_eq!(e, a);
297
298 let a = v.invert_mod2k_vartime(90).unwrap();
299 assert_eq!(e, a);
300
301 let a = U256::from(10u64).invert_mod2k(4);
304 assert!(a.is_none().to_bool_vartime());
305
306 let a = U256::from(10u64).invert_mod2k_vartime(4);
307 assert!(a.is_none().to_bool_vartime());
308
309 let a = U256::from(10u64).invert_mod2k_vartime(0).unwrap();
312 assert_eq!(a, U256::ZERO);
313 }
314
315 #[test]
316 fn test_invert_odd() {
317 let a = U1024::from_be_hex(concat![
318 "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
319 "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
320 "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
321 "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
322 ]);
323 let m = U1024::from_be_hex(concat![
324 "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
325 "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
326 "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
327 "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156767"
328 ])
329 .to_odd()
330 .unwrap();
331 let expected = U1024::from_be_hex(concat![
332 "B03623284B0EBABCABD5C5881893320281460C0A8E7BF4BFDCFFCBCCBF436A55",
333 "D364235C8171E46C7D21AAD0680676E57274A8FDA6D12768EF961CACDD2DAE57",
334 "88D93DA5EB8EDC391EE3726CDCF4613C539F7D23E8702200CB31B5ED5B06E5CA",
335 "3E520968399B4017BF98A864FABA2B647EFC4998B56774D4F2CB026BC024A336"
336 ]);
337
338 let res = a.invert_odd_mod(&m).unwrap();
339 assert_eq!(res, expected);
340
341 let res = a.invert_mod(m.as_nz_ref()).unwrap();
343 assert_eq!(res, expected);
344 }
345
346 #[test]
347 fn test_invert_odd_no_inverse() {
348 let p1 =
350 U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff61");
351 let p2 =
353 U256::from_be_hex("00000000000000000000000000000000ffffffffffffffffffffffffffffff53");
354
355 let m = p1.wrapping_mul(&p2).to_odd().unwrap();
356
357 let res = p1.invert_odd_mod(&m);
359 assert!(res.is_none().to_bool_vartime());
360 }
361
362 #[test]
363 fn test_invert_even() {
364 let a = U1024::from_be_hex(concat![
365 "000225E99153B467A5B451979A3F451DAEF3BF8D6C6521D2FA24BBB17F29544E",
366 "347A412B065B75A351EA9719E2430D2477B11CC9CF9C1AD6EDEE26CB15F463F8",
367 "BCC72EF87EA30288E95A48AA792226CEC959DCB0672D8F9D80A54CBBEA85CAD8",
368 "382EC224DEB2F5784E62D0CC2F81C2E6AD14EBABE646D6764B30C32B87688985"
369 ]);
370 let m = U1024::from_be_hex(concat![
371 "D509E7854ABDC81921F669F1DC6F61359523F3949803E58ED4EA8BC16483DC6F",
372 "37BFE27A9AC9EEA2969B357ABC5C0EE214BE16A7D4C58FC620D5B5A20AFF001A",
373 "D198D3155E5799DC4EA76652D64983A7E130B5EACEBAC768D28D589C36EC749C",
374 "558D0B64E37CD0775C0D0104AE7D98BA23C815185DD43CD8B16292FD94156000"
375 ])
376 .to_nz()
377 .unwrap();
378 let expected = U1024::from_be_hex(concat![
379 "1EBF391306817E1BC610E213F4453AD70911CCBD59A901B2A468A4FC1D64F357",
380 "DBFC6381EC5635CAA664DF280028AF4651482C77A143DF38D6BFD4D64B6C0225",
381 "FC0E199B15A64966FB26D88A86AD144271F6BDCD3D63193AB2B3CC53B99F21A3",
382 "5B9BFAE5D43C6BC6E7A9856C71C7318C76530E9E5AE35882D5ABB02F1696874D",
383 ]);
384
385 let res = a.invert_mod(&m).unwrap();
386 assert_eq!(res, expected);
387 }
388
389 #[test]
390 fn test_invert_small() {
391 let a = U64::from(3u64);
392 let m = U64::from(13u64).to_odd().unwrap();
393
394 let res = a.invert_odd_mod(&m).unwrap();
395 assert_eq!(U64::from(9u64), res);
396 }
397
398 #[test]
399 fn test_no_inverse_small() {
400 let a = U64::from(14u64);
401 let m = U64::from(49u64).to_odd().unwrap();
402
403 let res = a.invert_odd_mod(&m);
404 assert!(res.is_none().to_bool_vartime());
405 }
406
407 #[test]
408 fn test_invert_edge() {
409 assert!(
410 U256::ZERO
411 .invert_odd_mod(&U256::ONE.to_odd().unwrap())
412 .is_none()
413 .to_bool_vartime()
414 );
415 assert_eq!(
416 U256::ONE
417 .invert_odd_mod(&U256::ONE.to_odd().unwrap())
418 .unwrap(),
419 U256::ZERO
420 );
421 assert_eq!(
422 U256::ONE
423 .invert_odd_mod(&U256::MAX.to_odd().unwrap())
424 .unwrap(),
425 U256::ONE
426 );
427 assert!(
428 U256::MAX
429 .invert_odd_mod(&U256::MAX.to_odd().unwrap())
430 .is_none()
431 .to_bool_vartime()
432 );
433 assert_eq!(
434 U256::MAX
435 .invert_odd_mod(&U256::ONE.to_odd().unwrap())
436 .unwrap(),
437 U256::ZERO
438 );
439 }
440
441 #[test]
442 fn invert_mod_precision() {
443 const BIG: Odd<Uint<8>> = Odd(Uint::MAX);
444
445 fn test_invert_size<const LIMBS: usize>() {
446 let a = BIG.resize::<LIMBS>();
447 let a_inv = a.invert_mod_precision();
448 assert_eq!(a.as_ref().wrapping_mul(&a_inv), Uint::ONE);
449 }
450
451 test_invert_size::<1>();
452 test_invert_size::<2>();
453 test_invert_size::<3>();
454 test_invert_size::<4>();
455 test_invert_size::<5>();
456 test_invert_size::<6>();
457 test_invert_size::<7>();
458 test_invert_size::<8>();
459 test_invert_size::<9>();
460 test_invert_size::<10>();
461 }
462}