arpfloat/
bigint.rs

1//! This module contains the implementation of the big-int data structure that
2//! we use for the significand of the float.
3
4extern crate alloc;
5
6use core::cmp::Ordering;
7use core::ops::{
8    Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign,
9};
10
11use alloc::vec::Vec;
12
13/// Reports the kind of values that are lost when we shift right bits. In some
14/// context this used as the two guard bits.
15#[derive(Debug, Clone, Copy)]
16pub(crate) enum LossFraction {
17    ExactlyZero,  //0000000
18    LessThanHalf, //0xxxxxx
19    ExactlyHalf,  //1000000
20    MoreThanHalf, //1xxxxxx
21}
22
23impl LossFraction {
24    pub fn is_exactly_zero(&self) -> bool {
25        matches!(self, Self::ExactlyZero)
26    }
27    pub fn is_lt_half(&self) -> bool {
28        matches!(self, Self::LessThanHalf) || self.is_exactly_zero()
29    }
30    pub fn is_exactly_half(&self) -> bool {
31        matches!(self, Self::ExactlyHalf)
32    }
33    pub fn is_mt_half(&self) -> bool {
34        matches!(self, Self::MoreThanHalf)
35    }
36    #[allow(dead_code)]
37    pub fn is_lte_half(&self) -> bool {
38        self.is_lt_half() || self.is_exactly_half()
39    }
40    pub fn is_gte_half(&self) -> bool {
41        self.is_mt_half() || self.is_exactly_half()
42    }
43
44    // Return the inverted loss fraction.
45    pub fn invert(&self) -> LossFraction {
46        match self {
47            LossFraction::LessThanHalf => LossFraction::MoreThanHalf,
48            LossFraction::MoreThanHalf => LossFraction::LessThanHalf,
49            _ => *self,
50        }
51    }
52}
53/// This is an arbitrary-size unsigned big number implementation. It is used to
54/// store the mantissa of the floating point number. The BigInt data structure
55/// is backed by `Vec<u64>`, and the data is heap-allocated. BigInt implements
56/// the basic arithmetic operations such as add, sub, div, mul, etc.
57///
58/// # Examples
59///
60/// ```
61///    use arpfloat::BigInt;
62///
63///    let x = BigInt::from_u64(1995);
64///    let y = BigInt::from_u64(90210);
65///
66///    let z = x * y;
67///    let z = z.powi(10);
68///
69///    // Prints: 3564312949426686000....
70///    println!("{}", z.as_decimal());
71/// ```
72///
73#[derive(Debug, Clone)]
74pub struct BigInt {
75    parts: Vec<u64>,
76}
77
78impl BigInt {
79    /// Create a new zero big int number.
80    pub fn zero() -> Self {
81        BigInt::from_u64(0)
82    }
83
84    /// Create a new number with the value 1.
85    pub fn one() -> Self {
86        Self::from_u64(1)
87    }
88
89    /// Create a new number with a single '1' set at bit `bit`.
90    pub fn one_hot(bit: usize) -> Self {
91        let mut x = Self::zero();
92        x.flip_bit(bit);
93        x
94    }
95
96    /// Create a new number, where the first `bits` bits are set to 1.
97    pub fn all1s(bits: usize) -> Self {
98        if bits == 0 {
99            return Self::zero();
100        }
101        let mut x = Self::one();
102        x.shift_left(bits);
103        let _ = x.inplace_sub(&Self::one());
104        debug_assert_eq!(x.msb_index(), bits);
105        x
106    }
107
108    /// Create a number and set the lowest 64 bits to `val`.
109    pub fn from_u64(val: u64) -> Self {
110        let vec = Vec::from([val]);
111        BigInt { parts: vec }
112    }
113
114    /// Create a number and set the lowest 128 bits to `val`.
115    pub fn from_u128(val: u128) -> Self {
116        let a = val as u64;
117        let b = (val >> 64) as u64;
118        let vec = Vec::from([a, b]);
119        BigInt { parts: vec }
120    }
121
122    /// Create a pseudorandom number with `parts` number of parts in the word.
123    /// The random number generator is initialized with `seed`.
124    pub fn pseudorandom(parts: usize, seed: u32) -> Self {
125        use crate::utils::Lfsr;
126        let mut ll = Lfsr::new_with_seed(seed);
127
128        BigInt::from_iter(&mut ll, parts)
129    }
130
131    pub fn len(&self) -> usize {
132        self.parts.len()
133    }
134
135    pub fn is_empty(&self) -> bool {
136        self.parts.is_empty()
137    }
138
139    /// Returns the lowest 64 bits.
140    pub fn as_u64(&self) -> u64 {
141        for i in 1..self.len() {
142            debug_assert_eq!(self.parts[i], 0);
143        }
144        self.parts[0]
145    }
146
147    /// Returns the lowest 64 bits.
148    pub fn as_u128(&self) -> u128 {
149        if self.len() >= 2 {
150            for i in 2..self.len() {
151                debug_assert_eq!(self.parts[i], 0);
152            }
153            (self.parts[0] as u128) + ((self.parts[1] as u128) << 64)
154        } else {
155            self.parts[0] as u128
156        }
157    }
158
159    /// Return true if the number is equal to zero.
160    pub fn is_zero(&self) -> bool {
161        for elem in self.parts.iter() {
162            if *elem != 0 {
163                return false;
164            }
165        }
166        true
167    }
168
169    /// Returns true if this number is even.
170    pub fn is_even(&self) -> bool {
171        (self.parts[0] & 0x1) == 0
172    }
173
174    /// Returns true if this number is odd.
175    pub fn is_odd(&self) -> bool {
176        (self.parts[0] & 0x1) == 1
177    }
178
179    /// Flip the `bit_num` bit.
180    pub fn flip_bit(&mut self, bit_num: usize) {
181        let which_word = bit_num / u64::BITS as usize;
182        let bit_in_word = bit_num % u64::BITS as usize;
183        self.grow(which_word + 1);
184        debug_assert!(which_word < self.len(), "Bit out of bounds");
185        self.parts[which_word] ^= 1 << bit_in_word;
186    }
187
188    /// Zero out all of the bits above `bits`.
189    pub fn mask(&mut self, bits: usize) {
190        let mut bits = bits;
191        for i in 0..self.len() {
192            if bits >= 64 {
193                bits -= 64;
194                continue;
195            }
196
197            if bits == 0 {
198                self.parts[i] = 0;
199                continue;
200            }
201
202            let mask = (1u64 << bits) - 1;
203            self.parts[i] &= mask;
204            bits = 0;
205        }
206    }
207
208    /// Returns the fractional part that's lost during truncation at `bit`.
209    pub(crate) fn get_loss_kind_for_bit(&self, bit: usize) -> LossFraction {
210        if self.is_zero() {
211            return LossFraction::ExactlyZero;
212        }
213        if bit > self.len() * 64 {
214            return LossFraction::LessThanHalf;
215        }
216        let mut a = self.clone();
217        a.mask(bit);
218        if a.is_zero() {
219            return LossFraction::ExactlyZero;
220        }
221        let half = Self::one_hot(bit - 1);
222        match a.cmp(&half) {
223            Ordering::Less => LossFraction::LessThanHalf,
224            Ordering::Equal => LossFraction::ExactlyHalf,
225            Ordering::Greater => LossFraction::MoreThanHalf,
226        }
227    }
228
229    /// Returns the index of the most significant bit (the highest '1'),
230    /// using 1-based counting (the first bit is 1, and zero means no bits are
231    /// set).
232    pub fn msb_index(&self) -> usize {
233        for i in (0..self.len()).rev() {
234            let part = self.parts[i];
235            if part != 0 {
236                let idx = 64 - part.leading_zeros() as usize;
237                return i * 64 + idx;
238            }
239        }
240        0
241    }
242
243    /// Returns the index of the first '1' in the number. The number must not
244    ///  be a zero.
245    pub fn trailing_zeros(&self) -> usize {
246        debug_assert!(!self.is_zero());
247        for i in 0..self.len() {
248            let part = self.parts[i];
249            if part != 0 {
250                let idx = part.trailing_zeros() as usize;
251                return i * 64 + idx;
252            }
253        }
254        panic!("Expected a non-zero number");
255    }
256
257    // Construct a bigint from the words in 'parts'.
258    pub fn from_parts(parts: &[u64]) -> Self {
259        let parts: Vec<u64> = parts.to_vec();
260        BigInt { parts }
261    }
262
263    // Construct a bigint from an iterator that generates u64 parts.
264    // Take the first 'k' words.
265    pub fn from_iter<I: Iterator<Item = u64>>(iter: &mut I, k: usize) -> Self {
266        let parts: Vec<u64> = iter.take(k).collect();
267        BigInt { parts }
268    }
269
270    /// Ensure that there are at least 'size' words in the bigint.
271    pub fn grow(&mut self, size: usize) {
272        for _ in self.len()..size {
273            self.parts.push(0);
274        }
275    }
276
277    /// Remove the leading zero words from the bigint.
278    fn shrink(&mut self) {
279        while self.len() > 2 && self.parts[self.len() - 1] == 0 {
280            self.parts.pop();
281        }
282    }
283
284    /// Add `rhs` to this number.
285    pub fn inplace_add(&mut self, rhs: &Self) {
286        self.inplace_add_slice(&rhs.parts[..]);
287    }
288
289    /// Implements addition of the 'rhs' sequence of words to this number.
290    #[allow(clippy::needless_range_loop)]
291    pub(crate) fn inplace_add_slice(&mut self, rhs: &[u64]) {
292        self.grow(rhs.len());
293        let mut carry: bool = false;
294        for i in 0..rhs.len() {
295            let first = self.parts[i].overflowing_add(rhs[i]);
296            let second = first.0.overflowing_add(carry as u64);
297            carry = first.1 || second.1;
298            self.parts[i] = second.0;
299        }
300        // Continue to propagate the carry flag.
301        for i in rhs.len()..self.len() {
302            let second = self.parts[i].overflowing_add(carry as u64);
303            carry = second.1;
304            self.parts[i] = second.0;
305        }
306        if carry {
307            self.parts.push(1);
308        }
309        self.shrink()
310    }
311
312    /// Add `rhs` to self, and return true if the operation overflowed (borrow).
313    #[must_use]
314    pub fn inplace_sub(&mut self, rhs: &Self) -> bool {
315        self.inplace_sub_slice(&rhs.parts[..], 0)
316    }
317
318    /// Implements subtraction of the 'rhs' sequence of words to this number.
319    /// The parameter `known_zeros` specifies how many lower *words* in `rhs`
320    /// are zeros and can be ignored. This is used by the division algorithm
321    /// that shifts the divisor.
322    #[allow(clippy::needless_range_loop)]
323    fn inplace_sub_slice(&mut self, rhs: &[u64], bottom_zeros: usize) -> bool {
324        self.grow(rhs.len());
325        let mut borrow: bool = false;
326        // Do the part of the vectors that both sides have.
327
328        for i in bottom_zeros..rhs.len() {
329            let first = self.parts[i].overflowing_sub(rhs[i]);
330            let second = first.0.overflowing_sub(borrow as u64);
331            borrow = first.1 || second.1;
332            self.parts[i] = second.0;
333        }
334        // Propagate the carry bit.
335        for i in rhs.len()..self.len() {
336            let second = self.parts[i].overflowing_sub(borrow as u64);
337            self.parts[i] = second.0;
338            borrow = second.1;
339        }
340        self.shrink();
341        borrow
342    }
343
344    fn zeros(size: usize) -> Vec<u64> {
345        core::iter::repeat(0).take(size).collect()
346    }
347
348    /// Multiply `rhs` to self, and return true if the operation overflowed.
349    pub fn inplace_mul(&mut self, rhs: &Self) {
350        if self.len() > KARATSUBA_SIZE_THRESHOLD
351            || rhs.len() > KARATSUBA_SIZE_THRESHOLD
352        {
353            *self = Self::mul_karatsuba(self, rhs);
354            return;
355        }
356        self.inplace_mul_slice(rhs);
357    }
358
359    /// Implements multiplication of the 'rhs' sequence of words to this number.
360    fn inplace_mul_slice(&mut self, rhs: &[u64]) {
361        let size = self.len() + rhs.len() + 1;
362        let mut parts = Self::zeros(size);
363        let mut carries = Self::zeros(size);
364
365        for i in 0..self.len() {
366            for j in 0..rhs.len() {
367                let pi = self.parts[i] as u128;
368                let pij = pi * rhs[j] as u128;
369
370                let add0 = parts[i + j].overflowing_add(pij as u64);
371                parts[i + j] = add0.0;
372                carries[i + j] += add0.1 as u64;
373                let add1 = parts[i + j + 1].overflowing_add((pij >> 64) as u64);
374                parts[i + j + 1] = add1.0;
375                carries[i + j + 1] += add1.1 as u64;
376            }
377        }
378        self.grow(size);
379        let mut carry: u64 = 0;
380        for i in 0..size {
381            let add0 = parts[i].overflowing_add(carry);
382            self.parts[i] = add0.0;
383            carry = add0.1 as u64 + carries[i];
384        }
385        self.shrink();
386        assert!(carry == 0);
387    }
388
389    /// Divide self by `divisor`, and return the reminder.
390    pub fn inplace_div(&mut self, divisor: &Self) -> Self {
391        let mut dividend = self.clone();
392        let mut divisor = divisor.clone();
393        let mut quotient = Self::zero();
394
395        // Single word division.
396        if self.len() == 1 && divisor.parts.len() == 1 {
397            let a = dividend.get_part(0);
398            let b = divisor.get_part(0);
399            let res = a / b;
400            let rem = a % b;
401            self.parts[0] = res;
402            return Self::from_u64(rem);
403        }
404
405        let dividend_msb = dividend.msb_index();
406        let divisor_msb = divisor.msb_index();
407        assert_ne!(divisor_msb, 0, "division by zero");
408
409        if divisor_msb > dividend_msb {
410            let ret = self.clone();
411            *self = Self::zero();
412            return ret;
413        }
414
415        // Align the first bit of the divisor with the first bit of the
416        // dividend.
417        let bits = dividend_msb - divisor_msb;
418        divisor.shift_left(bits);
419
420        // Perform the long division.
421        for i in (0..bits + 1).rev() {
422            // Find out how many of the lower words of the divisor are zeros.
423            let low_zeros = i / 64;
424
425            if dividend >= divisor {
426                let overflow = dividend.inplace_sub_slice(&divisor, low_zeros);
427                debug_assert!(!overflow);
428                quotient.flip_bit(i);
429            }
430            divisor.shift_right(1);
431        }
432
433        *self = quotient;
434        self.shrink();
435        dividend
436    }
437
438    /// Shift the bits in the numbers `bits` to the left.
439    pub fn shift_left(&mut self, bits: usize) {
440        let words_to_shift = bits / u64::BITS as usize;
441        let bits_in_word = bits % u64::BITS as usize;
442
443        for _ in 0..words_to_shift + 1 {
444            self.parts.push(0);
445        }
446
447        // If we only need to move blocks.
448        if bits_in_word == 0 {
449            for i in (0..self.len()).rev() {
450                self.parts[i] = if i >= words_to_shift {
451                    self.parts[i - words_to_shift]
452                } else {
453                    0
454                };
455            }
456            return;
457        }
458
459        for i in (0..self.len()).rev() {
460            let left_val = if i >= words_to_shift {
461                self.parts[i - words_to_shift]
462            } else {
463                0
464            };
465            let right_val = if i > words_to_shift {
466                self.parts[i - words_to_shift - 1]
467            } else {
468                0
469            };
470            let right = right_val >> (u64::BITS as usize - bits_in_word);
471            let left = left_val << bits_in_word;
472            self.parts[i] = left | right;
473        }
474    }
475
476    /// Shift the bits in the numbers `bits` to the right.
477    pub fn shift_right(&mut self, bits: usize) {
478        let words_to_shift = bits / u64::BITS as usize;
479        let bits_in_word = bits % u64::BITS as usize;
480
481        // If we only need to move blocks.
482        if bits_in_word == 0 {
483            for i in 0..self.len() {
484                self.parts[i] = if i + words_to_shift < self.len() {
485                    self.parts[i + words_to_shift]
486                } else {
487                    0
488                };
489            }
490            self.shrink();
491            return;
492        }
493
494        for i in 0..self.len() {
495            let left_val = if i + words_to_shift < self.len() {
496                self.parts[i + words_to_shift]
497            } else {
498                0
499            };
500            let right_val = if i + 1 + words_to_shift < self.len() {
501                self.parts[i + 1 + words_to_shift]
502            } else {
503                0
504            };
505            let right = right_val << (u64::BITS as usize - bits_in_word);
506            let left = left_val >> bits_in_word;
507            self.parts[i] = left | right;
508        }
509        self.shrink();
510    }
511
512    /// Raise this number to the power of `exp` and return the value.
513    pub fn powi(&self, mut exp: u64) -> Self {
514        let mut v = Self::one();
515        let mut base = self.clone();
516        loop {
517            if exp & 0x1 == 1 {
518                v.inplace_mul(&base);
519            }
520            exp >>= 1;
521            if exp == 0 {
522                break;
523            }
524            base.inplace_mul(&base.clone());
525        }
526        v
527    }
528
529    /// Returns the word at idx `idx`.
530    pub fn get_part(&self, idx: usize) -> u64 {
531        self.parts[idx]
532    }
533
534    #[cfg(feature = "std")]
535    pub fn dump(&self) {
536        use std::println;
537        println!("[{}]", self.as_binary());
538    }
539
540    #[cfg(not(feature = "std"))]
541    pub fn dump(&self) {
542        // No-op in no_std environments
543    }
544}
545
546impl Default for BigInt {
547    fn default() -> Self {
548        Self::zero()
549    }
550}
551
552#[test]
553fn test_powi5() {
554    let lookup = [1, 5, 25, 125, 625, 3125, 15625, 78125];
555    for (i, val) in lookup.iter().enumerate() {
556        let five = BigInt::from_u64(5);
557        assert_eq!(five.powi(i as u64).as_u64(), *val);
558    }
559
560    // 15 ^ 16
561    let v15 = BigInt::from_u64(15);
562    assert_eq!(v15.powi(16).as_u64(), 6568408355712890625);
563
564    // 3 ^ 21
565    let v3 = BigInt::from_u64(3);
566    assert_eq!(v3.powi(21).as_u64(), 10460353203);
567}
568
569#[test]
570fn test_shl() {
571    let mut x = BigInt::from_u64(0xff00ff);
572    assert_eq!(x.get_part(0), 0xff00ff);
573    x.shift_left(17);
574    assert_eq!(x.get_part(0), 0x1fe01fe0000);
575    x.shift_left(17);
576    assert_eq!(x.get_part(0), 0x3fc03fc00000000);
577    x.shift_left(64);
578    assert_eq!(x.get_part(1), 0x3fc03fc00000000);
579}
580
581#[test]
582fn test_shr() {
583    let mut x = BigInt::from_u64(0xff00ff);
584    x.shift_left(128);
585    assert_eq!(x.get_part(2), 0xff00ff);
586    x.shift_right(17);
587    assert_eq!(x.get_part(1), 0x807f800000000000);
588    x.shift_right(17);
589    assert_eq!(x.get_part(1), 0x03fc03fc0000000);
590    x.shift_right(64);
591    assert_eq!(x.get_part(0), 0x03fc03fc0000000);
592}
593
594#[test]
595fn test_mul_basic() {
596    let mut x = BigInt::from_u64(0xffff_ffff_ffff_ffff);
597    let y = BigInt::from_u64(25);
598    x.inplace_mul(&x.clone());
599    x.inplace_mul(&y);
600    assert_eq!(x.get_part(0), 0x19);
601    assert_eq!(x.get_part(1), 0xffff_ffff_ffff_ffce);
602    assert_eq!(x.get_part(2), 0x18);
603}
604
605#[test]
606fn test_add_basic() {
607    let mut x = BigInt::from_u64(0xffffffff00000000);
608    let y = BigInt::from_u64(0xffffffff);
609    let z = BigInt::from_u64(0xf);
610    x.inplace_add(&y);
611    assert_eq!(x.get_part(0), 0xffffffffffffffff);
612    x.inplace_add(&z);
613    assert_eq!(x.get_part(0), 0xe);
614    assert_eq!(x.get_part(1), 0x1);
615}
616
617#[test]
618fn test_div_basic() {
619    let mut x1 = BigInt::from_u64(49);
620    let mut x2 = BigInt::from_u64(703);
621    let y = BigInt::from_u64(7);
622
623    let rem = x1.inplace_div(&y);
624    assert_eq!(x1.as_u64(), 7);
625    assert_eq!(rem.as_u64(), 0);
626
627    let rem = x2.inplace_div(&y);
628    assert_eq!(x2.as_u64(), 100);
629    assert_eq!(rem.as_u64(), 3);
630}
631
632#[test]
633fn test_div_10() {
634    let mut x1 = BigInt::from_u64(19940521);
635    let ten = BigInt::from_u64(10);
636    assert_eq!(x1.inplace_div(&ten).as_u64(), 1);
637    assert_eq!(x1.inplace_div(&ten).as_u64(), 2);
638    assert_eq!(x1.inplace_div(&ten).as_u64(), 5);
639    assert_eq!(x1.inplace_div(&ten).as_u64(), 0);
640    assert_eq!(x1.inplace_div(&ten).as_u64(), 4);
641}
642
643#[allow(dead_code)]
644fn test_with_random_values(
645    correct: fn(u128, u128) -> (u128, bool),
646    test: fn(u128, u128) -> (u128, bool),
647) {
648    use super::utils::Lfsr;
649
650    // Test addition, multiplication, subtraction with random values.
651    let mut lfsr = Lfsr::new();
652
653    for _ in 0..50000 {
654        let v0 = lfsr.get64();
655        let v1 = lfsr.get64();
656        let v2 = lfsr.get64();
657        let v3 = lfsr.get64();
658
659        let n1 = (v0 as u128) + ((v1 as u128) << 64);
660        let n2 = (v2 as u128) + ((v3 as u128) << 64);
661
662        let v1 = correct(n1, n2);
663        let v2 = test(n1, n2);
664        assert_eq!(v1.0, v2.0, "Incorrect value");
665        assert_eq!(v1.0, v2.0, "Incorrect carry");
666    }
667}
668
669#[test]
670fn test_sub_basic() {
671    // Check a single overflowing sub operation.
672    let mut x = BigInt::from_parts(&[0x0, 0x1, 0]);
673    let y = BigInt::from_u64(0x1);
674    let c1 = x.inplace_sub(&y);
675    assert!(!c1);
676    assert_eq!(x.get_part(0), 0xffffffffffffffff);
677    assert_eq!(x.get_part(1), 0);
678
679    let mut x = BigInt::from_parts(&[0x1, 0x1]);
680    let y = BigInt::from_parts(&[0x0, 0x1, 0x0]);
681    let c1 = x.inplace_sub(&y);
682    assert!(!c1);
683    assert_eq!(x.get_part(0), 0x1);
684    assert_eq!(x.get_part(1), 0);
685
686    let mut x = BigInt::from_parts(&[0x1, 0x1, 0x1]);
687    let y = BigInt::from_parts(&[0x0, 0x1, 0x0]);
688    let c1 = x.inplace_sub(&y);
689    assert!(!c1);
690    assert_eq!(x.get_part(0), 0x1);
691    assert_eq!(x.get_part(1), 0);
692    assert_eq!(x.get_part(2), 0x1);
693}
694
695#[test]
696fn test_mask_basic() {
697    let mut x = BigInt::from_parts(&[0b11111, 0b10101010101010, 0b111]);
698    x.mask(69);
699    assert_eq!(x.get_part(0), 0b11111); // No change
700    assert_eq!(x.get_part(1), 0b01010); // Keep the bottom 5 bits.
701    assert_eq!(x.get_part(2), 0b0); // Zero.
702}
703
704#[test]
705fn test_basic_operations() {
706    // Check Add, Mul, Sub, in comparison to the double implementation.
707
708    fn correct_sub(a: u128, b: u128) -> (u128, bool) {
709        a.overflowing_sub(b)
710    }
711    fn correct_add(a: u128, b: u128) -> (u128, bool) {
712        a.overflowing_add(b)
713    }
714    fn correct_mul(a: u128, b: u128) -> (u128, bool) {
715        a.overflowing_mul(b)
716    }
717    fn correct_div(a: u128, b: u128) -> (u128, bool) {
718        a.overflowing_div(b)
719    }
720
721    fn test_sub(a: u128, b: u128) -> (u128, bool) {
722        let mut a = BigInt::from_u128(a);
723        let b = BigInt::from_u128(b);
724        let c = a.inplace_sub(&b);
725        (a.as_u128(), c)
726    }
727    fn test_add(a: u128, b: u128) -> (u128, bool) {
728        let mut a = BigInt::from_u128(a);
729        let b = BigInt::from_u128(b);
730        let mut carry = false;
731        a.inplace_add(&b);
732        if a.len() > 2 {
733            carry = true;
734            a.parts[2] = 0;
735        }
736
737        (a.as_u128(), carry)
738    }
739    fn test_mul(a: u128, b: u128) -> (u128, bool) {
740        let mut a = BigInt::from_u128(a);
741        let b = BigInt::from_u128(b);
742        let mut carry = false;
743        a.inplace_mul(&b);
744        if a.len() > 2 {
745            carry = true;
746            a.parts[2] = 0;
747            a.parts[3] = 0;
748        }
749        (a.as_u128(), carry)
750    }
751    fn test_div(a: u128, b: u128) -> (u128, bool) {
752        let mut a = BigInt::from_u128(a);
753        let b = BigInt::from_u128(b);
754        a.inplace_div(&b);
755        (a.as_u128(), false)
756    }
757
758    fn correct_cmp(a: u128, b: u128) -> (u128, bool) {
759        (
760            match a.cmp(&b) {
761                Ordering::Less => 1,
762                Ordering::Equal => 2,
763                Ordering::Greater => 3,
764            } as u128,
765            false,
766        )
767    }
768    fn test_cmp(a: u128, b: u128) -> (u128, bool) {
769        let a = BigInt::from_u128(a);
770        let b = BigInt::from_u128(b);
771
772        (
773            match a.cmp(&b) {
774                Ordering::Less => 1,
775                Ordering::Equal => 2,
776                Ordering::Greater => 3,
777            } as u128,
778            false,
779        )
780    }
781
782    test_with_random_values(correct_mul, test_mul);
783    test_with_random_values(correct_div, test_div);
784    test_with_random_values(correct_add, test_add);
785    test_with_random_values(correct_sub, test_sub);
786    test_with_random_values(correct_cmp, test_cmp);
787}
788
789#[test]
790fn test_msb() {
791    let x = BigInt::from_u64(0xffffffff00000000);
792    assert_eq!(x.msb_index(), 64);
793
794    let x = BigInt::from_u64(0x0);
795    assert_eq!(x.msb_index(), 0);
796
797    let x = BigInt::from_u64(0x1);
798    assert_eq!(x.msb_index(), 1);
799
800    let mut x = BigInt::from_u64(0x1);
801    x.shift_left(189);
802    assert_eq!(x.msb_index(), 189 + 1);
803
804    for i in 0..256 {
805        let mut x = BigInt::from_u64(0x1);
806        x.shift_left(i);
807        assert_eq!(x.msb_index(), i + 1);
808    }
809}
810
811#[test]
812fn test_trailing_zero() {
813    let x = BigInt::from_u64(0xffffffff00000000);
814    assert_eq!(x.trailing_zeros(), 32);
815
816    let x = BigInt::from_u64(0x1);
817    assert_eq!(x.trailing_zeros(), 0);
818
819    let x = BigInt::from_u64(0x8);
820    assert_eq!(x.trailing_zeros(), 3);
821
822    let mut x = BigInt::from_u64(0x1);
823    x.shift_left(189);
824    assert_eq!(x.trailing_zeros(), 189);
825
826    for i in 0..256 {
827        let mut x = BigInt::from_u64(0x1);
828        x.shift_left(i);
829        assert_eq!(x.trailing_zeros(), i);
830    }
831}
832impl Eq for BigInt {}
833
834impl PartialEq for BigInt {
835    fn eq(&self, other: &BigInt) -> bool {
836        self.cmp(other).is_eq()
837    }
838}
839impl PartialOrd for BigInt {
840    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
841        Some(self.cmp(other))
842    }
843}
844impl Ord for BigInt {
845    fn cmp(&self, other: &Self) -> Ordering {
846        // This part word is longer.
847        if self.len() > other.len()
848            && self.parts[other.len()..].iter().any(|&x| x != 0)
849        {
850            return Ordering::Greater;
851        }
852
853        // The other word is longer.
854        if other.len() > self.len()
855            && other.parts[self.len()..].iter().any(|&x| x != 0)
856        {
857            return Ordering::Less;
858        }
859        let same_len = other.len().min(self.len());
860
861        // Compare all of the digits, from MSB to LSB.
862        for i in (0..same_len).rev() {
863            match self.parts[i].cmp(&other.parts[i]) {
864                Ordering::Less => return Ordering::Less,
865                Ordering::Equal => {}
866                Ordering::Greater => return Ordering::Greater,
867            }
868        }
869        Ordering::Equal
870    }
871}
872
873macro_rules! declare_operator {
874    ($trait_name:ident,
875     $func_name:ident,
876     $func_impl_name:ident) => {
877        // Self + Self
878        impl $trait_name for BigInt {
879            type Output = Self;
880
881            fn $func_name(self, rhs: Self) -> Self::Output {
882                self.$func_name(&rhs)
883            }
884        }
885
886        // Self + &Self -> Self
887        impl $trait_name<&Self> for BigInt {
888            type Output = Self;
889            fn $func_name(self, rhs: &Self) -> Self::Output {
890                let mut n = self;
891                let _ = n.$func_impl_name(rhs);
892                n
893            }
894        }
895
896        // &Self + &Self -> Self
897        impl $trait_name<Self> for &BigInt {
898            type Output = BigInt;
899            fn $func_name(self, rhs: Self) -> Self::Output {
900                let mut n = self.clone();
901                let _ = n.$func_impl_name(rhs);
902                n
903            }
904        }
905
906        // &Self + u64 -> Self
907        impl $trait_name<u64> for BigInt {
908            type Output = Self;
909            fn $func_name(self, rhs: u64) -> Self::Output {
910                let mut n = self;
911                let _ = n.$func_impl_name(&Self::from_u64(rhs));
912                n
913            }
914        }
915    };
916}
917
918declare_operator!(Add, add, inplace_add);
919declare_operator!(Sub, sub, inplace_sub);
920declare_operator!(Mul, mul, inplace_mul);
921declare_operator!(Div, div, inplace_div);
922
923macro_rules! declare_assign_operator {
924    ($trait_name:ident,
925     $func_name:ident,
926     $func_impl_name:ident) => {
927        impl $trait_name for BigInt {
928            fn $func_name(&mut self, rhs: Self) {
929                let _ = self.$func_impl_name(&rhs);
930            }
931        }
932
933        impl $trait_name<&BigInt> for BigInt {
934            fn $func_name(&mut self, rhs: &Self) {
935                let _ = self.$func_impl_name(&rhs);
936            }
937        }
938    };
939}
940
941declare_assign_operator!(AddAssign, add_assign, inplace_add);
942declare_assign_operator!(SubAssign, sub_assign, inplace_sub);
943declare_assign_operator!(MulAssign, mul_assign, inplace_mul);
944declare_assign_operator!(DivAssign, div_assign, inplace_div);
945
946#[test]
947fn test_bigint_operators() {
948    type BI = BigInt;
949    let x = BI::from_u64(10);
950    let y = BI::from_u64(1);
951
952    let c = ((&x - &y) * x) / 2;
953    assert_eq!(c.as_u64(), 45);
954    assert_eq!((&y + &y).as_u64(), 2);
955}
956
957#[test]
958fn test_all1s_ctor() {
959    type BI = BigInt;
960    let v0 = BI::all1s(0);
961    let v1 = BI::all1s(1);
962    let v2 = BI::all1s(5);
963    let v3 = BI::all1s(32);
964
965    assert_eq!(v0.get_part(0), 0b0);
966    assert_eq!(v1.get_part(0), 0b1);
967    assert_eq!(v2.get_part(0), 0b11111);
968    assert_eq!(v3.get_part(0), 0xffffffff);
969}
970
971#[test]
972fn test_flip_bit() {
973    type BI = BigInt;
974
975    {
976        let mut v0 = BI::zero();
977        assert_eq!(v0.get_part(0), 0);
978        v0.flip_bit(0);
979        assert_eq!(v0.get_part(0), 1);
980        v0.flip_bit(0);
981        assert_eq!(v0.get_part(0), 0);
982    }
983
984    {
985        let mut v0 = BI::zero();
986        v0.flip_bit(16);
987        assert_eq!(v0.get_part(0), 65536);
988    }
989
990    {
991        let mut v0 = BI::zero();
992        v0.flip_bit(95);
993        v0.shift_right(95);
994        assert_eq!(v0.get_part(0), 1);
995    }
996}
997
998#[cfg(feature = "std")]
999#[test]
1000fn test_mul_div_encode_decode() {
1001    use alloc::vec::Vec;
1002    // Take a string of symbols and encode them into one large number.
1003    const BASE: u64 = 5;
1004    type BI = BigInt;
1005    let base = BI::from_u64(BASE);
1006    let mut bitstream = BI::from_u64(0);
1007    let mut message: Vec<u64> = Vec::new();
1008
1009    // We can fit this many digits in the bignum without overflowing.
1010    // Generate a random message.
1011    for i in 0..275 {
1012        message.push(((i + 6) * 17) % BASE);
1013    }
1014
1015    // Encode the message.
1016    for letter in &message {
1017        let letter = BI::from_u64(*letter);
1018        bitstream.inplace_mul(&base);
1019        bitstream.inplace_add(&letter);
1020    }
1021
1022    let len = message.len();
1023    // Decode the message
1024    for idx in (0..len).rev() {
1025        let rem = bitstream.inplace_div(&base);
1026        assert_eq!(message[idx], rem.as_u64());
1027    }
1028}
1029
1030impl BigInt {
1031    /// Converts this number into a sequence of digits in the range 0..DIGIT.
1032    /// Use a recursive algorithm to split the number in half, if the number is
1033    /// too big.
1034    /// Return the number of digits that were converted.
1035    fn to_digits_impl<const DIGIT: u8>(
1036        num: &mut BigInt,
1037        num_digits: usize,
1038        output: &mut Vec<u8>,
1039    ) -> usize {
1040        const SPLIT_WORD_THRESHOLD: usize = 5;
1041
1042        // Figure out how many digits fit in a single word.
1043        let bits_per_digit = (8 - DIGIT.leading_zeros()) as usize;
1044        let digits_per_word = 64 / bits_per_digit;
1045        let digit = DIGIT as u64;
1046
1047        // If the word is too big, split it in half.
1048        let len = num.len();
1049        if len > SPLIT_WORD_THRESHOLD {
1050            let half = len / 2 - 1;
1051            // Figure out how many digits to extract:
1052            let k = digits_per_word * half;
1053            // Create a mega digit (a*a*a*a....).
1054            let mega_digit = BigInt::from_u64(digit).powi(k as u64);
1055            // Extract the lowest k digits.
1056            let mut rem = num.inplace_div(&mega_digit);
1057
1058            // Convert the two parts to digits:
1059            let tail = Self::to_digits_impl::<DIGIT>(&mut rem, k, output);
1060            let hd = Self::to_digits_impl::<DIGIT>(num, num_digits - k, output);
1061            debug_assert_eq!(tail, k);
1062            debug_assert_eq!(hd, num_digits - k);
1063            return num_digits;
1064        }
1065
1066        let mut extracted = 0;
1067
1068        // Multiply a*a*a*a ... until we fill a 64bit word.
1069        let divisor = BigInt::from_u64(digit.pow(digits_per_word as u32));
1070        // For each word:
1071        for _ in 0..(num_digits / digits_per_word) {
1072            // Pull a single word of [a*a*a*a ....].
1073            let mut rem = num.inplace_div(&divisor);
1074            // This is fast because we operate on a single word.
1075            extracted += digits_per_word;
1076            Self::extract_digits::<DIGIT>(digits_per_word, &mut rem, output);
1077        }
1078
1079        // Handle the rest of the digits.
1080        let iters = num_digits % digits_per_word;
1081        Self::extract_digits::<DIGIT>(iters, num, output);
1082        extracted += iters;
1083
1084        extracted
1085    }
1086
1087    // Extract 'iter' digits from 'num', one by one, and push them to 'vec'.
1088    fn extract_digits<const DIGIT: u8>(
1089        iter: usize,
1090        num: &mut BigInt,
1091        vec: &mut Vec<u8>,
1092    ) {
1093        let digit = BigInt::from_u64(DIGIT as u64);
1094        for _ in 0..iter {
1095            let d = num.inplace_div(&digit).as_u64();
1096            vec.push(d as u8);
1097        }
1098    }
1099
1100    /// Converts this number into a sequence of digits in the range 0..DIGIT.
1101    pub(crate) fn to_digits<const DIGIT: u8>(&self) -> Vec<u8> {
1102        let mut num = self.clone();
1103        num.shrink();
1104
1105        let mut output: Vec<u8> = Vec::new();
1106
1107        while !num.is_zero() {
1108            let len = num.len();
1109            // Figure out how many digits fit in the number.
1110            // See 'get_decimal_accuracy'.
1111            let digits = (len * 64 * 59) / 196;
1112            Self::to_digits_impl::<DIGIT>(&mut num, digits, &mut output);
1113        }
1114
1115        // Eliminate leading zeros.
1116
1117        while output.len() > 1 && output[output.len() - 1] == 0 {
1118            output.pop();
1119        }
1120        output.reverse();
1121        output
1122    }
1123}
1124
1125#[test]
1126pub fn test_bigint_to_digits() {
1127    use alloc::string::String;
1128    use core::primitive::char;
1129    /// Convert the vector of digits 'vec' of base 'base' into a string.
1130    fn vec_to_string(vec: Vec<u8>, base: u32) -> String {
1131        let mut sb = String::new();
1132        for d in vec {
1133            sb.push(char::from_digit(d as u32, base).unwrap())
1134        }
1135        sb
1136    }
1137
1138    // Test binary.
1139    let mut num = BigInt::from_u64(0b111000111000101010);
1140    num.shift_left(64);
1141    let digits = num.to_digits::<2>();
1142    assert_eq!(
1143        vec_to_string(digits, 2),
1144        "1110001110001010100000000000000\
1145        0000000000000000000000000000000\
1146        00000000000000000000"
1147    );
1148
1149    // Test base 10.
1150    let num = BigInt::from_u64(90210);
1151    let digits = num.to_digits::<10>();
1152    assert_eq!(vec_to_string(digits, 10), "90210");
1153
1154    // Test base 10 long.
1155    let num = BigInt::from_u128(123_456_123_456_987_654_987_654u128);
1156    let digits = num.to_digits::<10>();
1157    assert_eq!(vec_to_string(digits, 10), "123456123456987654987654");
1158}
1159
1160/// Bigint numbers above this size use the karatsuba algorithm for
1161/// multiplication. The number represents the number of words in the bigint.
1162/// Numbers below this threshold use the traditional O(n^2) multiplication.
1163const KARATSUBA_SIZE_THRESHOLD: usize = 64;
1164
1165impl BigInt {
1166    fn mul_karatsuba(lhs: &[u64], rhs: &[u64]) -> BigInt {
1167        // Algorithm description:
1168        // https://en.wikipedia.org/wiki/Karatsuba_algorithm
1169
1170        // Handle small numbers using the traditional O(n^2) algorithm.
1171        if lhs.len().min(rhs.len()) < KARATSUBA_SIZE_THRESHOLD {
1172            // Handle zero-sized inputs.
1173            if lhs.is_empty() || rhs.is_empty() {
1174                return BigInt::zero();
1175            }
1176            let mut lhs = BigInt::from_parts(lhs);
1177            lhs.inplace_mul_slice(rhs);
1178            return lhs;
1179        }
1180
1181        // Split the big-int into two parts. One of the parts might be
1182        // zero-sized.
1183        let mid = lhs.len().max(rhs.len()) / 2;
1184        let a = &lhs[0..mid.min(lhs.len())];
1185        let b = &lhs[mid.min(lhs.len())..];
1186        let c = &rhs[0..mid.min(rhs.len())];
1187        let d = &rhs[mid.min(rhs.len())..];
1188
1189        // Compute 'a*c' and 'b*d'.
1190        let ac = Self::mul_karatsuba(a, c);
1191        let mut bd = Self::mul_karatsuba(b, d);
1192
1193        // Compute (a+b) * (c+d).
1194        let mut a_b = BigInt::from_parts(a);
1195        a_b.inplace_add_slice(b);
1196        let mut c_d = BigInt::from_parts(c);
1197        c_d.inplace_add_slice(d);
1198
1199        let mut ad_plus_bc = Self::mul_karatsuba(&a_b, &c_d);
1200
1201        // Compute (a+b) * (c+d) - ac - bd
1202        ad_plus_bc.inplace_sub_slice(&ac, 0);
1203        ad_plus_bc.inplace_sub_slice(&bd, 0);
1204
1205        // Add the parts of the word together.
1206        bd.shift_left(64 * mid * 2);
1207        ad_plus_bc.shift_left(64 * mid);
1208        bd.inplace_add(&ad_plus_bc);
1209        bd.inplace_add(&ac);
1210        bd
1211    }
1212}
1213
1214#[test]
1215fn test_mul_karatsuba() {
1216    use crate::utils::Lfsr;
1217    let mut ll = Lfsr::new();
1218
1219    // Compare the multiplication of karatsuba to the direct multiplication on
1220    // two random numbers of lengths 'r' and 'l'.
1221    fn test_sizes(l: usize, r: usize, ll: &mut Lfsr) {
1222        let mut a = BigInt::from_iter(ll, l);
1223        let b = BigInt::from_iter(ll, r);
1224        let res = BigInt::mul_karatsuba(&a, &b);
1225        a.inplace_mul_slice(&b);
1226        assert_eq!(res, a);
1227    }
1228
1229    test_sizes(1, 1, &mut ll);
1230    test_sizes(100, 1, &mut ll);
1231    test_sizes(1, 100, &mut ll);
1232    test_sizes(100, 100, &mut ll);
1233    test_sizes(1000, 1000, &mut ll);
1234    test_sizes(1000, 1001, &mut ll);
1235
1236    // Try numbers of different sizes.
1237    for i in 64..90 {
1238        for j in 1..128 {
1239            test_sizes(i, j, &mut ll);
1240        }
1241    }
1242}
1243
1244use core::ops::Deref;
1245
1246impl Deref for BigInt {
1247    type Target = [u64];
1248
1249    fn deref(&self) -> &Self::Target {
1250        &self.parts[..]
1251    }
1252}