1use core::cmp::Ordering;
14
15const KARATSUBA_THRESHOLD_LIMBS: usize = 32;
18const KARATSUBA_MAX_IMBALANCE: usize = 2;
21
22#[derive(Clone, Copy, Debug, Eq, PartialEq)]
24pub enum Sign {
25 Positive,
27 Negative,
29 Zero,
31}
32
33#[derive(Clone, Debug, Eq, PartialEq)]
35pub struct BigUint {
36 limbs: Vec<u64>,
37}
38
39#[derive(Clone, Debug, Eq, PartialEq)]
41pub struct BigInt {
42 sign: Sign,
43 magnitude: BigUint,
44}
45
46#[derive(Clone, Debug, Eq, PartialEq)]
53pub struct MontgomeryCtx {
54 modulus: BigUint,
55 n0_inv: u64,
57 r2_mod: BigUint,
59 one_mont: BigUint,
61}
62
63impl Ord for BigUint {
64 fn cmp(&self, other: &Self) -> Ordering {
65 debug_assert!(
68 self.limbs.last().copied() != Some(0),
69 "BigUint invariant: no leading zero limbs",
70 );
71 debug_assert!(
72 other.limbs.last().copied() != Some(0),
73 "BigUint invariant: no leading zero limbs",
74 );
75 match self.limbs.len().cmp(&other.limbs.len()) {
76 Ordering::Equal => {}
77 ord => return ord,
78 }
79
80 for (&lhs, &rhs) in self.limbs.iter().rev().zip(other.limbs.iter().rev()) {
81 match lhs.cmp(&rhs) {
82 Ordering::Equal => {}
83 ord => return ord,
84 }
85 }
86
87 Ordering::Equal
88 }
89}
90
91impl PartialOrd for BigUint {
92 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93 Some(self.cmp(other))
94 }
95}
96
97impl BigUint {
98 #[must_use]
100 pub fn zero() -> Self {
101 Self { limbs: Vec::new() }
102 }
103
104 #[must_use]
106 pub fn one() -> Self {
107 Self { limbs: vec![1] }
108 }
109
110 #[must_use]
112 pub fn from_u64(value: u64) -> Self {
113 if value == 0 {
114 Self::zero()
115 } else {
116 Self { limbs: vec![value] }
117 }
118 }
119
120 #[must_use]
126 pub fn from_u128(value: u128) -> Self {
127 if value == 0 {
128 return Self::zero();
129 }
130
131 let lo =
132 u64::try_from(value & u128::from(u64::MAX)).expect("low 64 bits always fit into u64");
133 let hi = u64::try_from(value >> 64).expect("high 64 bits always fit into u64");
134 if hi == 0 {
135 Self { limbs: vec![lo] }
136 } else {
137 Self {
138 limbs: vec![lo, hi],
139 }
140 }
141 }
142
143 #[must_use]
147 pub fn from_be_bytes(bytes: &[u8]) -> Self {
148 if bytes.is_empty() {
149 return Self::zero();
150 }
151
152 let mut limbs = Vec::with_capacity(bytes.len().div_ceil(8));
153 let mut acc = 0u64;
154 let mut shift = 0u32;
155
156 for &byte in bytes.iter().rev() {
161 acc |= u64::from(byte) << shift;
162 shift += 8;
163 if shift == 64 {
164 limbs.push(acc);
165 acc = 0;
166 shift = 0;
167 }
168 }
169
170 if shift != 0 {
171 limbs.push(acc);
172 }
173
174 let mut out = Self { limbs };
175 out.normalize();
176 out
177 }
178
179 #[must_use]
190 pub fn to_be_bytes(&self) -> Vec<u8> {
191 if self.is_zero() {
192 return vec![0];
193 }
194
195 let mut out = Vec::with_capacity(self.limbs.len() * 8);
196 for &limb in self.limbs.iter().rev() {
197 out.extend_from_slice(&limb.to_be_bytes());
198 }
199
200 let first_nonzero = out
201 .iter()
202 .position(|&byte| byte != 0)
203 .expect("non-zero bigint must encode to at least one non-zero byte");
204 out.drain(0..first_nonzero);
205 out
206 }
207
208 #[must_use]
210 pub fn is_zero(&self) -> bool {
211 self.limbs.is_empty()
212 }
213
214 #[must_use]
216 pub fn is_odd(&self) -> bool {
217 !self.is_zero() && (self.limbs[0] & 1) == 1
218 }
219
220 #[must_use]
222 pub fn is_one(&self) -> bool {
223 self.limbs.len() == 1 && self.limbs[0] == 1
224 }
225
226 #[must_use]
233 pub fn bits(&self) -> usize {
234 if self.is_zero() {
235 return 0;
236 }
237
238 let top = *self
239 .limbs
240 .last()
241 .expect("non-zero bigint has at least one limb");
242 let top_bits = (u64::BITS - top.leading_zeros()) as usize;
243 (self.limbs.len() - 1) * 64 + top_bits
244 }
245
246 #[must_use]
248 pub fn sqrt_floor(&self) -> Self {
249 if self.is_zero() {
250 return Self::zero();
251 }
252 if self.is_one() {
253 return Self::one();
254 }
255
256 let mut low = Self::one();
257 let mut high = Self::zero();
258 high.set_bit(self.bits().div_ceil(2));
263
264 while {
265 let next_low = low.add_ref(&Self::one());
266 next_low < high
267 } {
268 let mut middle = low.add_ref(&high);
269 middle.shr1();
270 let square = middle.square_ref();
271 if square <= *self {
272 low = middle;
273 } else {
274 high = middle;
275 }
276 }
277
278 low
279 }
280
281 #[must_use]
283 pub fn bit(&self, index: usize) -> bool {
284 let limb = index / 64;
285 let shift = index % 64;
286 if limb >= self.limbs.len() {
287 false
288 } else {
289 ((self.limbs[limb] >> shift) & 1) == 1
290 }
291 }
292
293 pub fn set_bit(&mut self, index: usize) {
295 let limb = index / 64;
296 let shift = index % 64;
297 if self.limbs.len() <= limb {
298 self.limbs.resize(limb + 1, 0);
299 }
300 self.limbs[limb] |= 1u64 << shift;
301 }
302
303 pub fn add_assign_ref(&mut self, other: &Self) {
310 if other.is_zero() {
311 return;
312 }
313
314 if self.limbs.len() < other.limbs.len() {
315 self.limbs.resize(other.limbs.len(), 0);
316 }
317
318 let mut carry = 0u128;
319 for i in 0..other.limbs.len() {
320 let sum = u128::from(self.limbs[i]) + u128::from(other.limbs[i]) + carry;
321 self.limbs[i] = low_u64(sum);
322 carry = sum >> 64;
323 }
324
325 let mut i = other.limbs.len();
326 while carry != 0 && i < self.limbs.len() {
327 let sum = u128::from(self.limbs[i]) + carry;
328 self.limbs[i] = low_u64(sum);
329 carry = sum >> 64;
330 i += 1;
331 }
332
333 if carry != 0 {
334 self.limbs
335 .push(u64::try_from(carry).expect("final carry from u64 addition is at most 1"));
336 }
337 }
338
339 #[must_use]
341 pub fn add_ref(&self, other: &Self) -> Self {
342 let mut out = self.clone();
343 out.add_assign_ref(other);
344 out
345 }
346
347 pub fn sub_assign_ref(&mut self, other: &Self) {
353 assert!((*self).cmp(other) != Ordering::Less, "BigUint underflow");
354 if other.is_zero() {
355 return;
356 }
357
358 let mut borrow = 0u128;
359 for i in 0..self.limbs.len() {
360 let lhs = u128::from(self.limbs[i]);
361 let rhs = if i < other.limbs.len() {
362 u128::from(other.limbs[i])
363 } else {
364 0
365 };
366
367 let subtrahend = rhs + borrow;
368 if lhs >= subtrahend {
369 self.limbs[i] = low_u64(lhs - subtrahend);
370 borrow = 0;
371 } else {
372 self.limbs[i] = low_u64((1u128 << 64) + lhs - subtrahend);
373 borrow = 1;
374 }
375 }
376
377 self.normalize();
378 }
379
380 #[must_use]
382 pub fn sub_ref(&self, other: &Self) -> Self {
383 let mut out = self.clone();
384 out.sub_assign_ref(other);
385 out
386 }
387
388 #[must_use]
395 pub fn mul_ref(&self, other: &Self) -> Self {
396 if self.is_zero() || other.is_zero() {
397 return Self::zero();
398 }
399
400 if Self::should_use_karatsuba(self, other) {
401 return self.mul_karatsuba_ref(other);
402 }
403
404 Self::mul_schoolbook_ref(self, other)
405 }
406
407 #[must_use]
409 pub fn square_ref(&self) -> Self {
410 self.mul_ref(self)
411 }
412
413 fn split_at_limb(&self, split: usize) -> (Self, Self) {
414 let low_end = split.min(self.limbs.len());
415 let mut low = Self {
416 limbs: self.limbs[..low_end].to_vec(),
417 };
418 low.normalize();
419
420 if split >= self.limbs.len() {
421 return (low, Self::zero());
422 }
423
424 let mut high = Self {
425 limbs: self.limbs[split..].to_vec(),
426 };
427 high.normalize();
428 (low, high)
429 }
430
431 fn should_use_karatsuba(lhs: &Self, rhs: &Self) -> bool {
432 let short = lhs.limbs.len().min(rhs.limbs.len());
433 let long = lhs.limbs.len().max(rhs.limbs.len());
434 short >= KARATSUBA_THRESHOLD_LIMBS && long <= short * KARATSUBA_MAX_IMBALANCE
435 }
436
437 fn mul_karatsuba_ref(&self, other: &Self) -> Self {
438 let split = self.limbs.len().max(other.limbs.len()) / 2;
439 if split == 0 {
440 return Self::mul_schoolbook_ref(self, other);
441 }
442
443 let (a0, a1) = self.split_at_limb(split);
444 let (b0, b1) = other.split_at_limb(split);
445 if a1.is_zero() || b1.is_zero() {
446 return Self::mul_schoolbook_ref(self, other);
447 }
448
449 let z0 = a0.mul_ref(&b0);
450 let z2 = a1.mul_ref(&b1);
451
452 let a_sum = a0.add_ref(&a1);
453 let b_sum = b0.add_ref(&b1);
454 let mut z1 = a_sum.mul_ref(&b_sum);
455 z1.sub_assign_ref(&z0);
456 z1.sub_assign_ref(&z2);
457
458 let mut out = z0;
459 z1.shl_bits(split * 64);
460 out.add_assign_ref(&z1);
461
462 let mut z2_shifted = z2;
463 z2_shifted.shl_bits(split * 128);
464 out.add_assign_ref(&z2_shifted);
465 out
466 }
467
468 fn mul_schoolbook_ref(lhs: &Self, rhs: &Self) -> Self {
469 let mut out = vec![0u64; lhs.limbs.len() + rhs.limbs.len()];
470 for (i, &lhs_limb) in lhs.limbs.iter().enumerate() {
471 let mut carry = 0u128;
472 for (j, &rhs_limb) in rhs.limbs.iter().enumerate() {
473 let idx = i + j;
474 let acc =
475 u128::from(out[idx]) + u128::from(lhs_limb) * u128::from(rhs_limb) + carry;
476 out[idx] = low_u64(acc);
477 carry = acc >> 64;
478 }
479
480 let mut idx = i + rhs.limbs.len();
481 while carry != 0 {
482 let acc = u128::from(out[idx]) + carry;
483 out[idx] = low_u64(acc);
484 carry = acc >> 64;
485 idx += 1;
486 }
487 }
488
489 let mut result = Self { limbs: out };
490 result.normalize();
494 result
495 }
496
497 pub fn shl1(&mut self) {
499 if self.is_zero() {
500 return;
501 }
502
503 let mut carry = 0u64;
504 for limb in &mut self.limbs {
505 let next = *limb >> 63;
506 *limb = (*limb << 1) | carry;
507 carry = next;
508 }
509
510 if carry != 0 {
511 self.limbs.push(carry);
512 }
513 }
516
517 pub fn shr1(&mut self) {
519 if self.is_zero() {
520 return;
521 }
522
523 let mut carry = 0u64;
524 for limb in self.limbs.iter_mut().rev() {
525 let next = (*limb & 1) << 63;
526 *limb = (*limb >> 1) | carry;
527 carry = next;
528 }
529
530 self.normalize();
531 }
532
533 pub fn bitxor_assign(&mut self, other: &BigUint) {
539 if self.limbs.len() < other.limbs.len() {
540 self.limbs.resize(other.limbs.len(), 0);
541 }
542 for (s, &o) in self.limbs.iter_mut().zip(other.limbs.iter()) {
543 *s ^= o;
544 }
545 self.normalize();
546 }
547
548 pub fn shl_bits(&mut self, n: usize) {
554 if self.is_zero() || n == 0 {
555 return;
556 }
557 let limb_shifts = n / 64;
558 let bit_shifts = n % 64;
559 if limb_shifts > 0 {
561 let mut new_limbs = vec![0u64; limb_shifts];
562 new_limbs.extend_from_slice(&self.limbs);
563 self.limbs = new_limbs;
564 }
565 if bit_shifts > 0 {
567 let mut carry = 0u64;
568 for limb in &mut self.limbs {
569 let next_carry = *limb >> (64 - bit_shifts);
570 *limb = (*limb << bit_shifts) | carry;
571 carry = next_carry;
572 }
573 if carry != 0 {
574 self.limbs.push(carry);
575 }
576 }
577 }
580
581 #[must_use]
583 pub fn modulo(&self, modulus: &Self) -> Self {
584 let (_, remainder) = self.div_rem(modulus);
585 remainder
586 }
587
588 #[must_use]
594 pub fn rem_u64(&self, modulus: u64) -> u64 {
595 assert!(modulus != 0, "division by zero");
596 if self.is_zero() {
597 return 0;
598 }
599
600 let mut remainder = 0u128;
601 for &limb in self.limbs.iter().rev() {
605 let acc = (remainder << 64) | u128::from(limb);
606 remainder = acc % u128::from(modulus);
607 }
608
609 u64::try_from(remainder).expect("remainder modulo u64 fits into u64")
610 }
611
612 #[must_use]
626 pub fn mod_mul(lhs: &Self, rhs: &Self, modulus: &Self) -> Self {
627 assert!(!modulus.is_zero(), "modulus must be non-zero");
628 if modulus == &Self::one() {
629 return Self::zero();
630 }
631 if let Some(ctx) = MontgomeryCtx::new(modulus) {
632 return ctx.mul(lhs, rhs);
633 }
634 Self::mod_mul_plain(lhs, rhs, modulus)
635 }
636
637 #[must_use]
646 pub(crate) fn mod_mul_plain(lhs: &Self, rhs: &Self, modulus: &Self) -> Self {
647 if lhs.is_zero() || rhs.is_zero() {
648 return Self::zero();
649 }
650
651 let mut a = lhs.modulo(modulus);
652 let mut b = rhs.clone();
653 let mut out = Self::zero();
654 while !b.is_zero() {
655 if b.is_odd() {
656 out = out.add_ref(&a).modulo(modulus);
657 }
658 a = a.add_ref(&a).modulo(modulus);
659 b.shr1();
660 }
661 out
662 }
663
664 #[must_use]
670 pub fn div_rem(&self, divisor: &Self) -> (Self, Self) {
671 assert!(!divisor.is_zero(), "division by zero");
672 if self.cmp(divisor) == Ordering::Less {
673 return (Self::zero(), self.clone());
674 }
675
676 let mut quotient = Self::zero();
677 let mut remainder = Self::zero();
678
679 for bit in (0..self.bits()).rev() {
684 remainder.shl1();
685 if self.bit(bit) {
686 if remainder.is_zero() {
687 remainder.limbs.push(1);
688 } else {
689 remainder.limbs[0] |= 1;
690 }
691 }
692
693 if remainder.cmp(divisor) != Ordering::Less {
694 remainder.sub_assign_ref(divisor);
695 quotient.set_bit(bit);
696 }
697 }
698
699 (quotient, remainder)
700 }
701
702 fn normalize(&mut self) {
703 while self.limbs.last().copied() == Some(0) {
707 self.limbs.pop();
708 }
709 }
710
711 fn limb_or_zero(&self, idx: usize) -> u64 {
712 self.limbs.get(idx).copied().unwrap_or(0)
713 }
714
715 fn montgomery_mul_odd_with_workspace(
716 lhs: &Self,
717 rhs: &Self,
718 modulus: &Self,
719 n0_inv: u64,
720 workspace: &mut Vec<u64>,
721 ) -> Self {
722 debug_assert!(modulus.is_odd(), "Montgomery path requires an odd modulus");
723 let width = modulus.limbs.len();
724 let needed = width * 2 + 2;
727 if workspace.len() != needed {
728 workspace.resize(needed, 0);
729 } else {
730 workspace.fill(0);
731 }
732
733 for i in 0..width {
735 let lhs_limb = lhs.limb_or_zero(i);
736 let mut carry = 0u128;
737 for j in 0..width {
738 let idx = i + j;
739 let acc = u128::from(workspace[idx])
740 + u128::from(lhs_limb) * u128::from(rhs.limb_or_zero(j))
741 + carry;
742 workspace[idx] = low_u64(acc);
743 carry = acc >> 64;
744 }
745
746 let mut idx = i + width;
747 while carry != 0 {
748 let acc = u128::from(workspace[idx]) + carry;
749 workspace[idx] = low_u64(acc);
750 carry = acc >> 64;
751 idx += 1;
752 }
753 }
754
755 for i in 0..width {
762 let m = workspace[i].wrapping_mul(n0_inv);
763 let mut carry = 0u128;
764 for j in 0..width {
765 let idx = i + j;
766 let acc = u128::from(workspace[idx])
767 + u128::from(m) * u128::from(modulus.limb_or_zero(j))
768 + carry;
769 workspace[idx] = low_u64(acc);
770 carry = acc >> 64;
771 }
772
773 let mut idx = i + width;
774 while carry != 0 {
775 let acc = u128::from(workspace[idx]) + carry;
776 workspace[idx] = low_u64(acc);
777 carry = acc >> 64;
778 idx += 1;
779 }
780 }
781
782 let mut out = Self {
783 limbs: workspace[width..=(width * 2)].to_vec(),
784 };
785 out.normalize();
786 if out >= *modulus {
789 out.sub_assign_ref(modulus);
790 }
791 out
792 }
793}
794
795impl MontgomeryCtx {
796 fn encode_with_workspace(&self, value: &BigUint, workspace: &mut Vec<u64>) -> BigUint {
797 if value.is_zero() {
798 return BigUint::zero();
799 }
800
801 BigUint::montgomery_mul_odd_with_workspace(
802 &value.modulo(&self.modulus),
803 &self.r2_mod,
804 &self.modulus,
805 self.n0_inv,
806 workspace,
807 )
808 }
809
810 fn decode_with_workspace(&self, value: &BigUint, workspace: &mut Vec<u64>) -> BigUint {
811 BigUint::montgomery_mul_odd_with_workspace(
812 value,
813 &BigUint::one(),
814 &self.modulus,
815 self.n0_inv,
816 workspace,
817 )
818 }
819
820 fn pow_encoded_with_workspace(
821 &self,
822 base_mont: &BigUint,
823 exponent: &BigUint,
824 workspace: &mut Vec<u64>,
825 ) -> BigUint {
826 if self.modulus == BigUint::one() {
827 return BigUint::zero();
828 }
829
830 let mut result = self.one_mont.clone();
831 let mut power = base_mont.clone();
832
833 for bit in 0..exponent.bits() {
834 if exponent.bit(bit) {
835 result = BigUint::montgomery_mul_odd_with_workspace(
836 &result,
837 &power,
838 &self.modulus,
839 self.n0_inv,
840 workspace,
841 );
842 }
843 power = BigUint::montgomery_mul_odd_with_workspace(
844 &power,
845 &power,
846 &self.modulus,
847 self.n0_inv,
848 workspace,
849 );
850 }
851
852 self.decode_with_workspace(&result, workspace)
853 }
854
855 #[must_use]
857 pub fn new(modulus: &BigUint) -> Option<Self> {
858 if modulus.is_zero() || !modulus.is_odd() {
859 return None;
860 }
861
862 let n0_inv = montgomery_n0_inv(modulus.limbs[0]);
863
864 let mut r2 = BigUint::zero();
869 r2.set_bit(modulus.limbs.len() * 128);
870 let r2_mod = r2.modulo(modulus);
871
872 let mut r = BigUint::zero();
875 r.set_bit(modulus.limbs.len() * 64);
876 let one_mont = r.modulo(modulus);
877
878 Some(Self {
879 modulus: modulus.clone(),
880 n0_inv,
881 r2_mod,
882 one_mont,
883 })
884 }
885
886 #[must_use]
888 pub fn modulus(&self) -> &BigUint {
889 &self.modulus
890 }
891
892 #[must_use]
894 pub fn encode(&self, value: &BigUint) -> BigUint {
895 let mut workspace = Vec::new();
896 self.encode_with_workspace(value, &mut workspace)
897 }
898
899 #[must_use]
901 pub fn decode(&self, value: &BigUint) -> BigUint {
902 let mut workspace = Vec::new();
903 self.decode_with_workspace(value, &mut workspace)
904 }
905
906 #[must_use]
908 pub fn mul(&self, lhs: &BigUint, rhs: &BigUint) -> BigUint {
909 let mut workspace = Vec::new();
910 let lhs_mont = self.encode_with_workspace(lhs, &mut workspace);
911 let rhs_mont = self.encode_with_workspace(rhs, &mut workspace);
912 let product_mont = BigUint::montgomery_mul_odd_with_workspace(
913 &lhs_mont,
914 &rhs_mont,
915 &self.modulus,
916 self.n0_inv,
917 &mut workspace,
918 );
919 self.decode_with_workspace(&product_mont, &mut workspace)
920 }
921
922 #[must_use]
924 pub fn square(&self, value: &BigUint) -> BigUint {
925 let mut workspace = Vec::new();
926 let value_mont = self.encode_with_workspace(value, &mut workspace);
927 let square_mont = BigUint::montgomery_mul_odd_with_workspace(
928 &value_mont,
929 &value_mont,
930 &self.modulus,
931 self.n0_inv,
932 &mut workspace,
933 );
934 self.decode_with_workspace(&square_mont, &mut workspace)
935 }
936
937 #[must_use]
939 pub fn pow(&self, base: &BigUint, exponent: &BigUint) -> BigUint {
940 let mut workspace = Vec::new();
941 let base_mont = self.encode_with_workspace(&base.modulo(&self.modulus), &mut workspace);
942 self.pow_encoded_with_workspace(&base_mont, exponent, &mut workspace)
943 }
944
945 #[must_use]
950 pub fn pow_encoded(&self, base_mont: &BigUint, exponent: &BigUint) -> BigUint {
951 let mut workspace = Vec::new();
952 self.pow_encoded_with_workspace(base_mont, exponent, &mut workspace)
953 }
954}
955
956impl Drop for BigUint {
957 fn drop(&mut self) {
958 crate::ct::zeroize_slice(self.limbs.as_mut_slice());
962 }
963}
964
965#[inline]
966fn low_u64(value: u128) -> u64 {
967 u64::try_from(value & u128::from(u64::MAX)).expect("masked low 64 bits always fit into u64")
968}
969
970fn montgomery_n0_inv(n0: u64) -> u64 {
971 debug_assert!(n0 & 1 == 1, "Montgomery path requires an odd modulus");
972 let mut inv = 1u64;
976 for _ in 0..6 {
977 inv = inv.wrapping_mul(2u64.wrapping_sub(n0.wrapping_mul(inv)));
978 }
979 inv.wrapping_neg()
980}
981
982impl BigInt {
983 #[must_use]
985 pub fn zero() -> Self {
986 Self {
987 sign: Sign::Zero,
988 magnitude: BigUint::zero(),
989 }
990 }
991
992 #[must_use]
994 pub fn from_parts(sign: Sign, magnitude: BigUint) -> Self {
995 if magnitude.is_zero() {
996 return Self::zero();
997 }
998
999 let canonical_sign = match sign {
1000 Sign::Zero => Sign::Positive,
1001 other => other,
1002 };
1003
1004 Self {
1005 sign: canonical_sign,
1006 magnitude,
1007 }
1008 }
1009
1010 #[must_use]
1012 pub fn from_biguint(magnitude: BigUint) -> Self {
1013 Self::from_parts(Sign::Positive, magnitude)
1014 }
1015
1016 #[must_use]
1018 pub fn sign(&self) -> Sign {
1019 self.sign
1020 }
1021
1022 #[must_use]
1024 pub fn magnitude(&self) -> &BigUint {
1025 &self.magnitude
1026 }
1027
1028 #[must_use]
1030 pub fn negated(&self) -> Self {
1031 let sign = match self.sign {
1032 Sign::Positive => Sign::Negative,
1033 Sign::Negative => Sign::Positive,
1034 Sign::Zero => Sign::Zero,
1035 };
1036 Self {
1037 sign,
1038 magnitude: self.magnitude.clone(),
1039 }
1040 }
1041
1042 #[must_use]
1044 pub fn add_ref(&self, other: &Self) -> Self {
1045 match (self.sign, other.sign) {
1046 (Sign::Zero, _) => other.clone(),
1047 (_, Sign::Zero) => self.clone(),
1048 (Sign::Positive, Sign::Positive) => {
1049 Self::from_parts(Sign::Positive, self.magnitude.add_ref(&other.magnitude))
1050 }
1051 (Sign::Negative, Sign::Negative) => {
1052 Self::from_parts(Sign::Negative, self.magnitude.add_ref(&other.magnitude))
1053 }
1054 (Sign::Positive, Sign::Negative) => self.sub_ref(&other.negated()),
1055 (Sign::Negative, Sign::Positive) => other.sub_ref(&self.negated()),
1056 }
1057 }
1058
1059 #[must_use]
1061 pub fn sub_ref(&self, other: &Self) -> Self {
1062 match (self.sign, other.sign) {
1063 (_, Sign::Zero) => self.clone(),
1064 (Sign::Zero, _) => other.negated(),
1065 (Sign::Positive, Sign::Negative) => {
1066 Self::from_parts(Sign::Positive, self.magnitude.add_ref(&other.magnitude))
1067 }
1068 (Sign::Negative, Sign::Positive) => {
1069 Self::from_parts(Sign::Negative, self.magnitude.add_ref(&other.magnitude))
1070 }
1071 (Sign::Positive, Sign::Positive) => match self.magnitude.cmp(&other.magnitude) {
1072 Ordering::Greater => {
1073 Self::from_parts(Sign::Positive, self.magnitude.sub_ref(&other.magnitude))
1074 }
1075 Ordering::Less => {
1076 Self::from_parts(Sign::Negative, other.magnitude.sub_ref(&self.magnitude))
1077 }
1078 Ordering::Equal => Self::zero(),
1079 },
1080 (Sign::Negative, Sign::Negative) => match self.magnitude.cmp(&other.magnitude) {
1081 Ordering::Greater => {
1082 Self::from_parts(Sign::Negative, self.magnitude.sub_ref(&other.magnitude))
1083 }
1084 Ordering::Less => {
1085 Self::from_parts(Sign::Positive, other.magnitude.sub_ref(&self.magnitude))
1086 }
1087 Ordering::Equal => Self::zero(),
1088 },
1089 }
1090 }
1091
1092 #[must_use]
1094 pub fn mul_biguint_ref(&self, factor: &BigUint) -> Self {
1095 if factor.is_zero() || self.sign == Sign::Zero {
1096 return Self::zero();
1097 }
1098
1099 Self::from_parts(self.sign, self.magnitude.mul_ref(factor))
1100 }
1101
1102 #[must_use]
1108 pub fn modulo_positive(&self, modulus: &BigUint) -> BigUint {
1109 assert!(!modulus.is_zero(), "modulus must be non-zero");
1110 match self.sign {
1111 Sign::Zero => BigUint::zero(),
1112 Sign::Positive => self.magnitude.modulo(modulus),
1113 Sign::Negative => {
1114 let rem = self.magnitude.modulo(modulus);
1115 if rem.is_zero() {
1116 BigUint::zero()
1117 } else {
1118 modulus.sub_ref(&rem)
1119 }
1120 }
1121 }
1122 }
1123}
1124
1125#[cfg(test)]
1126mod tests {
1127 use super::{BigInt, BigUint, MontgomeryCtx, Sign};
1128
1129 fn lcg_next(state: &mut u64) -> u64 {
1130 *state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
1131 *state
1132 }
1133
1134 fn seeded_biguint(words: usize, state: &mut u64) -> BigUint {
1135 let mut limbs = Vec::with_capacity(words);
1136 for _ in 0..words {
1137 limbs.push(lcg_next(state));
1138 }
1139 if words > 0 && limbs[words - 1] == 0 {
1140 limbs[words - 1] = 1;
1141 }
1142 BigUint { limbs }
1143 }
1144
1145 #[test]
1146 fn bytes_roundtrip() {
1147 let value =
1148 BigUint::from_be_bytes(&[0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22]);
1149 assert_eq!(
1150 value.to_be_bytes(),
1151 vec![0x12, 0x34, 0x56, 0x78, 0x9a, 0xbc, 0xde, 0xf0, 0x11, 0x22]
1152 );
1153 }
1154
1155 #[test]
1156 fn add_sub_mul_small_values() {
1157 let a = BigUint::from_u128(1_000_000_000_000);
1158 let b = BigUint::from_u128(777_777_777_777);
1159 assert_eq!(a.add_ref(&b), BigUint::from_u128(1_777_777_777_777));
1160 assert_eq!(
1161 a.sub_ref(&BigUint::from_u64(1)),
1162 BigUint::from_u128(999_999_999_999)
1163 );
1164 assert_eq!(
1165 a.mul_ref(&b),
1166 BigUint::from_u128(777_777_777_777_000_000_000_000)
1167 );
1168 }
1169
1170 #[test]
1171 fn square_ref_matches_mul_ref() {
1172 let mut seed = 0x9e37_79b9_7f4a_7c15;
1173 for words in [1usize, 2, 8, 32, 48] {
1174 for _ in 0..8 {
1175 let value = seeded_biguint(words, &mut seed);
1176 assert_eq!(value.square_ref(), value.mul_ref(&value));
1177 }
1178 }
1179 }
1180
1181 #[test]
1182 fn karatsuba_dispatch_matches_schoolbook() {
1183 let mut seed = 0x243f_6a88_85a3_08d3;
1184 for words in [32usize, 40, 64] {
1185 for _ in 0..6 {
1186 let lhs = seeded_biguint(words, &mut seed);
1187 let rhs = seeded_biguint(words, &mut seed);
1188 let dispatched = lhs.mul_ref(&rhs);
1189 let schoolbook = BigUint::mul_schoolbook_ref(&lhs, &rhs);
1190 assert_eq!(dispatched, schoolbook);
1191 }
1192 }
1193 }
1194
1195 #[test]
1196 fn division_roundtrip() {
1197 let dividend = BigUint::from_u128(1_234_567_890_123_456_789);
1198 let divisor = BigUint::from_u64(37);
1199 let (q, r) = dividend.div_rem(&divisor);
1200 assert_eq!(q, BigUint::from_u128(33_366_699_733_066_399));
1201 assert_eq!(r, BigUint::from_u64(26));
1202 assert_eq!(q.mul_ref(&divisor).add_ref(&r), dividend);
1203 }
1204
1205 #[test]
1206 fn sqrt_floor_small_values() {
1207 assert_eq!(BigUint::from_u64(0).sqrt_floor(), BigUint::from_u64(0));
1208 assert_eq!(BigUint::from_u64(1).sqrt_floor(), BigUint::from_u64(1));
1209 assert_eq!(BigUint::from_u64(2).sqrt_floor(), BigUint::from_u64(1));
1210 assert_eq!(BigUint::from_u64(15).sqrt_floor(), BigUint::from_u64(3));
1211 assert_eq!(BigUint::from_u64(16).sqrt_floor(), BigUint::from_u64(4));
1212 assert_eq!(BigUint::from_u64(17).sqrt_floor(), BigUint::from_u64(4));
1213 assert_eq!(
1214 BigUint::from_u128(17_184_849_881).sqrt_floor(),
1215 BigUint::from_u64(131_090)
1216 );
1217 }
1218
1219 #[test]
1220 fn mod_mul_matches_small_arithmetic() {
1221 let a = BigUint::from_u64(123_456_789);
1222 let b = BigUint::from_u64(987_654_321);
1223 let m = BigUint::from_u64(1_000_000_007);
1224 assert_eq!(BigUint::mod_mul(&a, &b, &m), BigUint::from_u64(259_106_859));
1225 }
1226
1227 #[test]
1228 fn montgomery_mod_pow_matches_small_arithmetic() {
1229 let ctx = MontgomeryCtx::new(&BigUint::from_u64(1_000_000_007))
1230 .expect("odd modulus builds a context");
1231 let base = BigUint::from_u64(123_456_789);
1232 let exponent = BigUint::from_u64(65_537);
1233 assert_eq!(ctx.pow(&base, &exponent), BigUint::from_u64(560_583_526));
1234 }
1235
1236 #[test]
1237 fn montgomery_ctx_mul_matches_small_arithmetic() {
1238 let ctx = MontgomeryCtx::new(&BigUint::from_u64(1_000_000_007))
1239 .expect("odd modulus builds a context");
1240 let a = BigUint::from_u64(123_456_789);
1241 let b = BigUint::from_u64(987_654_321);
1242 assert_eq!(ctx.mul(&a, &b), BigUint::from_u64(259_106_859));
1243 }
1244
1245 #[test]
1246 fn mod_mul_even_modulus_uses_fallback_path() {
1247 let a = BigUint::from_u64(37);
1248 let b = BigUint::from_u64(19);
1249 let modulus = BigUint::from_u64(100);
1250 assert_eq!(BigUint::mod_mul(&a, &b, &modulus), BigUint::from_u64(3));
1251 }
1252
1253 #[test]
1254 fn bigint_sign_normalization() {
1255 let zero = BigInt::from_parts(Sign::Negative, BigUint::zero());
1256 assert_eq!(zero.sign(), Sign::Zero);
1257
1258 let value = BigInt::from_parts(Sign::Positive, BigUint::from_u64(7));
1259 assert_eq!(value.negated().sign(), Sign::Negative);
1260 assert_eq!(value.magnitude(), &BigUint::from_u64(7));
1261 }
1262
1263 #[test]
1264 fn bigint_add_sub_and_modulo() {
1265 let a = BigInt::from_biguint(BigUint::from_u64(10));
1266 let b = BigInt::from_parts(Sign::Negative, BigUint::from_u64(3));
1267 assert_eq!(a.add_ref(&b), BigInt::from_biguint(BigUint::from_u64(7)));
1268 assert_eq!(
1269 b.sub_ref(&a),
1270 BigInt::from_parts(Sign::Negative, BigUint::from_u64(13))
1271 );
1272 assert_eq!(
1273 BigInt::from_parts(Sign::Negative, BigUint::from_u64(3))
1274 .modulo_positive(&BigUint::from_u64(11)),
1275 BigUint::from_u64(8)
1276 );
1277 }
1278}