arithmetic_eval/arith/
bigint.rs

1use num_bigint::{BigInt, BigUint};
2use num_traits::{One, Signed, Zero};
3
4use core::{convert::TryFrom, mem};
5
6use super::{Arithmetic, ArithmeticError, ModularArithmetic};
7
8impl ModularArithmetic<BigUint> {
9    fn invert_big(&self, value: BigUint) -> Option<BigUint> {
10        let value = value % &self.modulus; // Reduce value since this influences speed.
11        let mut t = BigInt::zero();
12        let mut new_t = BigInt::one();
13
14        let modulus = BigInt::from(self.modulus.clone());
15        let mut r = modulus.clone();
16        let mut new_r = BigInt::from(value);
17
18        while !new_r.is_zero() {
19            let quotient = &r / &new_r;
20            t -= &quotient * &new_t;
21            mem::swap(&mut new_t, &mut t);
22            r -= quotient * &new_r;
23            mem::swap(&mut new_r, &mut r);
24        }
25
26        if r > BigInt::one() {
27            None // r = gcd(self.modulus, value) > 1
28        } else {
29            if t.is_negative() {
30                t += modulus;
31            }
32            Some(BigUint::try_from(t).unwrap())
33            // ^-- `unwrap` is safe by construction
34        }
35    }
36}
37
38impl Arithmetic<BigUint> for ModularArithmetic<BigUint> {
39    fn add(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
40        Ok((x + y) % &self.modulus)
41    }
42
43    fn sub(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
44        let y_neg = &self.modulus - (y % &self.modulus);
45        self.add(x, y_neg)
46    }
47
48    fn mul(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
49        Ok((x * y) % &self.modulus)
50    }
51
52    fn div(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
53        if y.is_zero() {
54            Err(ArithmeticError::DivisionByZero)
55        } else {
56            let y_inv = self.invert_big(y).ok_or(ArithmeticError::NoInverse)?;
57            self.mul(x, y_inv)
58        }
59    }
60
61    fn pow(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
62        Ok(x.modpow(&y, &self.modulus))
63    }
64
65    fn neg(&self, x: BigUint) -> Result<BigUint, ArithmeticError> {
66        let x = x % &self.modulus;
67        Ok(&self.modulus - x)
68    }
69
70    fn eq(&self, x: &BigUint, y: &BigUint) -> bool {
71        x % &self.modulus == y % &self.modulus
72    }
73}
74
75#[cfg(test)]
76mod bigint_tests {
77    use super::*;
78    use crate::arith::{CheckedArithmetic, NegateOnlyZero, OrdArithmetic, Unchecked};
79
80    use num_bigint::{BigInt, BigUint};
81    use rand::{rngs::StdRng, Rng, SeedableRng};
82    use static_assertions::assert_impl_all;
83
84    assert_impl_all!(CheckedArithmetic<NegateOnlyZero>: OrdArithmetic<BigUint>);
85    assert_impl_all!(CheckedArithmetic<Unchecked>: OrdArithmetic<BigInt>);
86    assert_impl_all!(ModularArithmetic<BigUint>: Arithmetic<BigUint>);
87
88    fn gen_biguint<R: Rng>(rng: &mut R, bits: u64) -> BigUint {
89        let bits = usize::try_from(bits).expect("Capacity overflow");
90        let (div, rem) = (bits / 8, bits % 8);
91
92        let mut buffer = vec![0_u8; div + (rem != 0) as usize];
93        rng.fill_bytes(&mut buffer);
94        if rem > 0 {
95            // Zero out most significant bits in the first byte.
96            let mask = u8::try_from((1_u16 << rem) - 1).unwrap();
97            buffer[0] &= mask;
98        }
99
100        BigUint::from_bytes_be(&buffer)
101    }
102
103    fn mini_fuzz_for_big_prime_modulus(modulus: &BigUint, sample_count: usize) {
104        let arithmetic = ModularArithmetic::new(modulus.clone());
105        let mut rng = StdRng::seed_from_u64(modulus.bits());
106        let signed_modulus = BigInt::from(modulus.clone());
107
108        for _ in 0..sample_count {
109            let x = gen_biguint(&mut rng, modulus.bits() - 1);
110            let y = gen_biguint(&mut rng, modulus.bits() - 1);
111            let expected = (&x + &y) % modulus;
112            assert_eq!(arithmetic.add(x.clone(), y.clone()).unwrap(), expected);
113
114            let mut expected =
115                (BigInt::from(x.clone()) - BigInt::from(y.clone())) % &signed_modulus;
116            if expected < BigInt::zero() {
117                expected += &signed_modulus;
118            }
119            let expected = BigUint::try_from(expected).unwrap();
120            assert_eq!(arithmetic.sub(x.clone(), y.clone()).unwrap(), expected);
121
122            let expected = (&x * &y) % modulus;
123            assert_eq!(arithmetic.mul(x, y).unwrap(), expected);
124        }
125
126        for _ in 0..sample_count {
127            let x = gen_biguint(&mut rng, modulus.bits());
128            let inv = arithmetic.div(BigUint::one(), x.clone());
129            if (&x % modulus).is_zero() {
130                // Quite unlikely, but better be safe than sorry.
131                assert!(inv.is_err());
132            } else {
133                let inv = inv.unwrap();
134                assert_eq!((inv * &x) % modulus, BigUint::one());
135            }
136        }
137
138        for _ in 0..(sample_count / 10) {
139            let x = gen_biguint(&mut rng, modulus.bits());
140
141            // Check a random small exponent.
142            let exp = rng.gen_range(1_u64..1_000);
143            let expected_pow = (0..exp).fold(BigUint::one(), |acc, _| (acc * &x) % modulus);
144            assert_eq!(
145                arithmetic.pow(x.clone(), BigUint::from(exp)).unwrap(),
146                expected_pow
147            );
148
149            if !(&x % modulus).is_zero() {
150                // Check Fermat's little theorem.
151                let pow = arithmetic.pow(x, modulus - 1_u32).unwrap();
152                assert_eq!(pow, BigUint::one());
153            }
154        }
155    }
156
157    // Primes taken from https://bigprimes.org/
158
159    #[test]
160    fn mini_fuzz_for_128_bit_prime_modulus() {
161        let modulus = "904717851509176637007209984924163038177";
162        mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 10_000);
163    }
164
165    #[test]
166    fn mini_fuzz_for_256_bit_prime_modulus() {
167        let modulus =
168            "35383204059922826862591333932184957269284020569026927321130404396066349029943";
169        mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 5_000);
170    }
171
172    #[test]
173    fn mini_fuzz_for_384_bit_prime_modulus() {
174        let modulus =
175            "680077592003957715873956706738577254635634257392753873876268782486415186187701100959\
176             54501183649227109037342431341197";
177        mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 2_000);
178    }
179
180    #[test]
181    fn mini_fuzz_for_512_bit_prime_modulus() {
182        let modulus =
183            "134956060831834915306923365068985449378393338769474235719041178417311022526812045709\
184             1169866466743447386864273902296614844109589811099153700965207136981133";
185        mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 2_000);
186    }
187}