modular_math/mod_math/
mod_math.rs

1use primitive_types::{U256, U512};
2
3/// `ModMath` is a struct that provides modular arithmetic operations.
4///
5/// It operates on unsigned 256-bit integers (`U256`) and performs operations under a given modulus.
6/// The modulus is provided when creating a new `ModMath` instance and cannot be zero.
7pub struct ModMath {
8    modulus: U256,
9}
10
11impl ModMath {
12    /// Creates a new `ModMath` instance with the given modulus.
13    ///
14    /// # Panics
15    ///
16    /// Panics if the modulus is zero.
17    pub fn new<T: IntoU256>(modulus: T) -> Self {
18        let modulus = modulus.into_u256();
19        if modulus == U256::zero() {
20            panic!("Modulus Cannot be Zero");
21        }
22        ModMath {
23            modulus
24        }
25    }
26
27    pub fn modulus<T: IntoU256>(&self, a: T) -> U256 {
28        a.into_u256() % self.modulus
29    }
30
31    /// Adds two `U256` numbers under the modulus.
32    pub fn add<T: IntoU256>(&self, a: T, b: T) -> U256 {
33        let a = a.into_u256();
34        let b = b.into_u256();
35        match a.checked_add(b) {
36            Some(sum) => sum % self.modulus,
37            None => {
38                let a_512 = U512::from(a);
39                let b_512 = U512::from(b);
40                let modulus_512 = U512::from(self.modulus);
41                let result = (a_512 + b_512) % modulus_512;
42
43                ModMath::u512_to_u256(result)
44            }
45        }
46    }
47
48    /// Subtracts the second `U256` number from the first one under the modulus.
49    pub fn sub<T: IntoU256>(&self, a: T, b: T) -> U256 {
50        let a = a.into_u256();
51        let b = b.into_u256();
52        if b > a {
53            // (self.modulus + a - b) % self.modulus
54            match self.modulus.checked_add(a) {
55                Some(sum) => (sum - b) % self.modulus,
56                None => {
57                    let a_512 = U512::from(a);
58                    let b_512 = U512::from(b);
59                    let modulus_512 = U512::from(self.modulus);
60                    let result = (modulus_512 + a_512 - b_512) % modulus_512;
61
62                    ModMath::u512_to_u256(result)
63                }
64            }
65        } else {
66            (a - b) % self.modulus
67        }
68    }
69
70    /// Multiplies two `U256` numbers under the modulus.
71    pub fn mul<T: IntoU256>(&self, a: T, b: T) -> U256 {
72        let a_mod = a.into_u256() % self.modulus;
73        let b_mod = b.into_u256() % self.modulus;
74    
75        // Use checked_mul for safe multiplication
76        match a_mod.checked_mul(b_mod) {
77            Some(product) => product % self.modulus,
78            None => {
79                let a_mod_u512 = U512::from(a_mod);
80                let b_mod_u512 = U512::from(b_mod);
81                let result  = a_mod_u512 * b_mod_u512 % U512::from(self.modulus);
82
83                ModMath::u512_to_u256(result)
84            },
85        }
86    }
87    
88
89    /// Raises the base to the power of the exponent under the modulus.
90    pub fn exp<T: IntoU256>(&self, base: T, exponent: T) -> U256 {
91        let mut result = U256::one();
92        let mut base = base.into_u256() % self.modulus;
93        let mut exponent = exponent.into_u256();
94        while exponent != U256::zero() {
95            if exponent % U256::from(2) != U256::zero() {
96                result = self.mul(result, base)
97            }
98            base = self.square(base);
99            exponent /= U256::from(2);
100        }
101        result
102    }
103
104    /// Calculates the modular multiplicative inverse of a `U256` number under the modulus.
105    ///
106    /// Returns `None` if the inverse does not exist.
107    pub fn inv<T: IntoU256>(&self, a: T) -> Option<U256> {
108        let (mut m, mut x0, mut x1) = (self.modulus, U256::zero(), U256::one());
109        let mut a = a.into_u256() % self.modulus;
110        if self.modulus == U256::one() {
111            return None;
112        }
113    
114        while a > U256::one() {
115            let q = a / m;
116            let mut temp = m;
117    
118            m = a % m;
119            a = temp;
120            temp = x0;
121            let t = self.mul(q, x0);
122            x0 = self.sub(x1, t);
123            x1 = temp;
124        }
125    
126        if x1 < U256::zero() {
127            x1 = self.add(x1, self.modulus);
128        }
129    
130        if a != U256::one() {
131            None
132        } else {
133            Some(x1)
134        }
135    }
136
137    /// Divides the first `U256` number by the second one under the modulus.
138    ///
139    /// # Panics
140    ///
141    /// Panics if the second number is zero or if its inverse does not exist under the modulus.
142    pub fn div<T: IntoU256>(&self, a: T, b: T) -> U256 {
143        let b = b.into_u256();
144        let b_inv = self.inv(b).unwrap_or_else(|| {
145            panic!("Cannot find Inverse of {}", b);
146        });
147         self.mul(a.into_u256(), b_inv)
148    }
149
150    /// Calculates the additive inverse of a given `U256` under modulus
151    pub fn add_inv<T: IntoU256>(&self, a: T) -> U256 {
152      let a = a.into_u256();
153      if a == U256::zero() {
154        U256::zero()
155      } else {
156        self.modulus - a
157      }
158    }
159    
160    /// Checks if two `U256` numbers are equivalent under the modulus.
161    pub fn eq<T: IntoU256>(&self, a: T, b: T) -> bool {
162        a.into_u256() % self.modulus == b.into_u256() % self.modulus
163    }
164
165    /// Squares a given U256 number under modulus
166    pub fn square<T: IntoU256>(&self, a: T) -> U256 {
167        let a = a.into_u256();
168        self.mul(a, a)
169    }
170
171    fn u512_to_u256(result: U512) -> U256 {
172        let mut result_little_endian = [0_u8; 64];
173        result.to_little_endian(&mut result_little_endian);
174        U256::from_little_endian(&result_little_endian[..32])
175    }
176
177    /// Find the square root of a given `U256` under modulus using tonelli-shanks algorithm
178    /// returns None if no sqrt exists
179    pub fn sqrt<T: IntoU256>(&self, a: T) -> Option<U256> {
180       
181       let a = a.into_u256();
182
183       if self.modulus % U256::from(4) == U256::from(3) { // p = 4k + 3
184        let exponent = Self::floor_div(self.modulus + U256::one(), U256::from(4));
185        return Some(self.exp(a, exponent));
186       } else {
187        // Tonelli Shanks Algorithm
188        return self.tonelli_shanks(a);
189       }
190    }
191
192    fn floor_div(a: U256, b: U256) -> U256 {
193        assert!(b != U256::zero(), "Division by zero error");
194        let div = a / b;
195        if a % b != U256::zero() && (a < U256::zero()) != (b < U256::zero()) {
196            div - U256::one()
197        } else {
198            div
199        }
200    }
201
202    // utility function to find gcd 
203    fn gcd(a: U256, b: U256) -> U256 {
204        if b == U256::zero() {
205            return a;
206        } else {
207            return Self::gcd(b, a % b)
208        }
209    }
210
211    // Returns k such that a^k = 1 (mod p)
212    fn order(&self, a: U256) -> Option<U256> {
213        if Self::gcd(a, self.modulus) != U256::one() {
214            return None;
215        }
216
217        let mut k = U256::one();
218        loop {
219            if self.exp(a, k) == U256::one() {
220                return Some(k);
221            }
222            k += U256::one();
223        }
224    }
225
226    fn convertx2e(mut x: U256) -> (U256, U256) {
227        let mut z = U256::zero();
228        while x % U256::from(2) == U256::zero() {
229            x = x / U256::from(2);
230            z += U256::one();
231        } 
232        (x, z)
233    }
234
235    fn legendre_symbol(&self, a: U256) -> i32 {
236        let exponent = (self.modulus - U256::one()) / U256::from(2);
237        let result = self.exp(a, exponent);
238        
239        if result == U256::one() {
240            1
241        } else if result == U256::zero() {
242            0
243        } else {
244            -1
245        }
246    }
247
248    fn tonelli_shanks(&self, a: U256) -> Option<U256> {
249        
250        if self.modulus == U256::from(2) {
251            return Some(a)
252        }
253
254        if Self::gcd(a, self.modulus) != U256::one() {
255            return None
256        }
257
258        match self.legendre_symbol(a) {
259            -1 => return None,
260            0 => return Some(U256::zero()),
261            _ => (),
262        }
263
264        let (s, e) = Self::convertx2e(self.modulus - U256::one());
265        let mut q = U256::from(2);
266
267        loop {
268            let exponent = (self.modulus - U256::one()) / U256::from(2);
269            if self.exp(q, exponent) == self.modulus - U256::one() {
270                break;
271            }
272            q += U256::one();
273        }
274
275        let exp_a = (s + U256::one()) / U256::from(2);
276        let mut x = self.exp(a, exp_a);
277        let mut b = self.exp(a, s);
278        let mut g = self.exp(q, s);
279
280        let mut r = e;
281
282        loop {
283            let mut m = U256::zero();
284
285            while (m < r) {
286                if self.order(b).is_none() {
287                    return None
288                }
289
290                if self.order(b).unwrap() == U256::from(2).pow(m) {
291                    break;
292                }
293                m += U256::one();
294            }
295
296            if m == U256::zero() {
297                return Some(x);
298            }
299
300            let exp_x = self.exp(U256::from(2), r - m - U256::one());
301            x = self.mul(x, self.exp(g, exp_x));
302            
303            let exp_g = self.exp(U256::from(2), r - m);
304            g = self.exp(g, exp_g);
305            b = self.mul(b, g);
306
307            if b == U256::one() {
308                return Some(x);
309            }
310            r = m;
311        }
312
313
314    }
315
316    
317}
318
319
320pub trait IntoU256 {
321    fn into_u256(self) -> U256;
322}
323
324impl IntoU256 for u32 {
325    fn into_u256(self) -> U256 {
326        U256::from(self)
327    }
328}
329
330impl IntoU256 for i32 {
331    fn into_u256(self) -> U256 {
332        if self < 0 {
333            panic!("Negative value cannot be converted to U256");
334        }
335        U256::from(self as u32)  // Safe cast since the value is non-negative
336    }
337}
338
339impl IntoU256 for u64 {
340    fn into_u256(self) -> U256 {
341        U256::from(self)
342    }
343}
344
345impl IntoU256 for i64 {
346    fn into_u256(self) -> U256 {
347        if self < 0 {
348            panic!("Negative value cannot be converted to U256");
349        }
350        U256::from(self as u64)  // Safe cast since the value is non-negative
351    }
352}
353
354impl IntoU256 for &str {
355    fn into_u256(self) -> U256 {
356        U256::from_dec_str(self).unwrap()
357    }
358}
359
360impl IntoU256 for U256 {
361    fn into_u256(self) -> U256 {
362        self
363    }
364}