modular_math/mod_math/
mod_math.rs1use primitive_types::{U256, U512};
2
3pub struct ModMath {
8 modulus: U256,
9}
10
11impl ModMath {
12 pub fn new<T: IntoU256>(modulus: T) -> Self {
18 let modulus = modulus.into_u256();
19 if modulus == U256::zero() {
20 panic!("Modulus Cannot be Zero");
21 }
22 ModMath {
23 modulus
24 }
25 }
26
27 pub fn modulus<T: IntoU256>(&self, a: T) -> U256 {
28 a.into_u256() % self.modulus
29 }
30
31 pub fn add<T: IntoU256>(&self, a: T, b: T) -> U256 {
33 let a = a.into_u256();
34 let b = b.into_u256();
35 match a.checked_add(b) {
36 Some(sum) => sum % self.modulus,
37 None => {
38 let a_512 = U512::from(a);
39 let b_512 = U512::from(b);
40 let modulus_512 = U512::from(self.modulus);
41 let result = (a_512 + b_512) % modulus_512;
42
43 ModMath::u512_to_u256(result)
44 }
45 }
46 }
47
48 pub fn sub<T: IntoU256>(&self, a: T, b: T) -> U256 {
50 let a = a.into_u256();
51 let b = b.into_u256();
52 if b > a {
53 match self.modulus.checked_add(a) {
55 Some(sum) => (sum - b) % self.modulus,
56 None => {
57 let a_512 = U512::from(a);
58 let b_512 = U512::from(b);
59 let modulus_512 = U512::from(self.modulus);
60 let result = (modulus_512 + a_512 - b_512) % modulus_512;
61
62 ModMath::u512_to_u256(result)
63 }
64 }
65 } else {
66 (a - b) % self.modulus
67 }
68 }
69
70 pub fn mul<T: IntoU256>(&self, a: T, b: T) -> U256 {
72 let a_mod = a.into_u256() % self.modulus;
73 let b_mod = b.into_u256() % self.modulus;
74
75 match a_mod.checked_mul(b_mod) {
77 Some(product) => product % self.modulus,
78 None => {
79 let a_mod_u512 = U512::from(a_mod);
80 let b_mod_u512 = U512::from(b_mod);
81 let result = a_mod_u512 * b_mod_u512 % U512::from(self.modulus);
82
83 ModMath::u512_to_u256(result)
84 },
85 }
86 }
87
88
89 pub fn exp<T: IntoU256>(&self, base: T, exponent: T) -> U256 {
91 let mut result = U256::one();
92 let mut base = base.into_u256() % self.modulus;
93 let mut exponent = exponent.into_u256();
94 while exponent != U256::zero() {
95 if exponent % U256::from(2) != U256::zero() {
96 result = self.mul(result, base)
97 }
98 base = self.square(base);
99 exponent /= U256::from(2);
100 }
101 result
102 }
103
104 pub fn inv<T: IntoU256>(&self, a: T) -> Option<U256> {
108 let (mut m, mut x0, mut x1) = (self.modulus, U256::zero(), U256::one());
109 let mut a = a.into_u256() % self.modulus;
110 if self.modulus == U256::one() {
111 return None;
112 }
113
114 while a > U256::one() {
115 let q = a / m;
116 let mut temp = m;
117
118 m = a % m;
119 a = temp;
120 temp = x0;
121 let t = self.mul(q, x0);
122 x0 = self.sub(x1, t);
123 x1 = temp;
124 }
125
126 if x1 < U256::zero() {
127 x1 = self.add(x1, self.modulus);
128 }
129
130 if a != U256::one() {
131 None
132 } else {
133 Some(x1)
134 }
135 }
136
137 pub fn div<T: IntoU256>(&self, a: T, b: T) -> U256 {
143 let b = b.into_u256();
144 let b_inv = self.inv(b).unwrap_or_else(|| {
145 panic!("Cannot find Inverse of {}", b);
146 });
147 self.mul(a.into_u256(), b_inv)
148 }
149
150 pub fn add_inv<T: IntoU256>(&self, a: T) -> U256 {
152 let a = a.into_u256();
153 if a == U256::zero() {
154 U256::zero()
155 } else {
156 self.modulus - a
157 }
158 }
159
160 pub fn eq<T: IntoU256>(&self, a: T, b: T) -> bool {
162 a.into_u256() % self.modulus == b.into_u256() % self.modulus
163 }
164
165 pub fn square<T: IntoU256>(&self, a: T) -> U256 {
167 let a = a.into_u256();
168 self.mul(a, a)
169 }
170
171 fn u512_to_u256(result: U512) -> U256 {
172 let mut result_little_endian = [0_u8; 64];
173 result.to_little_endian(&mut result_little_endian);
174 U256::from_little_endian(&result_little_endian[..32])
175 }
176
177 pub fn sqrt<T: IntoU256>(&self, a: T) -> Option<U256> {
180
181 let a = a.into_u256();
182
183 if self.modulus % U256::from(4) == U256::from(3) { let exponent = Self::floor_div(self.modulus + U256::one(), U256::from(4));
185 return Some(self.exp(a, exponent));
186 } else {
187 return self.tonelli_shanks(a);
189 }
190 }
191
192 fn floor_div(a: U256, b: U256) -> U256 {
193 assert!(b != U256::zero(), "Division by zero error");
194 let div = a / b;
195 if a % b != U256::zero() && (a < U256::zero()) != (b < U256::zero()) {
196 div - U256::one()
197 } else {
198 div
199 }
200 }
201
202 fn gcd(a: U256, b: U256) -> U256 {
204 if b == U256::zero() {
205 return a;
206 } else {
207 return Self::gcd(b, a % b)
208 }
209 }
210
211 fn order(&self, a: U256) -> Option<U256> {
213 if Self::gcd(a, self.modulus) != U256::one() {
214 return None;
215 }
216
217 let mut k = U256::one();
218 loop {
219 if self.exp(a, k) == U256::one() {
220 return Some(k);
221 }
222 k += U256::one();
223 }
224 }
225
226 fn convertx2e(mut x: U256) -> (U256, U256) {
227 let mut z = U256::zero();
228 while x % U256::from(2) == U256::zero() {
229 x = x / U256::from(2);
230 z += U256::one();
231 }
232 (x, z)
233 }
234
235 fn legendre_symbol(&self, a: U256) -> i32 {
236 let exponent = (self.modulus - U256::one()) / U256::from(2);
237 let result = self.exp(a, exponent);
238
239 if result == U256::one() {
240 1
241 } else if result == U256::zero() {
242 0
243 } else {
244 -1
245 }
246 }
247
248 fn tonelli_shanks(&self, a: U256) -> Option<U256> {
249
250 if self.modulus == U256::from(2) {
251 return Some(a)
252 }
253
254 if Self::gcd(a, self.modulus) != U256::one() {
255 return None
256 }
257
258 match self.legendre_symbol(a) {
259 -1 => return None,
260 0 => return Some(U256::zero()),
261 _ => (),
262 }
263
264 let (s, e) = Self::convertx2e(self.modulus - U256::one());
265 let mut q = U256::from(2);
266
267 loop {
268 let exponent = (self.modulus - U256::one()) / U256::from(2);
269 if self.exp(q, exponent) == self.modulus - U256::one() {
270 break;
271 }
272 q += U256::one();
273 }
274
275 let exp_a = (s + U256::one()) / U256::from(2);
276 let mut x = self.exp(a, exp_a);
277 let mut b = self.exp(a, s);
278 let mut g = self.exp(q, s);
279
280 let mut r = e;
281
282 loop {
283 let mut m = U256::zero();
284
285 while (m < r) {
286 if self.order(b).is_none() {
287 return None
288 }
289
290 if self.order(b).unwrap() == U256::from(2).pow(m) {
291 break;
292 }
293 m += U256::one();
294 }
295
296 if m == U256::zero() {
297 return Some(x);
298 }
299
300 let exp_x = self.exp(U256::from(2), r - m - U256::one());
301 x = self.mul(x, self.exp(g, exp_x));
302
303 let exp_g = self.exp(U256::from(2), r - m);
304 g = self.exp(g, exp_g);
305 b = self.mul(b, g);
306
307 if b == U256::one() {
308 return Some(x);
309 }
310 r = m;
311 }
312
313
314 }
315
316
317}
318
319
320pub trait IntoU256 {
321 fn into_u256(self) -> U256;
322}
323
324impl IntoU256 for u32 {
325 fn into_u256(self) -> U256 {
326 U256::from(self)
327 }
328}
329
330impl IntoU256 for i32 {
331 fn into_u256(self) -> U256 {
332 if self < 0 {
333 panic!("Negative value cannot be converted to U256");
334 }
335 U256::from(self as u32) }
337}
338
339impl IntoU256 for u64 {
340 fn into_u256(self) -> U256 {
341 U256::from(self)
342 }
343}
344
345impl IntoU256 for i64 {
346 fn into_u256(self) -> U256 {
347 if self < 0 {
348 panic!("Negative value cannot be converted to U256");
349 }
350 U256::from(self as u64) }
352}
353
354impl IntoU256 for &str {
355 fn into_u256(self) -> U256 {
356 U256::from_dec_str(self).unwrap()
357 }
358}
359
360impl IntoU256 for U256 {
361 fn into_u256(self) -> U256 {
362 self
363 }
364}