cairo_vm/math_utils/
mod.rs

1mod is_prime;
2
3pub use is_prime::is_prime;
4
5use core::cmp::min;
6
7use crate::stdlib::{boxed::Box, ops::Shr, prelude::Vec};
8use crate::types::errors::math_errors::MathError;
9use crate::utils::CAIRO_PRIME;
10use crate::Felt252;
11use lazy_static::lazy_static;
12use num_bigint::{BigInt, BigUint, RandBigInt, ToBigInt};
13use num_integer::Integer;
14use num_traits::{One, Signed, Zero};
15use rand::{rngs::SmallRng, SeedableRng};
16use starknet_types_core::felt::NonZeroFelt;
17
18lazy_static! {
19    pub static ref SIGNED_FELT_MAX: BigUint = (&*CAIRO_PRIME).shr(1_u32);
20    static ref POWERS_OF_TWO: Vec<NonZeroFelt> =
21        core::iter::successors(Some(Felt252::ONE), |x| Some(x * Felt252::TWO))
22            .take(252)
23            .map(|x| x.try_into().unwrap())
24            .collect::<Vec<_>>();
25}
26
27pub const STWO_PRIME: u64 = (1 << 31) - 1;
28const STWO_PRIME_U128: u128 = STWO_PRIME as u128;
29const MASK_36: u64 = (1 << 36) - 1;
30const MASK_8: u64 = (1 << 8) - 1;
31
32/// Returns the `n`th (up to the `251`th power) power of 2 as a [`Felt252`]
33/// in constant time.
34/// It silently returns `1` if the input is out of bounds.
35pub fn pow2_const(n: u32) -> Felt252 {
36    // If the conversion fails then it's out of range and we compute the power as usual
37    POWERS_OF_TWO
38        .get(n as usize)
39        .unwrap_or(&POWERS_OF_TWO[0])
40        .into()
41}
42
43/// Returns the `n`th (up to the `251`th power) power of 2 as a [`&stark_felt::NonZeroFelt`]
44/// in constant time.
45/// It silently returns `1` if the input is out of bounds.
46pub fn pow2_const_nz(n: u32) -> &'static NonZeroFelt {
47    // If the conversion fails then it's out of range and we compute the power as usual
48    POWERS_OF_TWO.get(n as usize).unwrap_or(&POWERS_OF_TWO[0])
49}
50
51/// Converts [`Felt252`] into a [`BigInt`] number in the range: `(- FIELD / 2, FIELD / 2)`.
52///
53/// # Examples
54///
55/// ```
56/// # use cairo_vm::{Felt252, math_utils::signed_felt};
57/// # use num_bigint::BigInt;
58/// let positive = Felt252::from(5);
59/// assert_eq!(signed_felt(positive), BigInt::from(5));
60///
61/// let negative = Felt252::MAX;
62/// assert_eq!(signed_felt(negative), BigInt::from(-1));
63/// ```
64pub fn signed_felt(felt: Felt252) -> BigInt {
65    let biguint = felt.to_biguint();
66    if biguint > *SIGNED_FELT_MAX {
67        BigInt::from_biguint(num_bigint::Sign::Minus, &*CAIRO_PRIME - &biguint)
68    } else {
69        biguint.to_bigint().expect("cannot fail")
70    }
71}
72
73pub fn signed_felt_for_prime(value: Felt252, prime: &BigUint) -> BigInt {
74    let value = value.to_biguint();
75    let half_prime = prime / 2u32;
76    if value > half_prime {
77        BigInt::from_biguint(num_bigint::Sign::Minus, prime - &value)
78    } else {
79        BigInt::from_biguint(num_bigint::Sign::Plus, value)
80    }
81}
82
83/// QM31 utility function, used specifically for Stwo.
84/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
85/// Reads four u64 coordinates from a single Felt252.
86/// STWO_PRIME fits in 36 bits, hence each coordinate can be represented by 36 bits and a QM31
87/// element can be stored in the first 144 bits of a Felt252.
88/// Returns an error if the input has over 144 bits or any coordinate is unreduced.
89fn qm31_packed_reduced_read_coordinates(felt: Felt252) -> Result<[u64; 4], MathError> {
90    let limbs = felt.to_le_digits();
91    if limbs[3] != 0 || limbs[2] >= 1 << 16 {
92        return Err(MathError::QM31UnreducedError(Box::new(felt)));
93    }
94    let coordinates = [
95        (limbs[0] & MASK_36),
96        ((limbs[0] >> 36) + ((limbs[1] & MASK_8) << 28)),
97        ((limbs[1] >> 8) & MASK_36),
98        ((limbs[1] >> 44) + (limbs[2] << 20)),
99    ];
100    for x in coordinates.iter() {
101        if *x >= STWO_PRIME {
102            return Err(MathError::QM31UnreducedError(Box::new(felt)));
103        }
104    }
105    Ok(coordinates)
106}
107
108/// QM31 utility function, used specifically for Stwo.
109/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
110/// Reduces four u64 coordinates and packs them into a single Felt252.
111/// STWO_PRIME fits in 36 bits, hence each coordinate can be represented by 36 bits and a QM31
112/// element can be stored in the first 144 bits of a Felt252.
113pub(crate) fn qm31_coordinates_to_packed_reduced(coordinates: [u64; 4]) -> Felt252 {
114    let bytes_part1 = ((coordinates[0] % STWO_PRIME) as u128
115        + (((coordinates[1] % STWO_PRIME) as u128) << 36))
116        .to_le_bytes();
117    let bytes_part2 = ((coordinates[2] % STWO_PRIME) as u128
118        + (((coordinates[3] % STWO_PRIME) as u128) << 36))
119        .to_le_bytes();
120    let mut result_bytes = [0u8; 32];
121    result_bytes[0..9].copy_from_slice(&bytes_part1[0..9]);
122    result_bytes[9..18].copy_from_slice(&bytes_part2[0..9]);
123    Felt252::from_bytes_le(&result_bytes)
124}
125
126/// QM31 utility function, used specifically for Stwo.
127/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
128/// Computes the addition of two QM31 elements in reduced form.
129/// Returns an error if either operand is not reduced.
130pub(crate) fn qm31_packed_reduced_add(
131    felt1: Felt252,
132    felt2: Felt252,
133) -> Result<Felt252, MathError> {
134    let coordinates1 = qm31_packed_reduced_read_coordinates(felt1)?;
135    let coordinates2 = qm31_packed_reduced_read_coordinates(felt2)?;
136    let result_unreduced_coordinates = [
137        coordinates1[0] + coordinates2[0],
138        coordinates1[1] + coordinates2[1],
139        coordinates1[2] + coordinates2[2],
140        coordinates1[3] + coordinates2[3],
141    ];
142    Ok(qm31_coordinates_to_packed_reduced(
143        result_unreduced_coordinates,
144    ))
145}
146
147/// QM31 utility function, used specifically for Stwo.
148/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
149/// Computes the negative of a QM31 element in reduced form.
150/// Returns an error if the input is not reduced.
151#[allow(dead_code)]
152pub(crate) fn qm31_packed_reduced_neg(felt: Felt252) -> Result<Felt252, MathError> {
153    let coordinates = qm31_packed_reduced_read_coordinates(felt)?;
154    Ok(qm31_coordinates_to_packed_reduced([
155        STWO_PRIME - coordinates[0],
156        STWO_PRIME - coordinates[1],
157        STWO_PRIME - coordinates[2],
158        STWO_PRIME - coordinates[3],
159    ]))
160}
161
162/// QM31 utility function, used specifically for Stwo.
163/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
164/// Computes the subtraction of two QM31 elements in reduced form.
165/// Returns an error if either operand is not reduced.
166pub(crate) fn qm31_packed_reduced_sub(
167    felt1: Felt252,
168    felt2: Felt252,
169) -> Result<Felt252, MathError> {
170    let coordinates1 = qm31_packed_reduced_read_coordinates(felt1)?;
171    let coordinates2 = qm31_packed_reduced_read_coordinates(felt2)?;
172    let result_unreduced_coordinates = [
173        STWO_PRIME + coordinates1[0] - coordinates2[0],
174        STWO_PRIME + coordinates1[1] - coordinates2[1],
175        STWO_PRIME + coordinates1[2] - coordinates2[2],
176        STWO_PRIME + coordinates1[3] - coordinates2[3],
177    ];
178    Ok(qm31_coordinates_to_packed_reduced(
179        result_unreduced_coordinates,
180    ))
181}
182
183/// QM31 utility function, used specifically for Stwo.
184/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
185/// Computes the multiplication of two QM31 elements in reduced form.
186/// Returns an error if either operand is not reduced.
187pub(crate) fn qm31_packed_reduced_mul(
188    felt1: Felt252,
189    felt2: Felt252,
190) -> Result<Felt252, MathError> {
191    let coordinates1_u64 = qm31_packed_reduced_read_coordinates(felt1)?;
192    let coordinates2_u64 = qm31_packed_reduced_read_coordinates(felt2)?;
193    let coordinates1 = coordinates1_u64.map(u128::from);
194    let coordinates2 = coordinates2_u64.map(u128::from);
195
196    let result_coordinates = [
197        ((5 * STWO_PRIME_U128 * STWO_PRIME_U128 + coordinates1[0] * coordinates2[0]
198            - coordinates1[1] * coordinates2[1]
199            + 2 * coordinates1[2] * coordinates2[2]
200            - 2 * coordinates1[3] * coordinates2[3]
201            - coordinates1[2] * coordinates2[3]
202            - coordinates1[3] * coordinates2[2])
203            % STWO_PRIME_U128) as u64,
204        ((STWO_PRIME_U128 * STWO_PRIME_U128
205            + coordinates1[0] * coordinates2[1]
206            + coordinates1[1] * coordinates2[0]
207            + 2 * (coordinates1[2] * coordinates2[3] + coordinates1[3] * coordinates2[2])
208            + coordinates1[2] * coordinates2[2]
209            - coordinates1[3] * coordinates2[3])
210            % STWO_PRIME_U128) as u64,
211        2 * STWO_PRIME * STWO_PRIME + coordinates1_u64[0] * coordinates2_u64[2]
212            - coordinates1_u64[1] * coordinates2_u64[3]
213            + coordinates1_u64[2] * coordinates2_u64[0]
214            - coordinates1_u64[3] * coordinates2_u64[1],
215        coordinates1_u64[0] * coordinates2_u64[3]
216            + coordinates1_u64[1] * coordinates2_u64[2]
217            + coordinates1_u64[2] * coordinates2_u64[1]
218            + coordinates1_u64[3] * coordinates2_u64[0],
219    ];
220    Ok(qm31_coordinates_to_packed_reduced(result_coordinates))
221}
222
223/// M31 utility function, used specifically for Stwo.
224/// M31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
225/// Computes the inverse in the M31 field using Fermat's little theorem, i.e., returns
226/// `v^(STWO_PRIME-2) modulo STWO_PRIME`, which is the inverse of v unless v % STWO_PRIME == 0.
227pub(crate) fn pow2147483645(v: u64) -> u64 {
228    let t0 = (sqn(v, 2) * v) % STWO_PRIME;
229    let t1 = (sqn(t0, 1) * t0) % STWO_PRIME;
230    let t2 = (sqn(t1, 3) * t0) % STWO_PRIME;
231    let t3 = (sqn(t2, 1) * t0) % STWO_PRIME;
232    let t4 = (sqn(t3, 8) * t3) % STWO_PRIME;
233    let t5 = (sqn(t4, 8) * t3) % STWO_PRIME;
234    (sqn(t5, 7) * t2) % STWO_PRIME
235}
236
237/// M31 utility function, used specifically for Stwo.
238/// M31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
239/// Computes `v^(2^n) modulo STWO_PRIME`.
240fn sqn(v: u64, n: usize) -> u64 {
241    let mut u = v;
242    for _ in 0..n {
243        u = (u * u) % STWO_PRIME;
244    }
245    u
246}
247
248/// QM31 utility function, used specifically for Stwo.
249/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
250/// Computes the inverse of a QM31 element in reduced form.
251/// Returns an error if the denominator is zero or either operand is not reduced.
252pub(crate) fn qm31_packed_reduced_inv(felt: Felt252) -> Result<Felt252, MathError> {
253    if felt.is_zero() {
254        return Err(MathError::DividedByZero);
255    }
256    let coordinates = qm31_packed_reduced_read_coordinates(felt)?;
257
258    let b2_r = (coordinates[2] * coordinates[2] + STWO_PRIME * STWO_PRIME
259        - coordinates[3] * coordinates[3])
260        % STWO_PRIME;
261    let b2_i = (2 * coordinates[2] * coordinates[3]) % STWO_PRIME;
262
263    let denom_r = (coordinates[0] * coordinates[0] + STWO_PRIME * STWO_PRIME
264        - coordinates[1] * coordinates[1]
265        + 2 * STWO_PRIME
266        - 2 * b2_r
267        + b2_i)
268        % STWO_PRIME;
269    let denom_i =
270        (2 * coordinates[0] * coordinates[1] + 3 * STWO_PRIME - 2 * b2_i - b2_r) % STWO_PRIME;
271
272    let denom_norm_squared = (denom_r * denom_r + denom_i * denom_i) % STWO_PRIME;
273    let denom_norm_inverse_squared = pow2147483645(denom_norm_squared);
274
275    let denom_inverse_r = (denom_r * denom_norm_inverse_squared) % STWO_PRIME;
276    let denom_inverse_i = ((STWO_PRIME - denom_i) * denom_norm_inverse_squared) % STWO_PRIME;
277
278    Ok(qm31_coordinates_to_packed_reduced([
279        coordinates[0] * denom_inverse_r + STWO_PRIME * STWO_PRIME
280            - coordinates[1] * denom_inverse_i,
281        coordinates[0] * denom_inverse_i + coordinates[1] * denom_inverse_r,
282        coordinates[3] * denom_inverse_i + STWO_PRIME * STWO_PRIME
283            - coordinates[2] * denom_inverse_r,
284        2 * STWO_PRIME * STWO_PRIME
285            - coordinates[2] * denom_inverse_i
286            - coordinates[3] * denom_inverse_r,
287    ]))
288}
289
290/// QM31 utility function, used specifically for Stwo.
291/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
292/// Computes the division of two QM31 elements in reduced form.
293/// Returns an error if the input is zero.
294pub(crate) fn qm31_packed_reduced_div(
295    felt1: Felt252,
296    felt2: Felt252,
297) -> Result<Felt252, MathError> {
298    let felt2_inv = qm31_packed_reduced_inv(felt2)?;
299    qm31_packed_reduced_mul(felt1, felt2_inv)
300}
301
302///Returns the integer square root of the nonnegative integer n.
303///This is the floor of the exact square root of n.
304///Unlike math.sqrt(), this function doesn't have rounding error issues.
305pub fn isqrt(n: &BigUint) -> Result<BigUint, MathError> {
306    /*    # The following algorithm was copied from
307    # https://stackoverflow.com/questions/15390807/integer-square-root-in-python.
308    x = n
309    y = (x + 1) // 2
310    while y < x:
311        x = y
312        y = (x + n // x) // 2
313    assert x**2 <= n < (x + 1) ** 2
314    return x*/
315
316    let mut x = n.clone();
317    //n.shr(1) = n.div_floor(2)
318    let mut y = (&x + 1_u32).shr(1_u32);
319
320    while y < x {
321        x = y;
322        y = (&x + n.div_floor(&x)).shr(1_u32);
323    }
324
325    if !(&BigUint::pow(&x, 2_u32) <= n && n < &BigUint::pow(&(&x + 1_u32), 2_u32)) {
326        return Err(MathError::FailedToGetSqrt(Box::new(n.clone())));
327    };
328    Ok(x)
329}
330
331/// Performs integer division between x and y; fails if x is not divisible by y.
332pub fn safe_div(x: &Felt252, y: &Felt252) -> Result<Felt252, MathError> {
333    let (q, r) = x.div_rem(&y.try_into().map_err(|_| MathError::DividedByZero)?);
334
335    if !r.is_zero() {
336        Err(MathError::SafeDivFail(Box::new((*x, *y))))
337    } else {
338        Ok(q)
339    }
340}
341
342/// Performs integer division between x and y; fails if x is not divisible by y.
343pub fn safe_div_bigint(x: &BigInt, y: &BigInt) -> Result<BigInt, MathError> {
344    if y.is_zero() {
345        return Err(MathError::DividedByZero);
346    }
347
348    let (q, r) = x.div_mod_floor(y);
349
350    if !r.is_zero() {
351        return Err(MathError::SafeDivFailBigInt(Box::new((
352            x.clone(),
353            y.clone(),
354        ))));
355    }
356
357    Ok(q)
358}
359
360/// Performs integer division between x and y; fails if x is not divisible by y.
361pub fn safe_div_usize(x: usize, y: usize) -> Result<usize, MathError> {
362    if y.is_zero() {
363        return Err(MathError::DividedByZero);
364    }
365
366    let (q, r) = x.div_mod_floor(&y);
367
368    if !r.is_zero() {
369        return Err(MathError::SafeDivFailUsize(Box::new((x, y))));
370    }
371
372    Ok(q)
373}
374
375///Returns num_a^-1 mod p
376pub(crate) fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt {
377    if num_a.is_zero() {
378        return BigInt::zero();
379    }
380    let mut a = num_a.abs();
381    let x_sign = num_a.signum();
382    let mut b = p.abs();
383    let (mut x, mut r) = (BigInt::one(), BigInt::zero());
384    let (mut c, mut q);
385    while !b.is_zero() {
386        (q, c) = a.div_mod_floor(&b);
387        x -= &q * &r;
388        (a, b, r, x) = (b, c, x, r)
389    }
390
391    x * x_sign
392}
393
394///Returns x, y, g such that g = x*a + y*b = gcd(a, b).
395fn igcdex(num_a: &BigInt, num_b: &BigInt) -> (BigInt, BigInt, BigInt) {
396    match (num_a, num_b) {
397        (a, b) if a.is_zero() && b.is_zero() => (BigInt::zero(), BigInt::one(), BigInt::zero()),
398        (a, _) if a.is_zero() => (BigInt::zero(), num_b.signum(), num_b.abs()),
399        (_, b) if b.is_zero() => (num_a.signum(), BigInt::zero(), num_a.abs()),
400        _ => {
401            let mut a = num_a.abs();
402            let x_sign = num_a.signum();
403            let mut b = num_b.abs();
404            let y_sign = num_b.signum();
405            let (mut x, mut y, mut r, mut s) =
406                (BigInt::one(), BigInt::zero(), BigInt::zero(), BigInt::one());
407            let (mut c, mut q);
408            while !b.is_zero() {
409                (q, c) = a.div_mod_floor(&b);
410                x -= &q * &r;
411                y -= &q * &s;
412                (a, b, r, s, x, y) = (b, c, x, y, r, s)
413            }
414            (x * x_sign, y * y_sign, a)
415        }
416    }
417}
418
419///Finds a nonnegative integer x < p such that (m * x) % p == n.
420pub fn div_mod(n: &BigInt, m: &BigInt, p: &BigInt) -> Result<BigInt, MathError> {
421    let (a, _, c) = igcdex(m, p);
422    if !c.is_one() {
423        return Err(MathError::DivModIgcdexNotZero(Box::new((
424            n.clone(),
425            m.clone(),
426            p.clone(),
427        ))));
428    }
429    Ok((n * a).mod_floor(p))
430}
431
432pub(crate) fn div_mod_unsigned(
433    n: &BigUint,
434    m: &BigUint,
435    p: &BigUint,
436) -> Result<BigUint, MathError> {
437    // BigUint to BigInt conversion cannot fail & div_mod will always return a positive value if all values are positive so we can safely unwrap here
438    div_mod(
439        &n.to_bigint().unwrap(),
440        &m.to_bigint().unwrap(),
441        &p.to_bigint().unwrap(),
442    )
443    .map(|i| i.to_biguint().unwrap())
444}
445
446pub fn ec_add(
447    point_a: (BigInt, BigInt),
448    point_b: (BigInt, BigInt),
449    prime: &BigInt,
450) -> Result<(BigInt, BigInt), MathError> {
451    let m = line_slope(&point_a, &point_b, prime)?;
452    let x = (m.clone() * m.clone() - point_a.0.clone() - point_b.0).mod_floor(prime);
453    let y = (m * (point_a.0 - x.clone()) - point_a.1).mod_floor(prime);
454    Ok((x, y))
455}
456
457/// Computes the slope of the line connecting the two given EC points over the field GF(p).
458/// Assumes the points are given in affine form (x, y) and have different x coordinates.
459pub fn line_slope(
460    point_a: &(BigInt, BigInt),
461    point_b: &(BigInt, BigInt),
462    prime: &BigInt,
463) -> Result<BigInt, MathError> {
464    debug_assert!(!(&point_a.0 - &point_b.0).is_multiple_of(prime));
465    div_mod(
466        &(&point_a.1 - &point_b.1),
467        &(&point_a.0 - &point_b.0),
468        prime,
469    )
470}
471
472///  Doubles a point on an elliptic curve with the equation y^2 = x^3 + alpha*x + beta mod p.
473/// Assumes the point is given in affine form (x, y) and has y != 0.
474pub fn ec_double(
475    point: (BigInt, BigInt),
476    alpha: &BigInt,
477    prime: &BigInt,
478) -> Result<(BigInt, BigInt), MathError> {
479    let m = ec_double_slope(&point, alpha, prime)?;
480    let x = ((&m * &m) - (2_i32 * &point.0)).mod_floor(prime);
481    let y = (m * (point.0 - &x) - point.1).mod_floor(prime);
482    Ok((x, y))
483}
484/// Computes the slope of an elliptic curve with the equation y^2 = x^3 + alpha*x + beta mod p, at
485/// the given point.
486/// Assumes the point is given in affine form (x, y) and has y != 0.
487pub fn ec_double_slope(
488    point: &(BigInt, BigInt),
489    alpha: &BigInt,
490    prime: &BigInt,
491) -> Result<BigInt, MathError> {
492    debug_assert!(!point.1.is_multiple_of(prime));
493    div_mod(
494        &(3_i32 * &point.0 * &point.0 + alpha),
495        &(2_i32 * &point.1),
496        prime,
497    )
498}
499
500// Adapted from sympy _sqrt_prime_power with k == 1
501pub fn sqrt_prime_power(a: &BigUint, p: &BigUint) -> Option<BigUint> {
502    if p.is_zero() || !is_prime(p) {
503        return None;
504    }
505    let two = BigUint::from(2_u32);
506    let a = a.mod_floor(p);
507    if p == &two {
508        return Some(a);
509    }
510    if !(a < two || (a.modpow(&(p - 1_u32).div_floor(&two), p)).is_one()) {
511        return None;
512    };
513
514    if p.mod_floor(&BigUint::from(4_u32)) == 3_u32.into() {
515        let res = a.modpow(&(p + 1_u32).div_floor(&BigUint::from(4_u32)), p);
516        return Some(min(res.clone(), p - res));
517    };
518
519    if p.mod_floor(&BigUint::from(8_u32)) == 5_u32.into() {
520        let sign = a.modpow(&(p - 1_u32).div_floor(&BigUint::from(4_u32)), p);
521        if sign.is_one() {
522            let res = a.modpow(&(p + 3_u32).div_floor(&BigUint::from(8_u32)), p);
523            return Some(min(res.clone(), p - res));
524        } else {
525            let b = (4_u32 * &a).modpow(&(p - 5_u32).div_floor(&BigUint::from(8_u32)), p);
526            let x = (2_u32 * &a * b).mod_floor(p);
527            if x.modpow(&two, p) == a {
528                return Some(x);
529            }
530        }
531    };
532
533    Some(sqrt_tonelli_shanks(&a, p))
534}
535
536fn sqrt_tonelli_shanks(n: &BigUint, prime: &BigUint) -> BigUint {
537    // Based on Tonelli-Shanks' algorithm for finding square roots
538    // and sympy's library implementation of said algorithm.
539    if n.is_zero() || n.is_one() {
540        return n.clone();
541    }
542    let s = (prime - 1_u32).trailing_zeros().unwrap_or_default();
543    let t = prime >> s;
544    let a = n.modpow(&t, prime);
545    // Rng is not critical here so its safe to use a seeded value
546    let mut rng = SmallRng::seed_from_u64(11480028852697973135);
547    let mut d;
548    loop {
549        d = RandBigInt::gen_biguint_range(&mut rng, &BigUint::from(2_u32), &(prime - 1_u32));
550        let r = legendre_symbol(&d, prime);
551        if r == -1 {
552            break;
553        };
554    }
555    d = d.modpow(&t, prime);
556    let mut m = BigUint::zero();
557    let mut exponent = BigUint::one() << (s - 1);
558    let mut adm;
559    for i in 0..s as u32 {
560        adm = &a * &d.modpow(&m, prime);
561        adm = adm.modpow(&exponent, prime);
562        exponent >>= 1;
563        if adm == (prime - 1_u32) {
564            m += BigUint::from(1_u32) << i;
565        }
566    }
567    let root_1 =
568        (n.modpow(&((t + 1_u32) >> 1), prime) * d.modpow(&(m >> 1), prime)).mod_floor(prime);
569    let root_2 = prime - &root_1;
570    if root_1 < root_2 {
571        root_1
572    } else {
573        root_2
574    }
575}
576
577/* Disclaimer: Some asumptions have been taken based on the functions that rely on this function, make sure these are true before calling this function individually
578Adpted from sympy implementation, asuming:
579    - p is an odd prime number
580    - a.mod_floor(p) == a
581Returns the Legendre symbol `(a / p)`.
582
583    For an integer ``a`` and an odd prime ``p``, the Legendre symbol is
584    defined as
585
586    .. math ::
587        \genfrac(){}{}{a}{p} = \begin{cases}
588             0 & \text{if } p \text{ divides } a\\
589             1 & \text{if } a \text{ is a quadratic residue modulo } p\\
590            -1 & \text{if } a \text{ is a quadratic nonresidue modulo } p
591        \end{cases}
592*/
593fn legendre_symbol(a: &BigUint, p: &BigUint) -> i8 {
594    if a.is_zero() {
595        return 0;
596    };
597    if is_quad_residue(a, p).unwrap_or_default() {
598        1
599    } else {
600        -1
601    }
602}
603
604// Ported from sympy implementation
605// Simplified as a & p are nonnegative
606// Asumes p is a prime number
607pub(crate) fn is_quad_residue(a: &BigUint, p: &BigUint) -> Result<bool, MathError> {
608    if p.is_zero() {
609        return Err(MathError::IsQuadResidueZeroPrime);
610    }
611    let a = if a >= p { a.mod_floor(p) } else { a.clone() };
612    if a < BigUint::from(2_u8) || p < &BigUint::from(3_u8) {
613        return Ok(true);
614    }
615    Ok(
616        a.modpow(&(p - BigUint::one()).div_floor(&BigUint::from(2_u8)), p)
617            .is_one(),
618    )
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624    use crate::utils::test_utils::*;
625    use crate::utils::CAIRO_PRIME;
626    use assert_matches::assert_matches;
627
628    use num_traits::Num;
629
630    #[cfg(feature = "std")]
631    use num_prime::RandPrime;
632
633    #[cfg(feature = "std")]
634    use proptest::{array::uniform4, prelude::*};
635
636    // Only used in proptest for now
637    #[cfg(feature = "std")]
638    use num_bigint::Sign;
639
640    #[cfg(target_arch = "wasm32")]
641    use wasm_bindgen_test::*;
642
643    #[test]
644    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
645    fn calculate_divmod_a() {
646        let a = bigint_str!(
647            "11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788"
648        );
649        let b = bigint_str!(
650            "4020711254448367604954374443741161860304516084891705811279711044808359405970"
651        );
652        assert_eq!(
653            bigint_str!(
654                "2904750555256547440469454488220756360634457312540595732507835416669695939476"
655            ),
656            div_mod(
657                &a,
658                &b,
659                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
660                    .expect("Couldn't parse prime")
661            )
662            .unwrap()
663        );
664    }
665
666    #[test]
667    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
668    fn calculate_divmod_b() {
669        let a = bigint_str!(
670            "29642372811668969595956851264770043260610851505766181624574941701711520154703788233010819515917136995474951116158286220089597404329949295479559895970988"
671        );
672        let b = bigint_str!(
673            "3443173965374276972000139705137775968422921151703548011275075734291405722262"
674        );
675        assert_eq!(
676            bigint_str!(
677                "3601388548860259779932034493250169083811722919049731683411013070523752439691"
678            ),
679            div_mod(
680                &a,
681                &b,
682                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
683                    .expect("Couldn't parse prime")
684            )
685            .unwrap()
686        );
687    }
688
689    #[test]
690    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
691    fn calculate_divmod_c() {
692        let a = bigint_str!(
693            "1208267356464811040667664150251401430616174694388968865551115897173431833224432165394286799069453655049199580362994484548890574931604445970825506916876"
694        );
695        let b = bigint_str!(
696            "1809792356889571967986805709823554331258072667897598829955472663737669990418"
697        );
698        assert_eq!(
699            bigint_str!(
700                "1545825591488572374291664030703937603499513742109806697511239542787093258962"
701            ),
702            div_mod(
703                &a,
704                &b,
705                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
706                    .expect("Couldn't parse prime")
707            )
708            .unwrap()
709        );
710    }
711
712    #[test]
713    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
714    fn compute_safe_div() {
715        let x = Felt252::from(26);
716        let y = Felt252::from(13);
717        assert_matches!(safe_div(&x, &y), Ok(i) if i == Felt252::from(2));
718    }
719
720    #[test]
721    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
722    fn compute_safe_div_non_divisor() {
723        let x = Felt252::from(25);
724        let y = Felt252::from(4);
725        let result = safe_div(&x, &y);
726        assert_matches!(
727            result,
728            Err(MathError::SafeDivFail(bx)) if *bx == (Felt252::from(25), Felt252::from(4)));
729    }
730
731    #[test]
732    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
733    fn compute_safe_div_by_zero() {
734        let x = Felt252::from(25);
735        let y = Felt252::ZERO;
736        let result = safe_div(&x, &y);
737        assert_matches!(result, Err(MathError::DividedByZero));
738    }
739
740    #[test]
741    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
742    fn compute_safe_div_usize() {
743        assert_matches!(safe_div_usize(26, 13), Ok(2));
744    }
745
746    #[test]
747    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
748    fn compute_safe_div_usize_non_divisor() {
749        assert_matches!(
750            safe_div_usize(25, 4),
751            Err(MathError::SafeDivFailUsize(bx)) if *bx == (25, 4)
752        );
753    }
754
755    #[test]
756    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
757    fn compute_safe_div_usize_by_zero() {
758        assert_matches!(safe_div_usize(25, 0), Err(MathError::DividedByZero));
759    }
760
761    #[test]
762    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
763    fn compute_line_slope_for_valid_points() {
764        let point_a = (
765            bigint_str!(
766                "3139037544796708144595053687182055617920475701120786241351436619796497072089"
767            ),
768            bigint_str!(
769                "2119589567875935397690285099786081818522144748339117565577200220779667999801"
770            ),
771        );
772        let point_b = (
773            bigint_str!(
774                "3324833730090626974525872402899302150520188025637965566623476530814354734325"
775            ),
776            bigint_str!(
777                "3147007486456030910661996439995670279305852583596209647900952752170983517249"
778            ),
779        );
780        let prime = (*CAIRO_PRIME).clone().into();
781        assert_eq!(
782            bigint_str!(
783                "992545364708437554384321881954558327331693627531977596999212637460266617010"
784            ),
785            line_slope(&point_a, &point_b, &prime).unwrap()
786        );
787    }
788
789    #[test]
790    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
791    fn compute_double_slope_for_valid_point_a() {
792        let point = (
793            bigint_str!(
794                "3143372541908290873737380228370996772020829254218248561772745122290262847573"
795            ),
796            bigint_str!(
797                "1721586982687138486000069852568887984211460575851774005637537867145702861131"
798            ),
799        );
800        let prime = (*CAIRO_PRIME).clone().into();
801        let alpha = bigint!(1);
802        assert_eq!(
803            bigint_str!(
804                "3601388548860259779932034493250169083811722919049731683411013070523752439691"
805            ),
806            ec_double_slope(&point, &alpha, &prime).unwrap()
807        );
808    }
809
810    #[test]
811    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
812    fn compute_double_slope_for_valid_point_b() {
813        let point = (
814            bigint_str!(
815                "1937407885261715145522756206040455121546447384489085099828343908348117672673"
816            ),
817            bigint_str!(
818                "2010355627224183802477187221870580930152258042445852905639855522404179702985"
819            ),
820        );
821        let prime = (*CAIRO_PRIME).clone().into();
822        let alpha = bigint!(1);
823        assert_eq!(
824            bigint_str!(
825                "2904750555256547440469454488220756360634457312540595732507835416669695939476"
826            ),
827            ec_double_slope(&point, &alpha, &prime).unwrap()
828        );
829    }
830
831    #[test]
832    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
833    fn calculate_ec_double_for_valid_point_a() {
834        let point = (
835            bigint_str!(
836                "1937407885261715145522756206040455121546447384489085099828343908348117672673"
837            ),
838            bigint_str!(
839                "2010355627224183802477187221870580930152258042445852905639855522404179702985"
840            ),
841        );
842        let prime = (*CAIRO_PRIME).clone().into();
843        let alpha = bigint!(1);
844        assert_eq!(
845            (
846                bigint_str!(
847                    "58460926014232092148191979591712815229424797874927791614218178721848875644"
848                ),
849                bigint_str!(
850                    "1065613861227134732854284722490492186040898336012372352512913425790457998694"
851                )
852            ),
853            ec_double(point, &alpha, &prime).unwrap()
854        );
855    }
856
857    #[test]
858    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
859    fn calculate_ec_double_for_valid_point_b() {
860        let point = (
861            bigint_str!(
862                "3143372541908290873737380228370996772020829254218248561772745122290262847573"
863            ),
864            bigint_str!(
865                "1721586982687138486000069852568887984211460575851774005637537867145702861131"
866            ),
867        );
868        let prime = (*CAIRO_PRIME).clone().into();
869        let alpha = bigint!(1);
870        assert_eq!(
871            (
872                bigint_str!(
873                    "1937407885261715145522756206040455121546447384489085099828343908348117672673"
874                ),
875                bigint_str!(
876                    "2010355627224183802477187221870580930152258042445852905639855522404179702985"
877                )
878            ),
879            ec_double(point, &alpha, &prime).unwrap()
880        );
881    }
882
883    #[test]
884    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
885    fn calculate_ec_double_for_valid_point_c() {
886        let point = (
887            bigint_str!(
888                "634630432210960355305430036410971013200846091773294855689580772209984122075"
889            ),
890            bigint_str!(
891                "904896178444785983993402854911777165629036333948799414977736331868834995209"
892            ),
893        );
894        let prime = (*CAIRO_PRIME).clone().into();
895        let alpha = bigint!(1);
896        assert_eq!(
897            (
898                bigint_str!(
899                    "3143372541908290873737380228370996772020829254218248561772745122290262847573"
900                ),
901                bigint_str!(
902                    "1721586982687138486000069852568887984211460575851774005637537867145702861131"
903                )
904            ),
905            ec_double(point, &alpha, &prime).unwrap()
906        );
907    }
908
909    #[test]
910    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
911    fn calculate_ec_add_for_valid_points_a() {
912        let point_a = (
913            bigint_str!(
914                "1183418161532233795704555250127335895546712857142554564893196731153957537489"
915            ),
916            bigint_str!(
917                "1938007580204102038458825306058547644691739966277761828724036384003180924526"
918            ),
919        );
920        let point_b = (
921            bigint_str!(
922                "1977703130303461992863803129734853218488251484396280000763960303272760326570"
923            ),
924            bigint_str!(
925                "2565191853811572867032277464238286011368568368717965689023024980325333517459"
926            ),
927        );
928        let prime = (*CAIRO_PRIME).clone().into();
929        assert_eq!(
930            (
931                bigint_str!(
932                    "1977874238339000383330315148209250828062304908491266318460063803060754089297"
933                ),
934                bigint_str!(
935                    "2969386888251099938335087541720168257053975603483053253007176033556822156706"
936                )
937            ),
938            ec_add(point_a, point_b, &prime).unwrap()
939        );
940    }
941
942    #[test]
943    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
944    fn calculate_ec_add_for_valid_points_b() {
945        let point_a = (
946            bigint_str!(
947                "3139037544796708144595053687182055617920475701120786241351436619796497072089"
948            ),
949            bigint_str!(
950                "2119589567875935397690285099786081818522144748339117565577200220779667999801"
951            ),
952        );
953        let point_b = (
954            bigint_str!(
955                "3324833730090626974525872402899302150520188025637965566623476530814354734325"
956            ),
957            bigint_str!(
958                "3147007486456030910661996439995670279305852583596209647900952752170983517249"
959            ),
960        );
961        let prime = (*CAIRO_PRIME).clone().into();
962        assert_eq!(
963            (
964                bigint_str!(
965                    "1183418161532233795704555250127335895546712857142554564893196731153957537489"
966                ),
967                bigint_str!(
968                    "1938007580204102038458825306058547644691739966277761828724036384003180924526"
969                )
970            ),
971            ec_add(point_a, point_b, &prime).unwrap()
972        );
973    }
974
975    #[test]
976    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
977    fn calculate_ec_add_for_valid_points_c() {
978        let point_a = (
979            bigint_str!(
980                "1183418161532233795704555250127335895546712857142554564893196731153957537489"
981            ),
982            bigint_str!(
983                "1938007580204102038458825306058547644691739966277761828724036384003180924526"
984            ),
985        );
986        let point_b = (
987            bigint_str!(
988                "1977703130303461992863803129734853218488251484396280000763960303272760326570"
989            ),
990            bigint_str!(
991                "2565191853811572867032277464238286011368568368717965689023024980325333517459"
992            ),
993        );
994        let prime = (*CAIRO_PRIME).clone().into();
995        assert_eq!(
996            (
997                bigint_str!(
998                    "1977874238339000383330315148209250828062304908491266318460063803060754089297"
999                ),
1000                bigint_str!(
1001                    "2969386888251099938335087541720168257053975603483053253007176033556822156706"
1002                )
1003            ),
1004            ec_add(point_a, point_b, &prime).unwrap()
1005        );
1006    }
1007
1008    #[test]
1009    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1010    fn calculate_isqrt_a() {
1011        let n = biguint!(81);
1012        assert_matches!(isqrt(&n), Ok(x) if x == biguint!(9));
1013    }
1014
1015    #[test]
1016    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1017    fn calculate_isqrt_b() {
1018        let n = biguint_str!("4573659632505831259480");
1019        assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(num) if num == n);
1020    }
1021
1022    #[test]
1023    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1024    fn calculate_isqrt_c() {
1025        let n = biguint_str!(
1026            "3618502788666131213697322783095070105623107215331596699973092056135872020481"
1027        );
1028        assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(inner) if inner == n);
1029    }
1030
1031    #[test]
1032    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1033    fn calculate_isqrt_zero() {
1034        let n = BigUint::zero();
1035        assert_matches!(isqrt(&n), Ok(inner) if inner.is_zero());
1036    }
1037
1038    #[test]
1039    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1040    fn safe_div_bigint_by_zero() {
1041        let x = BigInt::one();
1042        let y = BigInt::zero();
1043        assert_matches!(safe_div_bigint(&x, &y), Err(MathError::DividedByZero))
1044    }
1045
1046    #[test]
1047    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1048    fn test_sqrt_prime_power() {
1049        let n: BigUint = 25_u32.into();
1050        let p: BigUint = 18446744069414584321_u128.into();
1051        assert_eq!(sqrt_prime_power(&n, &p), Some(5_u32.into()));
1052    }
1053
1054    #[test]
1055    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1056    fn test_sqrt_prime_power_p_is_zero() {
1057        let n = BigUint::one();
1058        let p: BigUint = BigUint::zero();
1059        assert_eq!(sqrt_prime_power(&n, &p), None);
1060    }
1061
1062    #[test]
1063    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1064    fn test_sqrt_prime_power_non_prime() {
1065        let p: BigUint = BigUint::from_bytes_be(&[
1066            69, 15, 232, 82, 215, 167, 38, 143, 173, 94, 133, 111, 1, 2, 182, 229, 110, 113, 76, 0,
1067            47, 110, 148, 109, 6, 133, 27, 190, 158, 197, 168, 219, 165, 254, 81, 53, 25, 34,
1068        ]);
1069        let n = BigUint::from_bytes_be(&[
1070            9, 13, 22, 191, 87, 62, 157, 83, 157, 85, 93, 105, 230, 187, 32, 101, 51, 181, 49, 202,
1071            203, 195, 76, 193, 149, 78, 109, 146, 240, 126, 182, 115, 161, 238, 30, 118, 157, 252,
1072        ]);
1073
1074        assert_eq!(sqrt_prime_power(&n, &p), None);
1075    }
1076
1077    #[test]
1078    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1079    fn test_sqrt_prime_power_none() {
1080        let n: BigUint = 10_u32.into();
1081        let p: BigUint = 602_u32.into();
1082        assert_eq!(sqrt_prime_power(&n, &p), None);
1083    }
1084
1085    #[test]
1086    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1087    fn test_sqrt_prime_power_prime_two() {
1088        let n: BigUint = 25_u32.into();
1089        let p: BigUint = 2_u32.into();
1090        assert_eq!(sqrt_prime_power(&n, &p), Some(BigUint::one()));
1091    }
1092
1093    #[test]
1094    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1095    fn test_sqrt_prime_power_prime_mod_8_is_5_sign_not_one() {
1096        let n: BigUint = 676_u32.into();
1097        let p: BigUint = 9956234341095173_u64.into();
1098        assert_eq!(
1099            sqrt_prime_power(&n, &p),
1100            Some(BigUint::from(9956234341095147_u64))
1101        );
1102    }
1103
1104    #[test]
1105    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1106    fn test_sqrt_prime_power_prime_mod_8_is_5_sign_is_one() {
1107        let n: BigUint = 130283432663_u64.into();
1108        let p: BigUint = 743900351477_u64.into();
1109        assert_eq!(
1110            sqrt_prime_power(&n, &p),
1111            Some(BigUint::from(123538694848_u64))
1112        );
1113    }
1114
1115    #[test]
1116    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1117    fn test_legendre_symbol_zero() {
1118        assert!(legendre_symbol(&BigUint::zero(), &BigUint::one()).is_zero())
1119    }
1120
1121    #[test]
1122    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1123    fn test_is_quad_residue_prime_zero() {
1124        assert_eq!(
1125            is_quad_residue(&BigUint::one(), &BigUint::zero()),
1126            Err(MathError::IsQuadResidueZeroPrime)
1127        )
1128    }
1129
1130    #[test]
1131    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1132    fn test_is_quad_residue_prime_a_one_true() {
1133        assert_eq!(is_quad_residue(&BigUint::one(), &BigUint::one()), Ok(true))
1134    }
1135
1136    #[test]
1137    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1138    fn mul_inv_0_is_0() {
1139        let p = &(*CAIRO_PRIME).clone().into();
1140        let x = &BigInt::zero();
1141        let x_inv = mul_inv(x, p);
1142
1143        assert_eq!(x_inv, BigInt::zero());
1144    }
1145
1146    #[test]
1147    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1148    fn igcdex_1_1() {
1149        assert_eq!(
1150            igcdex(&BigInt::one(), &BigInt::one()),
1151            (BigInt::zero(), BigInt::one(), BigInt::one())
1152        )
1153    }
1154
1155    #[test]
1156    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1157    fn igcdex_0_0() {
1158        assert_eq!(
1159            igcdex(&BigInt::zero(), &BigInt::zero()),
1160            (BigInt::zero(), BigInt::one(), BigInt::zero())
1161        )
1162    }
1163
1164    #[test]
1165    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1166    fn igcdex_1_0() {
1167        assert_eq!(
1168            igcdex(&BigInt::one(), &BigInt::zero()),
1169            (BigInt::one(), BigInt::zero(), BigInt::one())
1170        )
1171    }
1172
1173    #[test]
1174    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1175    fn igcdex_4_6() {
1176        assert_eq!(
1177            igcdex(&BigInt::from(4), &BigInt::from(6)),
1178            (BigInt::from(-1), BigInt::one(), BigInt::from(2))
1179        )
1180    }
1181
1182    #[test]
1183    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1184    fn qm31_packed_reduced_read_coordinates_over_144_bits() {
1185        let mut felt_bytes = [0u8; 32];
1186        felt_bytes[18] = 1;
1187        let felt = Felt252::from_bytes_le(&felt_bytes);
1188        assert_matches!(
1189            qm31_packed_reduced_read_coordinates(felt),
1190            Err(MathError::QM31UnreducedError(bx)) if *bx == felt
1191        );
1192    }
1193
1194    #[test]
1195    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1196    fn qm31_packed_reduced_read_coordinates_unreduced() {
1197        let mut felt_bytes = [0u8; 32];
1198        felt_bytes[0] = 0xff;
1199        felt_bytes[1] = 0xff;
1200        felt_bytes[2] = 0xff;
1201        felt_bytes[3] = (1 << 7) - 1;
1202        let felt = Felt252::from_bytes_le(&felt_bytes);
1203        assert_matches!(
1204            qm31_packed_reduced_read_coordinates(felt),
1205            Err(MathError::QM31UnreducedError(bx)) if *bx == felt
1206        );
1207    }
1208
1209    #[test]
1210    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1211    fn test_qm31_packed_reduced_add() {
1212        let x_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1213        let y_coordinates = [1234567890, 1414213562, 1732050807, 1618033988];
1214        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1215        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1216        let res = qm31_packed_reduced_add(x, y).unwrap();
1217        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1218        assert_eq!(
1219            res_coordinates,
1220            Ok([
1221                (1414213562 + 1234567890) % STWO_PRIME,
1222                (1732050807 + 1414213562) % STWO_PRIME,
1223                (1618033988 + 1732050807) % STWO_PRIME,
1224                (1234567890 + 1618033988) % STWO_PRIME,
1225            ])
1226        );
1227    }
1228
1229    #[test]
1230    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1231    fn test_qm31_packed_reduced_neg() {
1232        let x_coordinates = [1749652895, 834624081, 1930174752, 2063872165];
1233        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1234        let res = qm31_packed_reduced_neg(x).unwrap();
1235        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1236        assert_eq!(
1237            res_coordinates,
1238            Ok([
1239                STWO_PRIME - x_coordinates[0],
1240                STWO_PRIME - x_coordinates[1],
1241                STWO_PRIME - x_coordinates[2],
1242                STWO_PRIME - x_coordinates[3]
1243            ])
1244        );
1245    }
1246
1247    #[test]
1248    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1249    fn test_qm31_packed_reduced_sub() {
1250        let x_coordinates = [
1251            (1414213562 + 1234567890) % STWO_PRIME,
1252            (1732050807 + 1414213562) % STWO_PRIME,
1253            (1618033988 + 1732050807) % STWO_PRIME,
1254            (1234567890 + 1618033988) % STWO_PRIME,
1255        ];
1256        let y_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1257        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1258        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1259        let res = qm31_packed_reduced_sub(x, y).unwrap();
1260        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1261        assert_eq!(
1262            res_coordinates,
1263            Ok([1234567890, 1414213562, 1732050807, 1618033988])
1264        );
1265    }
1266
1267    #[test]
1268    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1269    fn test_qm31_packed_reduced_mul() {
1270        let x_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1271        let y_coordinates = [1259921049, 1442249570, 1847759065, 2094551481];
1272        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1273        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1274        let res = qm31_packed_reduced_mul(x, y).unwrap();
1275        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1276        assert_eq!(
1277            res_coordinates,
1278            Ok([947980980, 1510986506, 623360030, 1260310989])
1279        );
1280    }
1281
1282    #[test]
1283    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1284    fn test_qm31_packed_reduced_inv() {
1285        let x_coordinates = [1259921049, 1442249570, 1847759065, 2094551481];
1286        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1287        let res = qm31_packed_reduced_inv(x).unwrap();
1288        assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1289
1290        let x_coordinates = [1, 2, 3, 4];
1291        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1292        let res = qm31_packed_reduced_inv(x).unwrap();
1293        assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1294
1295        let x_coordinates = [1749652895, 834624081, 1930174752, 2063872165];
1296        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1297        let res = qm31_packed_reduced_inv(x).unwrap();
1298        assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1299    }
1300
1301    #[test]
1302    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1303    fn test_qm31_packed_reduced_div() {
1304        let x_coordinates = [1259921049, 1442249570, 1847759065, 2094551481];
1305        let y_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1306        let xy_coordinates = [947980980, 1510986506, 623360030, 1260310989];
1307        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1308        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1309        let xy = qm31_coordinates_to_packed_reduced(xy_coordinates);
1310
1311        let res = qm31_packed_reduced_div(xy, y).unwrap();
1312        assert_eq!(res, x);
1313
1314        let res = qm31_packed_reduced_div(xy, x).unwrap();
1315        assert_eq!(res, y);
1316    }
1317
1318    /// Necessary strat to use proptest on the QM31 test
1319    #[cfg(feature = "std")]
1320    fn configuration_strat() -> BoxedStrategy<u64> {
1321        prop_oneof![Just(0), Just(1), Just(STWO_PRIME - 1), 0..STWO_PRIME].boxed()
1322    }
1323
1324    #[cfg(feature = "std")]
1325    proptest! {
1326
1327        #[test]
1328        #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1329        fn qm31_packed_reduced_inv_random(x_coordinates in uniform4(0u64..STWO_PRIME)
1330                                                            .prop_filter("All configs cant be 0",
1331                                                            |arr| !arr.iter().all(|x| *x == 0))
1332        ) {
1333            let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1334            let res = qm31_packed_reduced_inv(x).unwrap();
1335            assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1336        }
1337
1338        #[test]
1339        #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1340        fn qm31_packed_reduced_inv_extensive(x_coordinates in uniform4(configuration_strat())
1341                                                            .prop_filter("All configs cant be 0",
1342                                                            |arr| !arr.iter().all(|x| *x == 0))
1343                                                            .no_shrink()
1344        ) {
1345            let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1346            let res = qm31_packed_reduced_inv(x).unwrap();
1347            assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1348        }
1349
1350        #[test]
1351        fn pow2_const_in_range_returns_power_of_2(x in 0..=251u32) {
1352            prop_assert_eq!(pow2_const(x), Felt252::TWO.pow(x));
1353        }
1354
1355        #[test]
1356        fn pow2_const_oob_returns_1(x in 252u32..) {
1357            prop_assert_eq!(pow2_const(x), Felt252::ONE);
1358        }
1359
1360        #[test]
1361        fn pow2_const_nz_in_range_returns_power_of_2(x in 0..=251u32) {
1362            prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::TWO.pow(x));
1363        }
1364
1365        #[test]
1366        fn pow2_const_nz_oob_returns_1(x in 252u32..) {
1367            prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::ONE);
1368        }
1369
1370        #[test]
1371        // Test for sqrt_prime_power_ of a quadratic residue. Result should be the minimum root.
1372        fn sqrt_prime_power_using_random_prime(ref x in any::<[u8; 38]>(), ref y in any::<u64>()) {
1373            let mut rng = SmallRng::seed_from_u64(*y);
1374            let x = &BigUint::from_bytes_be(x);
1375            // Generate a prime here instead of relying on y, otherwise y may never be a prime number
1376            let p : &BigUint = &RandPrime::gen_prime(&mut rng, 384,  None);
1377            let x_sq = x * x;
1378            if let Some(sqrt) = sqrt_prime_power(&x_sq, p) {
1379                if &sqrt != x {
1380                    prop_assert_eq!(&(p - sqrt), x);
1381                } else {
1382                prop_assert_eq!(&sqrt, x);
1383                }
1384            }
1385        }
1386
1387        #[test]
1388        fn mul_inv_x_by_x_is_1(ref x in any::<[u8; 32]>()) {
1389            let p = &(*CAIRO_PRIME).clone().into();
1390            let pos_x = &BigInt::from_bytes_be(Sign::Plus, x);
1391            let neg_x = &BigInt::from_bytes_be(Sign::Minus, x);
1392            let pos_x_inv = mul_inv(pos_x, p);
1393            let neg_x_inv = mul_inv(neg_x, p);
1394
1395            prop_assert_eq!((pos_x * pos_x_inv).mod_floor(p), BigInt::one());
1396            prop_assert_eq!((neg_x * neg_x_inv).mod_floor(p), BigInt::one());
1397        }
1398    }
1399}