arithmetic_eval/arith/
bigint.rs1use num_bigint::{BigInt, BigUint};
2use num_traits::{One, Signed, Zero};
3
4use core::{convert::TryFrom, mem};
5
6use super::{Arithmetic, ArithmeticError, ModularArithmetic};
7
8impl ModularArithmetic<BigUint> {
9 fn invert_big(&self, value: BigUint) -> Option<BigUint> {
10 let value = value % &self.modulus; let mut t = BigInt::zero();
12 let mut new_t = BigInt::one();
13
14 let modulus = BigInt::from(self.modulus.clone());
15 let mut r = modulus.clone();
16 let mut new_r = BigInt::from(value);
17
18 while !new_r.is_zero() {
19 let quotient = &r / &new_r;
20 t -= "ient * &new_t;
21 mem::swap(&mut new_t, &mut t);
22 r -= quotient * &new_r;
23 mem::swap(&mut new_r, &mut r);
24 }
25
26 if r > BigInt::one() {
27 None } else {
29 if t.is_negative() {
30 t += modulus;
31 }
32 Some(BigUint::try_from(t).unwrap())
33 }
35 }
36}
37
38impl Arithmetic<BigUint> for ModularArithmetic<BigUint> {
39 fn add(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
40 Ok((x + y) % &self.modulus)
41 }
42
43 fn sub(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
44 let y_neg = &self.modulus - (y % &self.modulus);
45 self.add(x, y_neg)
46 }
47
48 fn mul(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
49 Ok((x * y) % &self.modulus)
50 }
51
52 fn div(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
53 if y.is_zero() {
54 Err(ArithmeticError::DivisionByZero)
55 } else {
56 let y_inv = self.invert_big(y).ok_or(ArithmeticError::NoInverse)?;
57 self.mul(x, y_inv)
58 }
59 }
60
61 fn pow(&self, x: BigUint, y: BigUint) -> Result<BigUint, ArithmeticError> {
62 Ok(x.modpow(&y, &self.modulus))
63 }
64
65 fn neg(&self, x: BigUint) -> Result<BigUint, ArithmeticError> {
66 let x = x % &self.modulus;
67 Ok(&self.modulus - x)
68 }
69
70 fn eq(&self, x: &BigUint, y: &BigUint) -> bool {
71 x % &self.modulus == y % &self.modulus
72 }
73}
74
75#[cfg(test)]
76mod bigint_tests {
77 use super::*;
78 use crate::arith::{CheckedArithmetic, NegateOnlyZero, OrdArithmetic, Unchecked};
79
80 use num_bigint::{BigInt, BigUint};
81 use rand::{rngs::StdRng, Rng, SeedableRng};
82 use static_assertions::assert_impl_all;
83
84 assert_impl_all!(CheckedArithmetic<NegateOnlyZero>: OrdArithmetic<BigUint>);
85 assert_impl_all!(CheckedArithmetic<Unchecked>: OrdArithmetic<BigInt>);
86 assert_impl_all!(ModularArithmetic<BigUint>: Arithmetic<BigUint>);
87
88 fn gen_biguint<R: Rng>(rng: &mut R, bits: u64) -> BigUint {
89 let bits = usize::try_from(bits).expect("Capacity overflow");
90 let (div, rem) = (bits / 8, bits % 8);
91
92 let mut buffer = vec![0_u8; div + (rem != 0) as usize];
93 rng.fill_bytes(&mut buffer);
94 if rem > 0 {
95 let mask = u8::try_from((1_u16 << rem) - 1).unwrap();
97 buffer[0] &= mask;
98 }
99
100 BigUint::from_bytes_be(&buffer)
101 }
102
103 fn mini_fuzz_for_big_prime_modulus(modulus: &BigUint, sample_count: usize) {
104 let arithmetic = ModularArithmetic::new(modulus.clone());
105 let mut rng = StdRng::seed_from_u64(modulus.bits());
106 let signed_modulus = BigInt::from(modulus.clone());
107
108 for _ in 0..sample_count {
109 let x = gen_biguint(&mut rng, modulus.bits() - 1);
110 let y = gen_biguint(&mut rng, modulus.bits() - 1);
111 let expected = (&x + &y) % modulus;
112 assert_eq!(arithmetic.add(x.clone(), y.clone()).unwrap(), expected);
113
114 let mut expected =
115 (BigInt::from(x.clone()) - BigInt::from(y.clone())) % &signed_modulus;
116 if expected < BigInt::zero() {
117 expected += &signed_modulus;
118 }
119 let expected = BigUint::try_from(expected).unwrap();
120 assert_eq!(arithmetic.sub(x.clone(), y.clone()).unwrap(), expected);
121
122 let expected = (&x * &y) % modulus;
123 assert_eq!(arithmetic.mul(x, y).unwrap(), expected);
124 }
125
126 for _ in 0..sample_count {
127 let x = gen_biguint(&mut rng, modulus.bits());
128 let inv = arithmetic.div(BigUint::one(), x.clone());
129 if (&x % modulus).is_zero() {
130 assert!(inv.is_err());
132 } else {
133 let inv = inv.unwrap();
134 assert_eq!((inv * &x) % modulus, BigUint::one());
135 }
136 }
137
138 for _ in 0..(sample_count / 10) {
139 let x = gen_biguint(&mut rng, modulus.bits());
140
141 let exp = rng.gen_range(1_u64..1_000);
143 let expected_pow = (0..exp).fold(BigUint::one(), |acc, _| (acc * &x) % modulus);
144 assert_eq!(
145 arithmetic.pow(x.clone(), BigUint::from(exp)).unwrap(),
146 expected_pow
147 );
148
149 if !(&x % modulus).is_zero() {
150 let pow = arithmetic.pow(x, modulus - 1_u32).unwrap();
152 assert_eq!(pow, BigUint::one());
153 }
154 }
155 }
156
157 #[test]
160 fn mini_fuzz_for_128_bit_prime_modulus() {
161 let modulus = "904717851509176637007209984924163038177";
162 mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 10_000);
163 }
164
165 #[test]
166 fn mini_fuzz_for_256_bit_prime_modulus() {
167 let modulus =
168 "35383204059922826862591333932184957269284020569026927321130404396066349029943";
169 mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 5_000);
170 }
171
172 #[test]
173 fn mini_fuzz_for_384_bit_prime_modulus() {
174 let modulus =
175 "680077592003957715873956706738577254635634257392753873876268782486415186187701100959\
176 54501183649227109037342431341197";
177 mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 2_000);
178 }
179
180 #[test]
181 fn mini_fuzz_for_512_bit_prime_modulus() {
182 let modulus =
183 "134956060831834915306923365068985449378393338769474235719041178417311022526812045709\
184 1169866466743447386864273902296614844109589811099153700965207136981133";
185 mini_fuzz_for_big_prime_modulus(&modulus.parse().unwrap(), 2_000);
186 }
187}