Skip to main content

cairo_vm/math_utils/
mod.rs

1mod is_prime;
2
3pub use is_prime::is_prime;
4
5use core::cmp::min;
6
7use crate::types::errors::math_errors::MathError;
8use crate::utils::CAIRO_PRIME;
9use crate::Felt252;
10use lazy_static::lazy_static;
11use num_bigint::{BigInt, BigUint, RandBigInt, ToBigInt};
12use num_integer::Integer;
13use num_traits::{One, Signed, Zero};
14use rand::{rngs::SmallRng, SeedableRng};
15use starknet_types_core::felt::NonZeroFelt;
16use std::ops::Shr;
17
18lazy_static! {
19    pub static ref SIGNED_FELT_MAX: BigUint = (&*CAIRO_PRIME).shr(1_u32);
20    static ref POWERS_OF_TWO: Vec<NonZeroFelt> =
21        core::iter::successors(Some(Felt252::ONE), |x| Some(x * Felt252::TWO))
22            .take(252)
23            .map(|x| x.try_into().unwrap())
24            .collect::<Vec<_>>();
25}
26
27pub const STWO_PRIME: u32 = (1 << 31) - 1;
28
29/// Packs a QM31 element into a Felt252, canonicalizing M31 coordinates.
30/// Lambdaworks' Mersenne31 arithmetic can represent zero as STWO_PRIME
31/// instead of 0. This function normalizes coordinates before packing
32/// so that `QM31::unpack_from_felt` can accept the result.
33pub(crate) fn qm31_pack_reduced(qm31: starknet_types_core::qm31::QM31) -> Felt252 {
34    let (a, b, c, d) = qm31.to_coefficients();
35    starknet_types_core::qm31::QM31::from_coefficients(
36        a % STWO_PRIME,
37        b % STWO_PRIME,
38        c % STWO_PRIME,
39        d % STWO_PRIME,
40    )
41    .pack_into_felt()
42}
43
44/// Returns the `n`th (up to the `251`th power) power of 2 as a [`Felt252`]
45/// in constant time.
46/// It silently returns `1` if the input is out of bounds.
47pub fn pow2_const(n: u32) -> Felt252 {
48    // If the conversion fails then it's out of range and we compute the power as usual
49    POWERS_OF_TWO
50        .get(n as usize)
51        .unwrap_or(&POWERS_OF_TWO[0])
52        .into()
53}
54
55/// Returns the `n`th (up to the `251`th power) power of 2 as a [`&stark_felt::NonZeroFelt`]
56/// in constant time.
57/// It silently returns `1` if the input is out of bounds.
58pub fn pow2_const_nz(n: u32) -> &'static NonZeroFelt {
59    // If the conversion fails then it's out of range and we compute the power as usual
60    POWERS_OF_TWO.get(n as usize).unwrap_or(&POWERS_OF_TWO[0])
61}
62
63/// Converts [`Felt252`] into a [`BigInt`] number in the range: `(- FIELD / 2, FIELD / 2)`.
64///
65/// # Examples
66///
67/// ```
68/// # use cairo_vm::{Felt252, math_utils::signed_felt};
69/// # use num_bigint::BigInt;
70/// let positive = Felt252::from(5);
71/// assert_eq!(signed_felt(positive), BigInt::from(5));
72///
73/// let negative = Felt252::MAX;
74/// assert_eq!(signed_felt(negative), BigInt::from(-1));
75/// ```
76pub fn signed_felt(felt: Felt252) -> BigInt {
77    let biguint = felt.to_biguint();
78    if biguint > *SIGNED_FELT_MAX {
79        BigInt::from_biguint(num_bigint::Sign::Minus, &*CAIRO_PRIME - &biguint)
80    } else {
81        biguint.to_bigint().expect("cannot fail")
82    }
83}
84
85pub fn signed_felt_for_prime(value: Felt252, prime: &BigUint) -> BigInt {
86    let value = value.to_biguint();
87    let half_prime = prime / 2u32;
88    if value > half_prime {
89        BigInt::from_biguint(num_bigint::Sign::Minus, prime - &value)
90    } else {
91        BigInt::from_biguint(num_bigint::Sign::Plus, value)
92    }
93}
94
95///Returns the integer square root of the nonnegative integer n.
96///This is the floor of the exact square root of n.
97///Unlike math.sqrt(), this function doesn't have rounding error issues.
98pub fn isqrt(n: &BigUint) -> Result<BigUint, MathError> {
99    /*    # The following algorithm was copied from
100    # https://stackoverflow.com/questions/15390807/integer-square-root-in-python.
101    x = n
102    y = (x + 1) // 2
103    while y < x:
104        x = y
105        y = (x + n // x) // 2
106    assert x**2 <= n < (x + 1) ** 2
107    return x*/
108
109    let mut x = n.clone();
110    //n.shr(1) = n.div_floor(2)
111    let mut y = (&x + 1_u32).shr(1_u32);
112
113    while y < x {
114        x = y;
115        y = (&x + n.div_floor(&x)).shr(1_u32);
116    }
117
118    if !(&BigUint::pow(&x, 2_u32) <= n && n < &BigUint::pow(&(&x + 1_u32), 2_u32)) {
119        return Err(MathError::FailedToGetSqrt(Box::new(n.clone())));
120    };
121    Ok(x)
122}
123
124/// Performs integer division between x and y; fails if x is not divisible by y.
125pub fn safe_div(x: &Felt252, y: &Felt252) -> Result<Felt252, MathError> {
126    let (q, r) = x.div_rem(&y.try_into().map_err(|_| MathError::DividedByZero)?);
127
128    if !r.is_zero() {
129        Err(MathError::SafeDivFail(Box::new((*x, *y))))
130    } else {
131        Ok(q)
132    }
133}
134
135/// Performs integer division between x and y; fails if x is not divisible by y.
136pub fn safe_div_bigint(x: &BigInt, y: &BigInt) -> Result<BigInt, MathError> {
137    if y.is_zero() {
138        return Err(MathError::DividedByZero);
139    }
140
141    let (q, r) = x.div_mod_floor(y);
142
143    if !r.is_zero() {
144        return Err(MathError::SafeDivFailBigInt(Box::new((
145            x.clone(),
146            y.clone(),
147        ))));
148    }
149
150    Ok(q)
151}
152
153/// Performs integer division between x and y; fails if x is not divisible by y.
154pub fn safe_div_usize(x: usize, y: usize) -> Result<usize, MathError> {
155    if y.is_zero() {
156        return Err(MathError::DividedByZero);
157    }
158
159    let (q, r) = x.div_mod_floor(&y);
160
161    if !r.is_zero() {
162        return Err(MathError::SafeDivFailUsize(Box::new((x, y))));
163    }
164
165    Ok(q)
166}
167
168///Returns num_a^-1 mod p
169pub(crate) fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt {
170    if num_a.is_zero() {
171        return BigInt::zero();
172    }
173    let mut a = num_a.abs();
174    let x_sign = num_a.signum();
175    let mut b = p.abs();
176    let (mut x, mut r) = (BigInt::one(), BigInt::zero());
177    let (mut c, mut q);
178    while !b.is_zero() {
179        (q, c) = a.div_mod_floor(&b);
180        x -= &q * &r;
181        (a, b, r, x) = (b, c, x, r)
182    }
183
184    x * x_sign
185}
186
187///Returns x, y, g such that g = x*a + y*b = gcd(a, b).
188fn igcdex(num_a: &BigInt, num_b: &BigInt) -> (BigInt, BigInt, BigInt) {
189    match (num_a, num_b) {
190        (a, b) if a.is_zero() && b.is_zero() => (BigInt::zero(), BigInt::one(), BigInt::zero()),
191        (a, _) if a.is_zero() => (BigInt::zero(), num_b.signum(), num_b.abs()),
192        (_, b) if b.is_zero() => (num_a.signum(), BigInt::zero(), num_a.abs()),
193        _ => {
194            let mut a = num_a.abs();
195            let x_sign = num_a.signum();
196            let mut b = num_b.abs();
197            let y_sign = num_b.signum();
198            let (mut x, mut y, mut r, mut s) =
199                (BigInt::one(), BigInt::zero(), BigInt::zero(), BigInt::one());
200            let (mut c, mut q);
201            while !b.is_zero() {
202                (q, c) = a.div_mod_floor(&b);
203                x -= &q * &r;
204                y -= &q * &s;
205                (a, b, r, s, x, y) = (b, c, x, y, r, s)
206            }
207            (x * x_sign, y * y_sign, a)
208        }
209    }
210}
211
212///Finds a nonnegative integer x < p such that (m * x) % p == n.
213pub fn div_mod(n: &BigInt, m: &BigInt, p: &BigInt) -> Result<BigInt, MathError> {
214    let (a, _, c) = igcdex(m, p);
215    if !c.is_one() {
216        return Err(MathError::DivModIgcdexNotZero(Box::new((
217            n.clone(),
218            m.clone(),
219            p.clone(),
220        ))));
221    }
222    Ok((n * a).mod_floor(p))
223}
224
225pub(crate) fn div_mod_unsigned(
226    n: &BigUint,
227    m: &BigUint,
228    p: &BigUint,
229) -> Result<BigUint, MathError> {
230    // BigUint to BigInt conversion cannot fail & div_mod will always return a positive value if all values are positive so we can safely unwrap here
231    div_mod(
232        &n.to_bigint().unwrap(),
233        &m.to_bigint().unwrap(),
234        &p.to_bigint().unwrap(),
235    )
236    .map(|i| i.to_biguint().unwrap())
237}
238
239pub fn ec_add(
240    point_a: (BigInt, BigInt),
241    point_b: (BigInt, BigInt),
242    prime: &BigInt,
243) -> Result<(BigInt, BigInt), MathError> {
244    let m = line_slope(&point_a, &point_b, prime)?;
245    let x = (m.clone() * m.clone() - point_a.0.clone() - point_b.0).mod_floor(prime);
246    let y = (m * (point_a.0 - x.clone()) - point_a.1).mod_floor(prime);
247    Ok((x, y))
248}
249
250/// Computes the slope of the line connecting the two given EC points over the field GF(p).
251/// Assumes the points are given in affine form (x, y) and have different x coordinates.
252pub fn line_slope(
253    point_a: &(BigInt, BigInt),
254    point_b: &(BigInt, BigInt),
255    prime: &BigInt,
256) -> Result<BigInt, MathError> {
257    debug_assert!(!(&point_a.0 - &point_b.0).is_multiple_of(prime));
258    div_mod(
259        &(&point_a.1 - &point_b.1),
260        &(&point_a.0 - &point_b.0),
261        prime,
262    )
263}
264
265///  Doubles a point on an elliptic curve with the equation y^2 = x^3 + alpha*x + beta mod p.
266/// Assumes the point is given in affine form (x, y) and has y != 0.
267pub fn ec_double(
268    point: (BigInt, BigInt),
269    alpha: &BigInt,
270    prime: &BigInt,
271) -> Result<(BigInt, BigInt), MathError> {
272    let m = ec_double_slope(&point, alpha, prime)?;
273    let x = ((&m * &m) - (2_i32 * &point.0)).mod_floor(prime);
274    let y = (m * (point.0 - &x) - point.1).mod_floor(prime);
275    Ok((x, y))
276}
277/// Computes the slope of an elliptic curve with the equation y^2 = x^3 + alpha*x + beta mod p, at
278/// the given point.
279/// Assumes the point is given in affine form (x, y) and has y != 0.
280pub fn ec_double_slope(
281    point: &(BigInt, BigInt),
282    alpha: &BigInt,
283    prime: &BigInt,
284) -> Result<BigInt, MathError> {
285    debug_assert!(!point.1.is_multiple_of(prime));
286    div_mod(
287        &(3_i32 * &point.0 * &point.0 + alpha),
288        &(2_i32 * &point.1),
289        prime,
290    )
291}
292
293// Adapted from sympy _sqrt_prime_power with k == 1
294pub fn sqrt_prime_power(a: &BigUint, p: &BigUint) -> Option<BigUint> {
295    if p.is_zero() || !is_prime(p) {
296        return None;
297    }
298    let two = BigUint::from(2_u32);
299    let a = a.mod_floor(p);
300    if p == &two {
301        return Some(a);
302    }
303    if !(a < two || (a.modpow(&(p - 1_u32).div_floor(&two), p)).is_one()) {
304        return None;
305    };
306
307    if p.mod_floor(&BigUint::from(4_u32)) == 3_u32.into() {
308        let res = a.modpow(&(p + 1_u32).div_floor(&BigUint::from(4_u32)), p);
309        return Some(min(res.clone(), p - res));
310    };
311
312    if p.mod_floor(&BigUint::from(8_u32)) == 5_u32.into() {
313        let sign = a.modpow(&(p - 1_u32).div_floor(&BigUint::from(4_u32)), p);
314        if sign.is_one() {
315            let res = a.modpow(&(p + 3_u32).div_floor(&BigUint::from(8_u32)), p);
316            return Some(min(res.clone(), p - res));
317        } else {
318            let b = (4_u32 * &a).modpow(&(p - 5_u32).div_floor(&BigUint::from(8_u32)), p);
319            let x = (2_u32 * &a * b).mod_floor(p);
320            if x.modpow(&two, p) == a {
321                return Some(x);
322            }
323        }
324    };
325
326    Some(sqrt_tonelli_shanks(&a, p))
327}
328
329fn sqrt_tonelli_shanks(n: &BigUint, prime: &BigUint) -> BigUint {
330    // Based on Tonelli-Shanks' algorithm for finding square roots
331    // and sympy's library implementation of said algorithm.
332    if n.is_zero() || n.is_one() {
333        return n.clone();
334    }
335    let s = (prime - 1_u32).trailing_zeros().unwrap_or_default();
336    let t = prime >> s;
337    let a = n.modpow(&t, prime);
338    // Rng is not critical here so its safe to use a seeded value
339    let mut rng = SmallRng::seed_from_u64(11480028852697973135);
340    let mut d;
341    loop {
342        d = RandBigInt::gen_biguint_range(&mut rng, &BigUint::from(2_u32), &(prime - 1_u32));
343        let r = legendre_symbol(&d, prime);
344        if r == -1 {
345            break;
346        };
347    }
348    d = d.modpow(&t, prime);
349    let mut m = BigUint::zero();
350    let mut exponent = BigUint::one() << (s - 1);
351    let mut adm;
352    for i in 0..s as u32 {
353        adm = &a * &d.modpow(&m, prime);
354        adm = adm.modpow(&exponent, prime);
355        exponent >>= 1;
356        if adm == (prime - 1_u32) {
357            m += BigUint::from(1_u32) << i;
358        }
359    }
360    let root_1 =
361        (n.modpow(&((t + 1_u32) >> 1), prime) * d.modpow(&(m >> 1), prime)).mod_floor(prime);
362    let root_2 = prime - &root_1;
363    if root_1 < root_2 {
364        root_1
365    } else {
366        root_2
367    }
368}
369
370/* Disclaimer: Some asumptions have been taken based on the functions that rely on this function, make sure these are true before calling this function individually
371Adpted from sympy implementation, asuming:
372    - p is an odd prime number
373    - a.mod_floor(p) == a
374Returns the Legendre symbol `(a / p)`.
375
376    For an integer ``a`` and an odd prime ``p``, the Legendre symbol is
377    defined as
378
379    .. math ::
380        \genfrac(){}{}{a}{p} = \begin{cases}
381             0 & \text{if } p \text{ divides } a\\
382             1 & \text{if } a \text{ is a quadratic residue modulo } p\\
383            -1 & \text{if } a \text{ is a quadratic nonresidue modulo } p
384        \end{cases}
385*/
386fn legendre_symbol(a: &BigUint, p: &BigUint) -> i8 {
387    if a.is_zero() {
388        return 0;
389    };
390    if is_quad_residue(a, p).unwrap_or_default() {
391        1
392    } else {
393        -1
394    }
395}
396
397// Ported from sympy implementation
398// Simplified as a & p are nonnegative
399// Asumes p is a prime number
400pub(crate) fn is_quad_residue(a: &BigUint, p: &BigUint) -> Result<bool, MathError> {
401    if p.is_zero() {
402        return Err(MathError::IsQuadResidueZeroPrime);
403    }
404    let a = if a >= p { a.mod_floor(p) } else { a.clone() };
405    if a < BigUint::from(2_u8) || p < &BigUint::from(3_u8) {
406        return Ok(true);
407    }
408    Ok(
409        a.modpow(&(p - BigUint::one()).div_floor(&BigUint::from(2_u8)), p)
410            .is_one(),
411    )
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::utils::test_utils::*;
418    use crate::utils::CAIRO_PRIME;
419    use assert_matches::assert_matches;
420
421    use num_traits::Num;
422
423    use num_prime::RandPrime;
424
425    use proptest::prelude::*;
426
427    // Only used in proptest for now
428    use num_bigint::Sign;
429
430    #[test]
431    fn calculate_divmod_a() {
432        let a = bigint_str!(
433            "11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788"
434        );
435        let b = bigint_str!(
436            "4020711254448367604954374443741161860304516084891705811279711044808359405970"
437        );
438        assert_eq!(
439            bigint_str!(
440                "2904750555256547440469454488220756360634457312540595732507835416669695939476"
441            ),
442            div_mod(
443                &a,
444                &b,
445                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
446                    .expect("Couldn't parse prime")
447            )
448            .unwrap()
449        );
450    }
451
452    #[test]
453    fn calculate_divmod_b() {
454        let a = bigint_str!(
455            "29642372811668969595956851264770043260610851505766181624574941701711520154703788233010819515917136995474951116158286220089597404329949295479559895970988"
456        );
457        let b = bigint_str!(
458            "3443173965374276972000139705137775968422921151703548011275075734291405722262"
459        );
460        assert_eq!(
461            bigint_str!(
462                "3601388548860259779932034493250169083811722919049731683411013070523752439691"
463            ),
464            div_mod(
465                &a,
466                &b,
467                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
468                    .expect("Couldn't parse prime")
469            )
470            .unwrap()
471        );
472    }
473
474    #[test]
475    fn calculate_divmod_c() {
476        let a = bigint_str!(
477            "1208267356464811040667664150251401430616174694388968865551115897173431833224432165394286799069453655049199580362994484548890574931604445970825506916876"
478        );
479        let b = bigint_str!(
480            "1809792356889571967986805709823554331258072667897598829955472663737669990418"
481        );
482        assert_eq!(
483            bigint_str!(
484                "1545825591488572374291664030703937603499513742109806697511239542787093258962"
485            ),
486            div_mod(
487                &a,
488                &b,
489                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
490                    .expect("Couldn't parse prime")
491            )
492            .unwrap()
493        );
494    }
495
496    #[test]
497    fn compute_safe_div() {
498        let x = Felt252::from(26);
499        let y = Felt252::from(13);
500        assert_matches!(safe_div(&x, &y), Ok(i) if i == Felt252::from(2));
501    }
502
503    #[test]
504    fn compute_safe_div_non_divisor() {
505        let x = Felt252::from(25);
506        let y = Felt252::from(4);
507        let result = safe_div(&x, &y);
508        assert_matches!(
509            result,
510            Err(MathError::SafeDivFail(bx)) if *bx == (Felt252::from(25), Felt252::from(4)));
511    }
512
513    #[test]
514    fn compute_safe_div_by_zero() {
515        let x = Felt252::from(25);
516        let y = Felt252::ZERO;
517        let result = safe_div(&x, &y);
518        assert_matches!(result, Err(MathError::DividedByZero));
519    }
520
521    #[test]
522    fn compute_safe_div_usize() {
523        assert_matches!(safe_div_usize(26, 13), Ok(2));
524    }
525
526    #[test]
527    fn compute_safe_div_usize_non_divisor() {
528        assert_matches!(
529            safe_div_usize(25, 4),
530            Err(MathError::SafeDivFailUsize(bx)) if *bx == (25, 4)
531        );
532    }
533
534    #[test]
535    fn compute_safe_div_usize_by_zero() {
536        assert_matches!(safe_div_usize(25, 0), Err(MathError::DividedByZero));
537    }
538
539    #[test]
540    fn compute_line_slope_for_valid_points() {
541        let point_a = (
542            bigint_str!(
543                "3139037544796708144595053687182055617920475701120786241351436619796497072089"
544            ),
545            bigint_str!(
546                "2119589567875935397690285099786081818522144748339117565577200220779667999801"
547            ),
548        );
549        let point_b = (
550            bigint_str!(
551                "3324833730090626974525872402899302150520188025637965566623476530814354734325"
552            ),
553            bigint_str!(
554                "3147007486456030910661996439995670279305852583596209647900952752170983517249"
555            ),
556        );
557        let prime = (*CAIRO_PRIME).clone().into();
558        assert_eq!(
559            bigint_str!(
560                "992545364708437554384321881954558327331693627531977596999212637460266617010"
561            ),
562            line_slope(&point_a, &point_b, &prime).unwrap()
563        );
564    }
565
566    #[test]
567    fn compute_double_slope_for_valid_point_a() {
568        let point = (
569            bigint_str!(
570                "3143372541908290873737380228370996772020829254218248561772745122290262847573"
571            ),
572            bigint_str!(
573                "1721586982687138486000069852568887984211460575851774005637537867145702861131"
574            ),
575        );
576        let prime = (*CAIRO_PRIME).clone().into();
577        let alpha = bigint!(1);
578        assert_eq!(
579            bigint_str!(
580                "3601388548860259779932034493250169083811722919049731683411013070523752439691"
581            ),
582            ec_double_slope(&point, &alpha, &prime).unwrap()
583        );
584    }
585
586    #[test]
587    fn compute_double_slope_for_valid_point_b() {
588        let point = (
589            bigint_str!(
590                "1937407885261715145522756206040455121546447384489085099828343908348117672673"
591            ),
592            bigint_str!(
593                "2010355627224183802477187221870580930152258042445852905639855522404179702985"
594            ),
595        );
596        let prime = (*CAIRO_PRIME).clone().into();
597        let alpha = bigint!(1);
598        assert_eq!(
599            bigint_str!(
600                "2904750555256547440469454488220756360634457312540595732507835416669695939476"
601            ),
602            ec_double_slope(&point, &alpha, &prime).unwrap()
603        );
604    }
605
606    #[test]
607    fn calculate_ec_double_for_valid_point_a() {
608        let point = (
609            bigint_str!(
610                "1937407885261715145522756206040455121546447384489085099828343908348117672673"
611            ),
612            bigint_str!(
613                "2010355627224183802477187221870580930152258042445852905639855522404179702985"
614            ),
615        );
616        let prime = (*CAIRO_PRIME).clone().into();
617        let alpha = bigint!(1);
618        assert_eq!(
619            (
620                bigint_str!(
621                    "58460926014232092148191979591712815229424797874927791614218178721848875644"
622                ),
623                bigint_str!(
624                    "1065613861227134732854284722490492186040898336012372352512913425790457998694"
625                )
626            ),
627            ec_double(point, &alpha, &prime).unwrap()
628        );
629    }
630
631    #[test]
632    fn calculate_ec_double_for_valid_point_b() {
633        let point = (
634            bigint_str!(
635                "3143372541908290873737380228370996772020829254218248561772745122290262847573"
636            ),
637            bigint_str!(
638                "1721586982687138486000069852568887984211460575851774005637537867145702861131"
639            ),
640        );
641        let prime = (*CAIRO_PRIME).clone().into();
642        let alpha = bigint!(1);
643        assert_eq!(
644            (
645                bigint_str!(
646                    "1937407885261715145522756206040455121546447384489085099828343908348117672673"
647                ),
648                bigint_str!(
649                    "2010355627224183802477187221870580930152258042445852905639855522404179702985"
650                )
651            ),
652            ec_double(point, &alpha, &prime).unwrap()
653        );
654    }
655
656    #[test]
657    fn calculate_ec_double_for_valid_point_c() {
658        let point = (
659            bigint_str!(
660                "634630432210960355305430036410971013200846091773294855689580772209984122075"
661            ),
662            bigint_str!(
663                "904896178444785983993402854911777165629036333948799414977736331868834995209"
664            ),
665        );
666        let prime = (*CAIRO_PRIME).clone().into();
667        let alpha = bigint!(1);
668        assert_eq!(
669            (
670                bigint_str!(
671                    "3143372541908290873737380228370996772020829254218248561772745122290262847573"
672                ),
673                bigint_str!(
674                    "1721586982687138486000069852568887984211460575851774005637537867145702861131"
675                )
676            ),
677            ec_double(point, &alpha, &prime).unwrap()
678        );
679    }
680
681    #[test]
682    fn calculate_ec_add_for_valid_points_a() {
683        let point_a = (
684            bigint_str!(
685                "1183418161532233795704555250127335895546712857142554564893196731153957537489"
686            ),
687            bigint_str!(
688                "1938007580204102038458825306058547644691739966277761828724036384003180924526"
689            ),
690        );
691        let point_b = (
692            bigint_str!(
693                "1977703130303461992863803129734853218488251484396280000763960303272760326570"
694            ),
695            bigint_str!(
696                "2565191853811572867032277464238286011368568368717965689023024980325333517459"
697            ),
698        );
699        let prime = (*CAIRO_PRIME).clone().into();
700        assert_eq!(
701            (
702                bigint_str!(
703                    "1977874238339000383330315148209250828062304908491266318460063803060754089297"
704                ),
705                bigint_str!(
706                    "2969386888251099938335087541720168257053975603483053253007176033556822156706"
707                )
708            ),
709            ec_add(point_a, point_b, &prime).unwrap()
710        );
711    }
712
713    #[test]
714    fn calculate_ec_add_for_valid_points_b() {
715        let point_a = (
716            bigint_str!(
717                "3139037544796708144595053687182055617920475701120786241351436619796497072089"
718            ),
719            bigint_str!(
720                "2119589567875935397690285099786081818522144748339117565577200220779667999801"
721            ),
722        );
723        let point_b = (
724            bigint_str!(
725                "3324833730090626974525872402899302150520188025637965566623476530814354734325"
726            ),
727            bigint_str!(
728                "3147007486456030910661996439995670279305852583596209647900952752170983517249"
729            ),
730        );
731        let prime = (*CAIRO_PRIME).clone().into();
732        assert_eq!(
733            (
734                bigint_str!(
735                    "1183418161532233795704555250127335895546712857142554564893196731153957537489"
736                ),
737                bigint_str!(
738                    "1938007580204102038458825306058547644691739966277761828724036384003180924526"
739                )
740            ),
741            ec_add(point_a, point_b, &prime).unwrap()
742        );
743    }
744
745    #[test]
746    fn calculate_ec_add_for_valid_points_c() {
747        let point_a = (
748            bigint_str!(
749                "1183418161532233795704555250127335895546712857142554564893196731153957537489"
750            ),
751            bigint_str!(
752                "1938007580204102038458825306058547644691739966277761828724036384003180924526"
753            ),
754        );
755        let point_b = (
756            bigint_str!(
757                "1977703130303461992863803129734853218488251484396280000763960303272760326570"
758            ),
759            bigint_str!(
760                "2565191853811572867032277464238286011368568368717965689023024980325333517459"
761            ),
762        );
763        let prime = (*CAIRO_PRIME).clone().into();
764        assert_eq!(
765            (
766                bigint_str!(
767                    "1977874238339000383330315148209250828062304908491266318460063803060754089297"
768                ),
769                bigint_str!(
770                    "2969386888251099938335087541720168257053975603483053253007176033556822156706"
771                )
772            ),
773            ec_add(point_a, point_b, &prime).unwrap()
774        );
775    }
776
777    #[test]
778    fn calculate_isqrt_a() {
779        let n = biguint!(81);
780        assert_matches!(isqrt(&n), Ok(x) if x == biguint!(9));
781    }
782
783    #[test]
784    fn calculate_isqrt_b() {
785        let n = biguint_str!("4573659632505831259480");
786        assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(num) if num == n);
787    }
788
789    #[test]
790    fn calculate_isqrt_c() {
791        let n = biguint_str!(
792            "3618502788666131213697322783095070105623107215331596699973092056135872020481"
793        );
794        assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(inner) if inner == n);
795    }
796
797    #[test]
798    fn calculate_isqrt_zero() {
799        let n = BigUint::zero();
800        assert_matches!(isqrt(&n), Ok(inner) if inner.is_zero());
801    }
802
803    #[test]
804    fn safe_div_bigint_by_zero() {
805        let x = BigInt::one();
806        let y = BigInt::zero();
807        assert_matches!(safe_div_bigint(&x, &y), Err(MathError::DividedByZero))
808    }
809
810    #[test]
811    fn test_sqrt_prime_power() {
812        let n: BigUint = 25_u32.into();
813        let p: BigUint = 18446744069414584321_u128.into();
814        assert_eq!(sqrt_prime_power(&n, &p), Some(5_u32.into()));
815    }
816
817    #[test]
818    fn test_sqrt_prime_power_p_is_zero() {
819        let n = BigUint::one();
820        let p: BigUint = BigUint::zero();
821        assert_eq!(sqrt_prime_power(&n, &p), None);
822    }
823
824    #[test]
825    fn test_sqrt_prime_power_non_prime() {
826        let p: BigUint = BigUint::from_bytes_be(&[
827            69, 15, 232, 82, 215, 167, 38, 143, 173, 94, 133, 111, 1, 2, 182, 229, 110, 113, 76, 0,
828            47, 110, 148, 109, 6, 133, 27, 190, 158, 197, 168, 219, 165, 254, 81, 53, 25, 34,
829        ]);
830        let n = BigUint::from_bytes_be(&[
831            9, 13, 22, 191, 87, 62, 157, 83, 157, 85, 93, 105, 230, 187, 32, 101, 51, 181, 49, 202,
832            203, 195, 76, 193, 149, 78, 109, 146, 240, 126, 182, 115, 161, 238, 30, 118, 157, 252,
833        ]);
834
835        assert_eq!(sqrt_prime_power(&n, &p), None);
836    }
837
838    #[test]
839    fn test_sqrt_prime_power_none() {
840        let n: BigUint = 10_u32.into();
841        let p: BigUint = 602_u32.into();
842        assert_eq!(sqrt_prime_power(&n, &p), None);
843    }
844
845    #[test]
846    fn test_sqrt_prime_power_prime_two() {
847        let n: BigUint = 25_u32.into();
848        let p: BigUint = 2_u32.into();
849        assert_eq!(sqrt_prime_power(&n, &p), Some(BigUint::one()));
850    }
851
852    #[test]
853    fn test_sqrt_prime_power_prime_mod_8_is_5_sign_not_one() {
854        let n: BigUint = 676_u32.into();
855        let p: BigUint = 9956234341095173_u64.into();
856        assert_eq!(
857            sqrt_prime_power(&n, &p),
858            Some(BigUint::from(9956234341095147_u64))
859        );
860    }
861
862    #[test]
863    fn test_sqrt_prime_power_prime_mod_8_is_5_sign_is_one() {
864        let n: BigUint = 130283432663_u64.into();
865        let p: BigUint = 743900351477_u64.into();
866        assert_eq!(
867            sqrt_prime_power(&n, &p),
868            Some(BigUint::from(123538694848_u64))
869        );
870    }
871
872    #[test]
873    fn test_legendre_symbol_zero() {
874        assert!(legendre_symbol(&BigUint::zero(), &BigUint::one()).is_zero())
875    }
876
877    #[test]
878    fn test_is_quad_residue_prime_zero() {
879        assert_eq!(
880            is_quad_residue(&BigUint::one(), &BigUint::zero()),
881            Err(MathError::IsQuadResidueZeroPrime)
882        )
883    }
884
885    #[test]
886    fn test_is_quad_residue_prime_a_one_true() {
887        assert_eq!(is_quad_residue(&BigUint::one(), &BigUint::one()), Ok(true))
888    }
889
890    #[test]
891    fn mul_inv_0_is_0() {
892        let p = &(*CAIRO_PRIME).clone().into();
893        let x = &BigInt::zero();
894        let x_inv = mul_inv(x, p);
895
896        assert_eq!(x_inv, BigInt::zero());
897    }
898
899    #[test]
900    fn igcdex_1_1() {
901        assert_eq!(
902            igcdex(&BigInt::one(), &BigInt::one()),
903            (BigInt::zero(), BigInt::one(), BigInt::one())
904        )
905    }
906
907    #[test]
908    fn igcdex_0_0() {
909        assert_eq!(
910            igcdex(&BigInt::zero(), &BigInt::zero()),
911            (BigInt::zero(), BigInt::one(), BigInt::zero())
912        )
913    }
914
915    #[test]
916    fn igcdex_1_0() {
917        assert_eq!(
918            igcdex(&BigInt::one(), &BigInt::zero()),
919            (BigInt::one(), BigInt::zero(), BigInt::one())
920        )
921    }
922
923    #[test]
924    fn igcdex_4_6() {
925        assert_eq!(
926            igcdex(&BigInt::from(4), &BigInt::from(6)),
927            (BigInt::from(-1), BigInt::one(), BigInt::from(2))
928        )
929    }
930
931    proptest! {
932
933        #[test]
934        fn pow2_const_in_range_returns_power_of_2(x in 0..=251u32) {
935            prop_assert_eq!(pow2_const(x), Felt252::TWO.pow(x));
936        }
937
938        #[test]
939        fn pow2_const_oob_returns_1(x in 252u32..) {
940            prop_assert_eq!(pow2_const(x), Felt252::ONE);
941        }
942
943        #[test]
944        fn pow2_const_nz_in_range_returns_power_of_2(x in 0..=251u32) {
945            prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::TWO.pow(x));
946        }
947
948        #[test]
949        fn pow2_const_nz_oob_returns_1(x in 252u32..) {
950            prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::ONE);
951        }
952
953        #[test]
954        // Test for sqrt_prime_power_ of a quadratic residue. Result should be the minimum root.
955        fn sqrt_prime_power_using_random_prime(ref x in any::<[u8; 38]>(), ref y in any::<u64>()) {
956            let mut rng = SmallRng::seed_from_u64(*y);
957            let x = &BigUint::from_bytes_be(x);
958            // Generate a prime here instead of relying on y, otherwise y may never be a prime number
959            let p : &BigUint = &RandPrime::gen_prime(&mut rng, 384,  None);
960            let x_sq = x * x;
961            if let Some(sqrt) = sqrt_prime_power(&x_sq, p) {
962                if &sqrt != x {
963                    prop_assert_eq!(&(p - sqrt), x);
964                } else {
965                prop_assert_eq!(&sqrt, x);
966                }
967            }
968        }
969
970        #[test]
971        fn mul_inv_x_by_x_is_1(ref x in any::<[u8; 32]>()) {
972            let p = &(*CAIRO_PRIME).clone().into();
973            let pos_x = &BigInt::from_bytes_be(Sign::Plus, x);
974            let neg_x = &BigInt::from_bytes_be(Sign::Minus, x);
975            let pos_x_inv = mul_inv(pos_x, p);
976            let neg_x_inv = mul_inv(neg_x, p);
977
978            prop_assert_eq!((pos_x * pos_x_inv).mod_floor(p), BigInt::one());
979            prop_assert_eq!((neg_x * neg_x_inv).mod_floor(p), BigInt::one());
980        }
981    }
982}