1extern crate alloc;
5
6use core::cmp::Ordering;
7use core::ops::{
8 Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign,
9};
10
11use alloc::vec::Vec;
12
13#[derive(Debug, Clone, Copy)]
16pub(crate) enum LossFraction {
17 ExactlyZero, LessThanHalf, ExactlyHalf, MoreThanHalf, }
22
23impl LossFraction {
24 pub fn is_exactly_zero(&self) -> bool {
25 matches!(self, Self::ExactlyZero)
26 }
27 pub fn is_lt_half(&self) -> bool {
28 matches!(self, Self::LessThanHalf) || self.is_exactly_zero()
29 }
30 pub fn is_exactly_half(&self) -> bool {
31 matches!(self, Self::ExactlyHalf)
32 }
33 pub fn is_mt_half(&self) -> bool {
34 matches!(self, Self::MoreThanHalf)
35 }
36 #[allow(dead_code)]
37 pub fn is_lte_half(&self) -> bool {
38 self.is_lt_half() || self.is_exactly_half()
39 }
40 pub fn is_gte_half(&self) -> bool {
41 self.is_mt_half() || self.is_exactly_half()
42 }
43
44 pub fn invert(&self) -> LossFraction {
46 match self {
47 LossFraction::LessThanHalf => LossFraction::MoreThanHalf,
48 LossFraction::MoreThanHalf => LossFraction::LessThanHalf,
49 _ => *self,
50 }
51 }
52}
53#[derive(Debug, Clone)]
74pub struct BigInt {
75 parts: Vec<u64>,
76}
77
78impl BigInt {
79 pub fn zero() -> Self {
81 BigInt::from_u64(0)
82 }
83
84 pub fn one() -> Self {
86 Self::from_u64(1)
87 }
88
89 pub fn one_hot(bit: usize) -> Self {
91 let mut x = Self::zero();
92 x.flip_bit(bit);
93 x
94 }
95
96 pub fn all1s(bits: usize) -> Self {
98 if bits == 0 {
99 return Self::zero();
100 }
101 let mut x = Self::one();
102 x.shift_left(bits);
103 let _ = x.inplace_sub(&Self::one());
104 debug_assert_eq!(x.msb_index(), bits);
105 x
106 }
107
108 pub fn from_u64(val: u64) -> Self {
110 let vec = Vec::from([val]);
111 BigInt { parts: vec }
112 }
113
114 pub fn from_u128(val: u128) -> Self {
116 let a = val as u64;
117 let b = (val >> 64) as u64;
118 let vec = Vec::from([a, b]);
119 BigInt { parts: vec }
120 }
121
122 pub fn pseudorandom(parts: usize, seed: u32) -> Self {
125 use crate::utils::Lfsr;
126 let mut ll = Lfsr::new_with_seed(seed);
127
128 BigInt::from_iter(&mut ll, parts)
129 }
130
131 pub fn len(&self) -> usize {
132 self.parts.len()
133 }
134
135 pub fn is_empty(&self) -> bool {
136 self.parts.is_empty()
137 }
138
139 pub fn as_u64(&self) -> u64 {
141 for i in 1..self.len() {
142 debug_assert_eq!(self.parts[i], 0);
143 }
144 self.parts[0]
145 }
146
147 pub fn as_u128(&self) -> u128 {
149 if self.len() >= 2 {
150 for i in 2..self.len() {
151 debug_assert_eq!(self.parts[i], 0);
152 }
153 (self.parts[0] as u128) + ((self.parts[1] as u128) << 64)
154 } else {
155 self.parts[0] as u128
156 }
157 }
158
159 pub fn is_zero(&self) -> bool {
161 for elem in self.parts.iter() {
162 if *elem != 0 {
163 return false;
164 }
165 }
166 true
167 }
168
169 pub fn is_even(&self) -> bool {
171 (self.parts[0] & 0x1) == 0
172 }
173
174 pub fn is_odd(&self) -> bool {
176 (self.parts[0] & 0x1) == 1
177 }
178
179 pub fn flip_bit(&mut self, bit_num: usize) {
181 let which_word = bit_num / u64::BITS as usize;
182 let bit_in_word = bit_num % u64::BITS as usize;
183 self.grow(which_word + 1);
184 debug_assert!(which_word < self.len(), "Bit out of bounds");
185 self.parts[which_word] ^= 1 << bit_in_word;
186 }
187
188 pub fn mask(&mut self, bits: usize) {
190 let mut bits = bits;
191 for i in 0..self.len() {
192 if bits >= 64 {
193 bits -= 64;
194 continue;
195 }
196
197 if bits == 0 {
198 self.parts[i] = 0;
199 continue;
200 }
201
202 let mask = (1u64 << bits) - 1;
203 self.parts[i] &= mask;
204 bits = 0;
205 }
206 }
207
208 pub(crate) fn get_loss_kind_for_bit(&self, bit: usize) -> LossFraction {
210 if self.is_zero() {
211 return LossFraction::ExactlyZero;
212 }
213 if bit > self.len() * 64 {
214 return LossFraction::LessThanHalf;
215 }
216 let mut a = self.clone();
217 a.mask(bit);
218 if a.is_zero() {
219 return LossFraction::ExactlyZero;
220 }
221 let half = Self::one_hot(bit - 1);
222 match a.cmp(&half) {
223 Ordering::Less => LossFraction::LessThanHalf,
224 Ordering::Equal => LossFraction::ExactlyHalf,
225 Ordering::Greater => LossFraction::MoreThanHalf,
226 }
227 }
228
229 pub fn msb_index(&self) -> usize {
233 for i in (0..self.len()).rev() {
234 let part = self.parts[i];
235 if part != 0 {
236 let idx = 64 - part.leading_zeros() as usize;
237 return i * 64 + idx;
238 }
239 }
240 0
241 }
242
243 pub fn trailing_zeros(&self) -> usize {
246 debug_assert!(!self.is_zero());
247 for i in 0..self.len() {
248 let part = self.parts[i];
249 if part != 0 {
250 let idx = part.trailing_zeros() as usize;
251 return i * 64 + idx;
252 }
253 }
254 panic!("Expected a non-zero number");
255 }
256
257 pub fn from_parts(parts: &[u64]) -> Self {
259 let parts: Vec<u64> = parts.to_vec();
260 BigInt { parts }
261 }
262
263 pub fn from_iter<I: Iterator<Item = u64>>(iter: &mut I, k: usize) -> Self {
266 let parts: Vec<u64> = iter.take(k).collect();
267 BigInt { parts }
268 }
269
270 pub fn grow(&mut self, size: usize) {
272 for _ in self.len()..size {
273 self.parts.push(0);
274 }
275 }
276
277 fn shrink(&mut self) {
279 while self.len() > 2 && self.parts[self.len() - 1] == 0 {
280 self.parts.pop();
281 }
282 }
283
284 pub fn inplace_add(&mut self, rhs: &Self) {
286 self.inplace_add_slice(&rhs.parts[..]);
287 }
288
289 #[allow(clippy::needless_range_loop)]
291 pub(crate) fn inplace_add_slice(&mut self, rhs: &[u64]) {
292 self.grow(rhs.len());
293 let mut carry: bool = false;
294 for i in 0..rhs.len() {
295 let first = self.parts[i].overflowing_add(rhs[i]);
296 let second = first.0.overflowing_add(carry as u64);
297 carry = first.1 || second.1;
298 self.parts[i] = second.0;
299 }
300 for i in rhs.len()..self.len() {
302 let second = self.parts[i].overflowing_add(carry as u64);
303 carry = second.1;
304 self.parts[i] = second.0;
305 }
306 if carry {
307 self.parts.push(1);
308 }
309 self.shrink()
310 }
311
312 #[must_use]
314 pub fn inplace_sub(&mut self, rhs: &Self) -> bool {
315 self.inplace_sub_slice(&rhs.parts[..], 0)
316 }
317
318 #[allow(clippy::needless_range_loop)]
323 fn inplace_sub_slice(&mut self, rhs: &[u64], bottom_zeros: usize) -> bool {
324 self.grow(rhs.len());
325 let mut borrow: bool = false;
326 for i in bottom_zeros..rhs.len() {
329 let first = self.parts[i].overflowing_sub(rhs[i]);
330 let second = first.0.overflowing_sub(borrow as u64);
331 borrow = first.1 || second.1;
332 self.parts[i] = second.0;
333 }
334 for i in rhs.len()..self.len() {
336 let second = self.parts[i].overflowing_sub(borrow as u64);
337 self.parts[i] = second.0;
338 borrow = second.1;
339 }
340 self.shrink();
341 borrow
342 }
343
344 fn zeros(size: usize) -> Vec<u64> {
345 core::iter::repeat(0).take(size).collect()
346 }
347
348 pub fn inplace_mul(&mut self, rhs: &Self) {
350 if self.len() > KARATSUBA_SIZE_THRESHOLD
351 || rhs.len() > KARATSUBA_SIZE_THRESHOLD
352 {
353 *self = Self::mul_karatsuba(self, rhs);
354 return;
355 }
356 self.inplace_mul_slice(rhs);
357 }
358
359 fn inplace_mul_slice(&mut self, rhs: &[u64]) {
361 let size = self.len() + rhs.len() + 1;
362 let mut parts = Self::zeros(size);
363 let mut carries = Self::zeros(size);
364
365 for i in 0..self.len() {
366 for j in 0..rhs.len() {
367 let pi = self.parts[i] as u128;
368 let pij = pi * rhs[j] as u128;
369
370 let add0 = parts[i + j].overflowing_add(pij as u64);
371 parts[i + j] = add0.0;
372 carries[i + j] += add0.1 as u64;
373 let add1 = parts[i + j + 1].overflowing_add((pij >> 64) as u64);
374 parts[i + j + 1] = add1.0;
375 carries[i + j + 1] += add1.1 as u64;
376 }
377 }
378 self.grow(size);
379 let mut carry: u64 = 0;
380 for i in 0..size {
381 let add0 = parts[i].overflowing_add(carry);
382 self.parts[i] = add0.0;
383 carry = add0.1 as u64 + carries[i];
384 }
385 self.shrink();
386 assert!(carry == 0);
387 }
388
389 pub fn inplace_div(&mut self, divisor: &Self) -> Self {
391 let mut dividend = self.clone();
392 let mut divisor = divisor.clone();
393 let mut quotient = Self::zero();
394
395 if self.len() == 1 && divisor.parts.len() == 1 {
397 let a = dividend.get_part(0);
398 let b = divisor.get_part(0);
399 let res = a / b;
400 let rem = a % b;
401 self.parts[0] = res;
402 return Self::from_u64(rem);
403 }
404
405 let dividend_msb = dividend.msb_index();
406 let divisor_msb = divisor.msb_index();
407 assert_ne!(divisor_msb, 0, "division by zero");
408
409 if divisor_msb > dividend_msb {
410 let ret = self.clone();
411 *self = Self::zero();
412 return ret;
413 }
414
415 let bits = dividend_msb - divisor_msb;
418 divisor.shift_left(bits);
419
420 for i in (0..bits + 1).rev() {
422 let low_zeros = i / 64;
424
425 if dividend >= divisor {
426 let overflow = dividend.inplace_sub_slice(&divisor, low_zeros);
427 debug_assert!(!overflow);
428 quotient.flip_bit(i);
429 }
430 divisor.shift_right(1);
431 }
432
433 *self = quotient;
434 self.shrink();
435 dividend
436 }
437
438 pub fn shift_left(&mut self, bits: usize) {
440 let words_to_shift = bits / u64::BITS as usize;
441 let bits_in_word = bits % u64::BITS as usize;
442
443 for _ in 0..words_to_shift + 1 {
444 self.parts.push(0);
445 }
446
447 if bits_in_word == 0 {
449 for i in (0..self.len()).rev() {
450 self.parts[i] = if i >= words_to_shift {
451 self.parts[i - words_to_shift]
452 } else {
453 0
454 };
455 }
456 return;
457 }
458
459 for i in (0..self.len()).rev() {
460 let left_val = if i >= words_to_shift {
461 self.parts[i - words_to_shift]
462 } else {
463 0
464 };
465 let right_val = if i > words_to_shift {
466 self.parts[i - words_to_shift - 1]
467 } else {
468 0
469 };
470 let right = right_val >> (u64::BITS as usize - bits_in_word);
471 let left = left_val << bits_in_word;
472 self.parts[i] = left | right;
473 }
474 }
475
476 pub fn shift_right(&mut self, bits: usize) {
478 let words_to_shift = bits / u64::BITS as usize;
479 let bits_in_word = bits % u64::BITS as usize;
480
481 if bits_in_word == 0 {
483 for i in 0..self.len() {
484 self.parts[i] = if i + words_to_shift < self.len() {
485 self.parts[i + words_to_shift]
486 } else {
487 0
488 };
489 }
490 self.shrink();
491 return;
492 }
493
494 for i in 0..self.len() {
495 let left_val = if i + words_to_shift < self.len() {
496 self.parts[i + words_to_shift]
497 } else {
498 0
499 };
500 let right_val = if i + 1 + words_to_shift < self.len() {
501 self.parts[i + 1 + words_to_shift]
502 } else {
503 0
504 };
505 let right = right_val << (u64::BITS as usize - bits_in_word);
506 let left = left_val >> bits_in_word;
507 self.parts[i] = left | right;
508 }
509 self.shrink();
510 }
511
512 pub fn powi(&self, mut exp: u64) -> Self {
514 let mut v = Self::one();
515 let mut base = self.clone();
516 loop {
517 if exp & 0x1 == 1 {
518 v.inplace_mul(&base);
519 }
520 exp >>= 1;
521 if exp == 0 {
522 break;
523 }
524 base.inplace_mul(&base.clone());
525 }
526 v
527 }
528
529 pub fn get_part(&self, idx: usize) -> u64 {
531 self.parts[idx]
532 }
533
534 #[cfg(feature = "std")]
535 pub fn dump(&self) {
536 use std::println;
537 println!("[{}]", self.as_binary());
538 }
539
540 #[cfg(not(feature = "std"))]
541 pub fn dump(&self) {
542 }
544}
545
546impl Default for BigInt {
547 fn default() -> Self {
548 Self::zero()
549 }
550}
551
552#[test]
553fn test_powi5() {
554 let lookup = [1, 5, 25, 125, 625, 3125, 15625, 78125];
555 for (i, val) in lookup.iter().enumerate() {
556 let five = BigInt::from_u64(5);
557 assert_eq!(five.powi(i as u64).as_u64(), *val);
558 }
559
560 let v15 = BigInt::from_u64(15);
562 assert_eq!(v15.powi(16).as_u64(), 6568408355712890625);
563
564 let v3 = BigInt::from_u64(3);
566 assert_eq!(v3.powi(21).as_u64(), 10460353203);
567}
568
569#[test]
570fn test_shl() {
571 let mut x = BigInt::from_u64(0xff00ff);
572 assert_eq!(x.get_part(0), 0xff00ff);
573 x.shift_left(17);
574 assert_eq!(x.get_part(0), 0x1fe01fe0000);
575 x.shift_left(17);
576 assert_eq!(x.get_part(0), 0x3fc03fc00000000);
577 x.shift_left(64);
578 assert_eq!(x.get_part(1), 0x3fc03fc00000000);
579}
580
581#[test]
582fn test_shr() {
583 let mut x = BigInt::from_u64(0xff00ff);
584 x.shift_left(128);
585 assert_eq!(x.get_part(2), 0xff00ff);
586 x.shift_right(17);
587 assert_eq!(x.get_part(1), 0x807f800000000000);
588 x.shift_right(17);
589 assert_eq!(x.get_part(1), 0x03fc03fc0000000);
590 x.shift_right(64);
591 assert_eq!(x.get_part(0), 0x03fc03fc0000000);
592}
593
594#[test]
595fn test_mul_basic() {
596 let mut x = BigInt::from_u64(0xffff_ffff_ffff_ffff);
597 let y = BigInt::from_u64(25);
598 x.inplace_mul(&x.clone());
599 x.inplace_mul(&y);
600 assert_eq!(x.get_part(0), 0x19);
601 assert_eq!(x.get_part(1), 0xffff_ffff_ffff_ffce);
602 assert_eq!(x.get_part(2), 0x18);
603}
604
605#[test]
606fn test_add_basic() {
607 let mut x = BigInt::from_u64(0xffffffff00000000);
608 let y = BigInt::from_u64(0xffffffff);
609 let z = BigInt::from_u64(0xf);
610 x.inplace_add(&y);
611 assert_eq!(x.get_part(0), 0xffffffffffffffff);
612 x.inplace_add(&z);
613 assert_eq!(x.get_part(0), 0xe);
614 assert_eq!(x.get_part(1), 0x1);
615}
616
617#[test]
618fn test_div_basic() {
619 let mut x1 = BigInt::from_u64(49);
620 let mut x2 = BigInt::from_u64(703);
621 let y = BigInt::from_u64(7);
622
623 let rem = x1.inplace_div(&y);
624 assert_eq!(x1.as_u64(), 7);
625 assert_eq!(rem.as_u64(), 0);
626
627 let rem = x2.inplace_div(&y);
628 assert_eq!(x2.as_u64(), 100);
629 assert_eq!(rem.as_u64(), 3);
630}
631
632#[test]
633fn test_div_10() {
634 let mut x1 = BigInt::from_u64(19940521);
635 let ten = BigInt::from_u64(10);
636 assert_eq!(x1.inplace_div(&ten).as_u64(), 1);
637 assert_eq!(x1.inplace_div(&ten).as_u64(), 2);
638 assert_eq!(x1.inplace_div(&ten).as_u64(), 5);
639 assert_eq!(x1.inplace_div(&ten).as_u64(), 0);
640 assert_eq!(x1.inplace_div(&ten).as_u64(), 4);
641}
642
643#[allow(dead_code)]
644fn test_with_random_values(
645 correct: fn(u128, u128) -> (u128, bool),
646 test: fn(u128, u128) -> (u128, bool),
647) {
648 use super::utils::Lfsr;
649
650 let mut lfsr = Lfsr::new();
652
653 for _ in 0..50000 {
654 let v0 = lfsr.get64();
655 let v1 = lfsr.get64();
656 let v2 = lfsr.get64();
657 let v3 = lfsr.get64();
658
659 let n1 = (v0 as u128) + ((v1 as u128) << 64);
660 let n2 = (v2 as u128) + ((v3 as u128) << 64);
661
662 let v1 = correct(n1, n2);
663 let v2 = test(n1, n2);
664 assert_eq!(v1.0, v2.0, "Incorrect value");
665 assert_eq!(v1.0, v2.0, "Incorrect carry");
666 }
667}
668
669#[test]
670fn test_sub_basic() {
671 let mut x = BigInt::from_parts(&[0x0, 0x1, 0]);
673 let y = BigInt::from_u64(0x1);
674 let c1 = x.inplace_sub(&y);
675 assert!(!c1);
676 assert_eq!(x.get_part(0), 0xffffffffffffffff);
677 assert_eq!(x.get_part(1), 0);
678
679 let mut x = BigInt::from_parts(&[0x1, 0x1]);
680 let y = BigInt::from_parts(&[0x0, 0x1, 0x0]);
681 let c1 = x.inplace_sub(&y);
682 assert!(!c1);
683 assert_eq!(x.get_part(0), 0x1);
684 assert_eq!(x.get_part(1), 0);
685
686 let mut x = BigInt::from_parts(&[0x1, 0x1, 0x1]);
687 let y = BigInt::from_parts(&[0x0, 0x1, 0x0]);
688 let c1 = x.inplace_sub(&y);
689 assert!(!c1);
690 assert_eq!(x.get_part(0), 0x1);
691 assert_eq!(x.get_part(1), 0);
692 assert_eq!(x.get_part(2), 0x1);
693}
694
695#[test]
696fn test_mask_basic() {
697 let mut x = BigInt::from_parts(&[0b11111, 0b10101010101010, 0b111]);
698 x.mask(69);
699 assert_eq!(x.get_part(0), 0b11111); assert_eq!(x.get_part(1), 0b01010); assert_eq!(x.get_part(2), 0b0); }
703
704#[test]
705fn test_basic_operations() {
706 fn correct_sub(a: u128, b: u128) -> (u128, bool) {
709 a.overflowing_sub(b)
710 }
711 fn correct_add(a: u128, b: u128) -> (u128, bool) {
712 a.overflowing_add(b)
713 }
714 fn correct_mul(a: u128, b: u128) -> (u128, bool) {
715 a.overflowing_mul(b)
716 }
717 fn correct_div(a: u128, b: u128) -> (u128, bool) {
718 a.overflowing_div(b)
719 }
720
721 fn test_sub(a: u128, b: u128) -> (u128, bool) {
722 let mut a = BigInt::from_u128(a);
723 let b = BigInt::from_u128(b);
724 let c = a.inplace_sub(&b);
725 (a.as_u128(), c)
726 }
727 fn test_add(a: u128, b: u128) -> (u128, bool) {
728 let mut a = BigInt::from_u128(a);
729 let b = BigInt::from_u128(b);
730 let mut carry = false;
731 a.inplace_add(&b);
732 if a.len() > 2 {
733 carry = true;
734 a.parts[2] = 0;
735 }
736
737 (a.as_u128(), carry)
738 }
739 fn test_mul(a: u128, b: u128) -> (u128, bool) {
740 let mut a = BigInt::from_u128(a);
741 let b = BigInt::from_u128(b);
742 let mut carry = false;
743 a.inplace_mul(&b);
744 if a.len() > 2 {
745 carry = true;
746 a.parts[2] = 0;
747 a.parts[3] = 0;
748 }
749 (a.as_u128(), carry)
750 }
751 fn test_div(a: u128, b: u128) -> (u128, bool) {
752 let mut a = BigInt::from_u128(a);
753 let b = BigInt::from_u128(b);
754 a.inplace_div(&b);
755 (a.as_u128(), false)
756 }
757
758 fn correct_cmp(a: u128, b: u128) -> (u128, bool) {
759 (
760 match a.cmp(&b) {
761 Ordering::Less => 1,
762 Ordering::Equal => 2,
763 Ordering::Greater => 3,
764 } as u128,
765 false,
766 )
767 }
768 fn test_cmp(a: u128, b: u128) -> (u128, bool) {
769 let a = BigInt::from_u128(a);
770 let b = BigInt::from_u128(b);
771
772 (
773 match a.cmp(&b) {
774 Ordering::Less => 1,
775 Ordering::Equal => 2,
776 Ordering::Greater => 3,
777 } as u128,
778 false,
779 )
780 }
781
782 test_with_random_values(correct_mul, test_mul);
783 test_with_random_values(correct_div, test_div);
784 test_with_random_values(correct_add, test_add);
785 test_with_random_values(correct_sub, test_sub);
786 test_with_random_values(correct_cmp, test_cmp);
787}
788
789#[test]
790fn test_msb() {
791 let x = BigInt::from_u64(0xffffffff00000000);
792 assert_eq!(x.msb_index(), 64);
793
794 let x = BigInt::from_u64(0x0);
795 assert_eq!(x.msb_index(), 0);
796
797 let x = BigInt::from_u64(0x1);
798 assert_eq!(x.msb_index(), 1);
799
800 let mut x = BigInt::from_u64(0x1);
801 x.shift_left(189);
802 assert_eq!(x.msb_index(), 189 + 1);
803
804 for i in 0..256 {
805 let mut x = BigInt::from_u64(0x1);
806 x.shift_left(i);
807 assert_eq!(x.msb_index(), i + 1);
808 }
809}
810
811#[test]
812fn test_trailing_zero() {
813 let x = BigInt::from_u64(0xffffffff00000000);
814 assert_eq!(x.trailing_zeros(), 32);
815
816 let x = BigInt::from_u64(0x1);
817 assert_eq!(x.trailing_zeros(), 0);
818
819 let x = BigInt::from_u64(0x8);
820 assert_eq!(x.trailing_zeros(), 3);
821
822 let mut x = BigInt::from_u64(0x1);
823 x.shift_left(189);
824 assert_eq!(x.trailing_zeros(), 189);
825
826 for i in 0..256 {
827 let mut x = BigInt::from_u64(0x1);
828 x.shift_left(i);
829 assert_eq!(x.trailing_zeros(), i);
830 }
831}
832impl Eq for BigInt {}
833
834impl PartialEq for BigInt {
835 fn eq(&self, other: &BigInt) -> bool {
836 self.cmp(other).is_eq()
837 }
838}
839impl PartialOrd for BigInt {
840 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
841 Some(self.cmp(other))
842 }
843}
844impl Ord for BigInt {
845 fn cmp(&self, other: &Self) -> Ordering {
846 if self.len() > other.len()
848 && self.parts[other.len()..].iter().any(|&x| x != 0)
849 {
850 return Ordering::Greater;
851 }
852
853 if other.len() > self.len()
855 && other.parts[self.len()..].iter().any(|&x| x != 0)
856 {
857 return Ordering::Less;
858 }
859 let same_len = other.len().min(self.len());
860
861 for i in (0..same_len).rev() {
863 match self.parts[i].cmp(&other.parts[i]) {
864 Ordering::Less => return Ordering::Less,
865 Ordering::Equal => {}
866 Ordering::Greater => return Ordering::Greater,
867 }
868 }
869 Ordering::Equal
870 }
871}
872
873macro_rules! declare_operator {
874 ($trait_name:ident,
875 $func_name:ident,
876 $func_impl_name:ident) => {
877 impl $trait_name for BigInt {
879 type Output = Self;
880
881 fn $func_name(self, rhs: Self) -> Self::Output {
882 self.$func_name(&rhs)
883 }
884 }
885
886 impl $trait_name<&Self> for BigInt {
888 type Output = Self;
889 fn $func_name(self, rhs: &Self) -> Self::Output {
890 let mut n = self;
891 let _ = n.$func_impl_name(rhs);
892 n
893 }
894 }
895
896 impl $trait_name<Self> for &BigInt {
898 type Output = BigInt;
899 fn $func_name(self, rhs: Self) -> Self::Output {
900 let mut n = self.clone();
901 let _ = n.$func_impl_name(rhs);
902 n
903 }
904 }
905
906 impl $trait_name<u64> for BigInt {
908 type Output = Self;
909 fn $func_name(self, rhs: u64) -> Self::Output {
910 let mut n = self;
911 let _ = n.$func_impl_name(&Self::from_u64(rhs));
912 n
913 }
914 }
915 };
916}
917
918declare_operator!(Add, add, inplace_add);
919declare_operator!(Sub, sub, inplace_sub);
920declare_operator!(Mul, mul, inplace_mul);
921declare_operator!(Div, div, inplace_div);
922
923macro_rules! declare_assign_operator {
924 ($trait_name:ident,
925 $func_name:ident,
926 $func_impl_name:ident) => {
927 impl $trait_name for BigInt {
928 fn $func_name(&mut self, rhs: Self) {
929 let _ = self.$func_impl_name(&rhs);
930 }
931 }
932
933 impl $trait_name<&BigInt> for BigInt {
934 fn $func_name(&mut self, rhs: &Self) {
935 let _ = self.$func_impl_name(&rhs);
936 }
937 }
938 };
939}
940
941declare_assign_operator!(AddAssign, add_assign, inplace_add);
942declare_assign_operator!(SubAssign, sub_assign, inplace_sub);
943declare_assign_operator!(MulAssign, mul_assign, inplace_mul);
944declare_assign_operator!(DivAssign, div_assign, inplace_div);
945
946#[test]
947fn test_bigint_operators() {
948 type BI = BigInt;
949 let x = BI::from_u64(10);
950 let y = BI::from_u64(1);
951
952 let c = ((&x - &y) * x) / 2;
953 assert_eq!(c.as_u64(), 45);
954 assert_eq!((&y + &y).as_u64(), 2);
955}
956
957#[test]
958fn test_all1s_ctor() {
959 type BI = BigInt;
960 let v0 = BI::all1s(0);
961 let v1 = BI::all1s(1);
962 let v2 = BI::all1s(5);
963 let v3 = BI::all1s(32);
964
965 assert_eq!(v0.get_part(0), 0b0);
966 assert_eq!(v1.get_part(0), 0b1);
967 assert_eq!(v2.get_part(0), 0b11111);
968 assert_eq!(v3.get_part(0), 0xffffffff);
969}
970
971#[test]
972fn test_flip_bit() {
973 type BI = BigInt;
974
975 {
976 let mut v0 = BI::zero();
977 assert_eq!(v0.get_part(0), 0);
978 v0.flip_bit(0);
979 assert_eq!(v0.get_part(0), 1);
980 v0.flip_bit(0);
981 assert_eq!(v0.get_part(0), 0);
982 }
983
984 {
985 let mut v0 = BI::zero();
986 v0.flip_bit(16);
987 assert_eq!(v0.get_part(0), 65536);
988 }
989
990 {
991 let mut v0 = BI::zero();
992 v0.flip_bit(95);
993 v0.shift_right(95);
994 assert_eq!(v0.get_part(0), 1);
995 }
996}
997
998#[cfg(feature = "std")]
999#[test]
1000fn test_mul_div_encode_decode() {
1001 use alloc::vec::Vec;
1002 const BASE: u64 = 5;
1004 type BI = BigInt;
1005 let base = BI::from_u64(BASE);
1006 let mut bitstream = BI::from_u64(0);
1007 let mut message: Vec<u64> = Vec::new();
1008
1009 for i in 0..275 {
1012 message.push(((i + 6) * 17) % BASE);
1013 }
1014
1015 for letter in &message {
1017 let letter = BI::from_u64(*letter);
1018 bitstream.inplace_mul(&base);
1019 bitstream.inplace_add(&letter);
1020 }
1021
1022 let len = message.len();
1023 for idx in (0..len).rev() {
1025 let rem = bitstream.inplace_div(&base);
1026 assert_eq!(message[idx], rem.as_u64());
1027 }
1028}
1029
1030impl BigInt {
1031 fn to_digits_impl<const DIGIT: u8>(
1036 num: &mut BigInt,
1037 num_digits: usize,
1038 output: &mut Vec<u8>,
1039 ) -> usize {
1040 const SPLIT_WORD_THRESHOLD: usize = 5;
1041
1042 let bits_per_digit = (8 - DIGIT.leading_zeros()) as usize;
1044 let digits_per_word = 64 / bits_per_digit;
1045 let digit = DIGIT as u64;
1046
1047 let len = num.len();
1049 if len > SPLIT_WORD_THRESHOLD {
1050 let half = len / 2 - 1;
1051 let k = digits_per_word * half;
1053 let mega_digit = BigInt::from_u64(digit).powi(k as u64);
1055 let mut rem = num.inplace_div(&mega_digit);
1057
1058 let tail = Self::to_digits_impl::<DIGIT>(&mut rem, k, output);
1060 let hd = Self::to_digits_impl::<DIGIT>(num, num_digits - k, output);
1061 debug_assert_eq!(tail, k);
1062 debug_assert_eq!(hd, num_digits - k);
1063 return num_digits;
1064 }
1065
1066 let mut extracted = 0;
1067
1068 let divisor = BigInt::from_u64(digit.pow(digits_per_word as u32));
1070 for _ in 0..(num_digits / digits_per_word) {
1072 let mut rem = num.inplace_div(&divisor);
1074 extracted += digits_per_word;
1076 Self::extract_digits::<DIGIT>(digits_per_word, &mut rem, output);
1077 }
1078
1079 let iters = num_digits % digits_per_word;
1081 Self::extract_digits::<DIGIT>(iters, num, output);
1082 extracted += iters;
1083
1084 extracted
1085 }
1086
1087 fn extract_digits<const DIGIT: u8>(
1089 iter: usize,
1090 num: &mut BigInt,
1091 vec: &mut Vec<u8>,
1092 ) {
1093 let digit = BigInt::from_u64(DIGIT as u64);
1094 for _ in 0..iter {
1095 let d = num.inplace_div(&digit).as_u64();
1096 vec.push(d as u8);
1097 }
1098 }
1099
1100 pub(crate) fn to_digits<const DIGIT: u8>(&self) -> Vec<u8> {
1102 let mut num = self.clone();
1103 num.shrink();
1104
1105 let mut output: Vec<u8> = Vec::new();
1106
1107 while !num.is_zero() {
1108 let len = num.len();
1109 let digits = (len * 64 * 59) / 196;
1112 Self::to_digits_impl::<DIGIT>(&mut num, digits, &mut output);
1113 }
1114
1115 while output.len() > 1 && output[output.len() - 1] == 0 {
1118 output.pop();
1119 }
1120 output.reverse();
1121 output
1122 }
1123}
1124
1125#[test]
1126pub fn test_bigint_to_digits() {
1127 use alloc::string::String;
1128 use core::primitive::char;
1129 fn vec_to_string(vec: Vec<u8>, base: u32) -> String {
1131 let mut sb = String::new();
1132 for d in vec {
1133 sb.push(char::from_digit(d as u32, base).unwrap())
1134 }
1135 sb
1136 }
1137
1138 let mut num = BigInt::from_u64(0b111000111000101010);
1140 num.shift_left(64);
1141 let digits = num.to_digits::<2>();
1142 assert_eq!(
1143 vec_to_string(digits, 2),
1144 "1110001110001010100000000000000\
1145 0000000000000000000000000000000\
1146 00000000000000000000"
1147 );
1148
1149 let num = BigInt::from_u64(90210);
1151 let digits = num.to_digits::<10>();
1152 assert_eq!(vec_to_string(digits, 10), "90210");
1153
1154 let num = BigInt::from_u128(123_456_123_456_987_654_987_654u128);
1156 let digits = num.to_digits::<10>();
1157 assert_eq!(vec_to_string(digits, 10), "123456123456987654987654");
1158}
1159
1160const KARATSUBA_SIZE_THRESHOLD: usize = 64;
1164
1165impl BigInt {
1166 fn mul_karatsuba(lhs: &[u64], rhs: &[u64]) -> BigInt {
1167 if lhs.len().min(rhs.len()) < KARATSUBA_SIZE_THRESHOLD {
1172 if lhs.is_empty() || rhs.is_empty() {
1174 return BigInt::zero();
1175 }
1176 let mut lhs = BigInt::from_parts(lhs);
1177 lhs.inplace_mul_slice(rhs);
1178 return lhs;
1179 }
1180
1181 let mid = lhs.len().max(rhs.len()) / 2;
1184 let a = &lhs[0..mid.min(lhs.len())];
1185 let b = &lhs[mid.min(lhs.len())..];
1186 let c = &rhs[0..mid.min(rhs.len())];
1187 let d = &rhs[mid.min(rhs.len())..];
1188
1189 let ac = Self::mul_karatsuba(a, c);
1191 let mut bd = Self::mul_karatsuba(b, d);
1192
1193 let mut a_b = BigInt::from_parts(a);
1195 a_b.inplace_add_slice(b);
1196 let mut c_d = BigInt::from_parts(c);
1197 c_d.inplace_add_slice(d);
1198
1199 let mut ad_plus_bc = Self::mul_karatsuba(&a_b, &c_d);
1200
1201 ad_plus_bc.inplace_sub_slice(&ac, 0);
1203 ad_plus_bc.inplace_sub_slice(&bd, 0);
1204
1205 bd.shift_left(64 * mid * 2);
1207 ad_plus_bc.shift_left(64 * mid);
1208 bd.inplace_add(&ad_plus_bc);
1209 bd.inplace_add(&ac);
1210 bd
1211 }
1212}
1213
1214#[test]
1215fn test_mul_karatsuba() {
1216 use crate::utils::Lfsr;
1217 let mut ll = Lfsr::new();
1218
1219 fn test_sizes(l: usize, r: usize, ll: &mut Lfsr) {
1222 let mut a = BigInt::from_iter(ll, l);
1223 let b = BigInt::from_iter(ll, r);
1224 let res = BigInt::mul_karatsuba(&a, &b);
1225 a.inplace_mul_slice(&b);
1226 assert_eq!(res, a);
1227 }
1228
1229 test_sizes(1, 1, &mut ll);
1230 test_sizes(100, 1, &mut ll);
1231 test_sizes(1, 100, &mut ll);
1232 test_sizes(100, 100, &mut ll);
1233 test_sizes(1000, 1000, &mut ll);
1234 test_sizes(1000, 1001, &mut ll);
1235
1236 for i in 64..90 {
1238 for j in 1..128 {
1239 test_sizes(i, j, &mut ll);
1240 }
1241 }
1242}
1243
1244use core::ops::Deref;
1245
1246impl Deref for BigInt {
1247 type Target = [u64];
1248
1249 fn deref(&self) -> &Self::Target {
1250 &self.parts[..]
1251 }
1252}