cairo_vm/math_utils/
mod.rs

1mod is_prime;
2
3pub use is_prime::is_prime;
4
5use core::cmp::min;
6
7use crate::stdlib::{boxed::Box, ops::Shr, prelude::Vec};
8use crate::types::errors::math_errors::MathError;
9use crate::utils::CAIRO_PRIME;
10use crate::Felt252;
11use lazy_static::lazy_static;
12use num_bigint::{BigInt, BigUint, RandBigInt, ToBigInt};
13use num_integer::Integer;
14use num_traits::{One, Signed, Zero};
15use rand::{rngs::SmallRng, SeedableRng};
16use starknet_types_core::felt::NonZeroFelt;
17
18lazy_static! {
19    pub static ref SIGNED_FELT_MAX: BigUint = (&*CAIRO_PRIME).shr(1_u32);
20    static ref POWERS_OF_TWO: Vec<NonZeroFelt> =
21        core::iter::successors(Some(Felt252::ONE), |x| Some(x * Felt252::TWO))
22            .take(252)
23            .map(|x| x.try_into().unwrap())
24            .collect::<Vec<_>>();
25}
26
27pub const STWO_PRIME: u64 = (1 << 31) - 1;
28const STWO_PRIME_U128: u128 = STWO_PRIME as u128;
29const MASK_36: u64 = (1 << 36) - 1;
30const MASK_8: u64 = (1 << 8) - 1;
31
32/// Returns the `n`th (up to the `251`th power) power of 2 as a [`Felt252`]
33/// in constant time.
34/// It silently returns `1` if the input is out of bounds.
35pub fn pow2_const(n: u32) -> Felt252 {
36    // If the conversion fails then it's out of range and we compute the power as usual
37    POWERS_OF_TWO
38        .get(n as usize)
39        .unwrap_or(&POWERS_OF_TWO[0])
40        .into()
41}
42
43/// Returns the `n`th (up to the `251`th power) power of 2 as a [`&stark_felt::NonZeroFelt`]
44/// in constant time.
45/// It silently returns `1` if the input is out of bounds.
46pub fn pow2_const_nz(n: u32) -> &'static NonZeroFelt {
47    // If the conversion fails then it's out of range and we compute the power as usual
48    POWERS_OF_TWO.get(n as usize).unwrap_or(&POWERS_OF_TWO[0])
49}
50
51/// Converts [`Felt252`] into a [`BigInt`] number in the range: `(- FIELD / 2, FIELD / 2)`.
52///
53/// # Examples
54///
55/// ```
56/// # use cairo_vm::{Felt252, math_utils::signed_felt};
57/// # use num_bigint::BigInt;
58/// let positive = Felt252::from(5);
59/// assert_eq!(signed_felt(positive), BigInt::from(5));
60///
61/// let negative = Felt252::MAX;
62/// assert_eq!(signed_felt(negative), BigInt::from(-1));
63/// ```
64pub fn signed_felt(felt: Felt252) -> BigInt {
65    let biguint = felt.to_biguint();
66    if biguint > *SIGNED_FELT_MAX {
67        BigInt::from_biguint(num_bigint::Sign::Minus, &*CAIRO_PRIME - &biguint)
68    } else {
69        biguint.to_bigint().expect("cannot fail")
70    }
71}
72
73pub fn signed_felt_for_prime(value: Felt252, prime: &BigUint) -> BigInt {
74    let value = value.to_biguint();
75    let half_prime = prime / 2u32;
76    if value > half_prime {
77        BigInt::from_biguint(num_bigint::Sign::Minus, prime - &value)
78    } else {
79        BigInt::from_biguint(num_bigint::Sign::Plus, value)
80    }
81}
82
83/// QM31 utility function, used specifically for Stwo.
84/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
85/// Reads four u64 coordinates from a single Felt252.
86/// STWO_PRIME fits in 36 bits, hence each coordinate can be represented by 36 bits and a QM31
87/// element can be stored in the first 144 bits of a Felt252.
88/// Returns an error if the input has over 144 bits or any coordinate is unreduced.
89fn qm31_packed_reduced_read_coordinates(felt: Felt252) -> Result<[u64; 4], MathError> {
90    let limbs = felt.to_le_digits();
91    if limbs[3] != 0 || limbs[2] >= 1 << 16 {
92        return Err(MathError::QM31UnreducedError(Box::new(felt)));
93    }
94    let coordinates = [
95        (limbs[0] & MASK_36),
96        ((limbs[0] >> 36) + ((limbs[1] & MASK_8) << 28)),
97        ((limbs[1] >> 8) & MASK_36),
98        ((limbs[1] >> 44) + (limbs[2] << 20)),
99    ];
100    for x in coordinates.iter() {
101        if *x >= STWO_PRIME {
102            return Err(MathError::QM31UnreducedError(Box::new(felt)));
103        }
104    }
105    Ok(coordinates)
106}
107
108/// QM31 utility function, used specifically for Stwo.
109/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
110/// Reduces four u64 coordinates and packs them into a single Felt252.
111/// STWO_PRIME fits in 36 bits, hence each coordinate can be represented by 36 bits and a QM31
112/// element can be stored in the first 144 bits of a Felt252.
113pub fn qm31_coordinates_to_packed_reduced(coordinates: [u64; 4]) -> Felt252 {
114    let bytes_part1 = ((coordinates[0] % STWO_PRIME) as u128
115        + (((coordinates[1] % STWO_PRIME) as u128) << 36))
116        .to_le_bytes();
117    let bytes_part2 = ((coordinates[2] % STWO_PRIME) as u128
118        + (((coordinates[3] % STWO_PRIME) as u128) << 36))
119        .to_le_bytes();
120    let mut result_bytes = [0u8; 32];
121    result_bytes[0..9].copy_from_slice(&bytes_part1[0..9]);
122    result_bytes[9..18].copy_from_slice(&bytes_part2[0..9]);
123    Felt252::from_bytes_le(&result_bytes)
124}
125
126/// QM31 utility function, used specifically for Stwo.
127/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
128/// Computes the addition of two QM31 elements in reduced form.
129/// Returns an error if either operand is not reduced.
130pub fn qm31_packed_reduced_add(felt1: Felt252, felt2: Felt252) -> Result<Felt252, MathError> {
131    let coordinates1 = qm31_packed_reduced_read_coordinates(felt1)?;
132    let coordinates2 = qm31_packed_reduced_read_coordinates(felt2)?;
133    let result_unreduced_coordinates = [
134        coordinates1[0] + coordinates2[0],
135        coordinates1[1] + coordinates2[1],
136        coordinates1[2] + coordinates2[2],
137        coordinates1[3] + coordinates2[3],
138    ];
139    Ok(qm31_coordinates_to_packed_reduced(
140        result_unreduced_coordinates,
141    ))
142}
143
144/// QM31 utility function, used specifically for Stwo.
145/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
146/// Computes the negative of a QM31 element in reduced form.
147/// Returns an error if the input is not reduced.
148pub fn qm31_packed_reduced_neg(felt: Felt252) -> Result<Felt252, MathError> {
149    let coordinates = qm31_packed_reduced_read_coordinates(felt)?;
150    Ok(qm31_coordinates_to_packed_reduced([
151        STWO_PRIME - coordinates[0],
152        STWO_PRIME - coordinates[1],
153        STWO_PRIME - coordinates[2],
154        STWO_PRIME - coordinates[3],
155    ]))
156}
157
158/// QM31 utility function, used specifically for Stwo.
159/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
160/// Computes the subtraction of two QM31 elements in reduced form.
161/// Returns an error if either operand is not reduced.
162pub fn qm31_packed_reduced_sub(felt1: Felt252, felt2: Felt252) -> Result<Felt252, MathError> {
163    let coordinates1 = qm31_packed_reduced_read_coordinates(felt1)?;
164    let coordinates2 = qm31_packed_reduced_read_coordinates(felt2)?;
165    let result_unreduced_coordinates = [
166        STWO_PRIME + coordinates1[0] - coordinates2[0],
167        STWO_PRIME + coordinates1[1] - coordinates2[1],
168        STWO_PRIME + coordinates1[2] - coordinates2[2],
169        STWO_PRIME + coordinates1[3] - coordinates2[3],
170    ];
171    Ok(qm31_coordinates_to_packed_reduced(
172        result_unreduced_coordinates,
173    ))
174}
175
176/// QM31 utility function, used specifically for Stwo.
177/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
178/// Computes the multiplication of two QM31 elements in reduced form.
179/// Returns an error if either operand is not reduced.
180pub fn qm31_packed_reduced_mul(felt1: Felt252, felt2: Felt252) -> Result<Felt252, MathError> {
181    let coordinates1_u64 = qm31_packed_reduced_read_coordinates(felt1)?;
182    let coordinates2_u64 = qm31_packed_reduced_read_coordinates(felt2)?;
183    let coordinates1 = coordinates1_u64.map(u128::from);
184    let coordinates2 = coordinates2_u64.map(u128::from);
185
186    let result_coordinates = [
187        ((5 * STWO_PRIME_U128 * STWO_PRIME_U128 + coordinates1[0] * coordinates2[0]
188            - coordinates1[1] * coordinates2[1]
189            + 2 * coordinates1[2] * coordinates2[2]
190            - 2 * coordinates1[3] * coordinates2[3]
191            - coordinates1[2] * coordinates2[3]
192            - coordinates1[3] * coordinates2[2])
193            % STWO_PRIME_U128) as u64,
194        ((STWO_PRIME_U128 * STWO_PRIME_U128
195            + coordinates1[0] * coordinates2[1]
196            + coordinates1[1] * coordinates2[0]
197            + 2 * (coordinates1[2] * coordinates2[3] + coordinates1[3] * coordinates2[2])
198            + coordinates1[2] * coordinates2[2]
199            - coordinates1[3] * coordinates2[3])
200            % STWO_PRIME_U128) as u64,
201        2 * STWO_PRIME * STWO_PRIME + coordinates1_u64[0] * coordinates2_u64[2]
202            - coordinates1_u64[1] * coordinates2_u64[3]
203            + coordinates1_u64[2] * coordinates2_u64[0]
204            - coordinates1_u64[3] * coordinates2_u64[1],
205        coordinates1_u64[0] * coordinates2_u64[3]
206            + coordinates1_u64[1] * coordinates2_u64[2]
207            + coordinates1_u64[2] * coordinates2_u64[1]
208            + coordinates1_u64[3] * coordinates2_u64[0],
209    ];
210    Ok(qm31_coordinates_to_packed_reduced(result_coordinates))
211}
212
213/// M31 utility function, used specifically for Stwo.
214/// M31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
215/// Computes the inverse in the M31 field using Fermat's little theorem, i.e., returns
216/// `v^(STWO_PRIME-2) modulo STWO_PRIME`, which is the inverse of v unless v % STWO_PRIME == 0.
217pub fn pow2147483645(v: u64) -> u64 {
218    let t0 = (sqn(v, 2) * v) % STWO_PRIME;
219    let t1 = (sqn(t0, 1) * t0) % STWO_PRIME;
220    let t2 = (sqn(t1, 3) * t0) % STWO_PRIME;
221    let t3 = (sqn(t2, 1) * t0) % STWO_PRIME;
222    let t4 = (sqn(t3, 8) * t3) % STWO_PRIME;
223    let t5 = (sqn(t4, 8) * t3) % STWO_PRIME;
224    (sqn(t5, 7) * t2) % STWO_PRIME
225}
226
227/// M31 utility function, used specifically for Stwo.
228/// M31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
229/// Computes `v^(2^n) modulo STWO_PRIME`.
230fn sqn(v: u64, n: usize) -> u64 {
231    let mut u = v;
232    for _ in 0..n {
233        u = (u * u) % STWO_PRIME;
234    }
235    u
236}
237
238/// QM31 utility function, used specifically for Stwo.
239/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
240/// Computes the inverse of a QM31 element in reduced form.
241/// Returns an error if the denominator is zero or either operand is not reduced.
242pub fn qm31_packed_reduced_inv(felt: Felt252) -> Result<Felt252, MathError> {
243    if felt.is_zero() {
244        return Err(MathError::DividedByZero);
245    }
246    let coordinates = qm31_packed_reduced_read_coordinates(felt)?;
247
248    let b2_r = (coordinates[2] * coordinates[2] + STWO_PRIME * STWO_PRIME
249        - coordinates[3] * coordinates[3])
250        % STWO_PRIME;
251    let b2_i = (2 * coordinates[2] * coordinates[3]) % STWO_PRIME;
252
253    let denom_r = (coordinates[0] * coordinates[0] + STWO_PRIME * STWO_PRIME
254        - coordinates[1] * coordinates[1]
255        + 2 * STWO_PRIME
256        - 2 * b2_r
257        + b2_i)
258        % STWO_PRIME;
259    let denom_i =
260        (2 * coordinates[0] * coordinates[1] + 3 * STWO_PRIME - 2 * b2_i - b2_r) % STWO_PRIME;
261
262    let denom_norm_squared = (denom_r * denom_r + denom_i * denom_i) % STWO_PRIME;
263    let denom_norm_inverse_squared = pow2147483645(denom_norm_squared);
264
265    let denom_inverse_r = (denom_r * denom_norm_inverse_squared) % STWO_PRIME;
266    let denom_inverse_i = ((STWO_PRIME - denom_i) * denom_norm_inverse_squared) % STWO_PRIME;
267
268    Ok(qm31_coordinates_to_packed_reduced([
269        coordinates[0] * denom_inverse_r + STWO_PRIME * STWO_PRIME
270            - coordinates[1] * denom_inverse_i,
271        coordinates[0] * denom_inverse_i + coordinates[1] * denom_inverse_r,
272        coordinates[3] * denom_inverse_i + STWO_PRIME * STWO_PRIME
273            - coordinates[2] * denom_inverse_r,
274        2 * STWO_PRIME * STWO_PRIME
275            - coordinates[2] * denom_inverse_i
276            - coordinates[3] * denom_inverse_r,
277    ]))
278}
279
280/// QM31 utility function, used specifically for Stwo.
281/// QM31 operations are to be relocated into https://github.com/lambdaclass/lambdaworks.
282/// Computes the division of two QM31 elements in reduced form.
283/// Returns an error if the input is zero.
284pub fn qm31_packed_reduced_div(felt1: Felt252, felt2: Felt252) -> Result<Felt252, MathError> {
285    let felt2_inv = qm31_packed_reduced_inv(felt2)?;
286    qm31_packed_reduced_mul(felt1, felt2_inv)
287}
288
289///Returns the integer square root of the nonnegative integer n.
290///This is the floor of the exact square root of n.
291///Unlike math.sqrt(), this function doesn't have rounding error issues.
292pub fn isqrt(n: &BigUint) -> Result<BigUint, MathError> {
293    /*    # The following algorithm was copied from
294    # https://stackoverflow.com/questions/15390807/integer-square-root-in-python.
295    x = n
296    y = (x + 1) // 2
297    while y < x:
298        x = y
299        y = (x + n // x) // 2
300    assert x**2 <= n < (x + 1) ** 2
301    return x*/
302
303    let mut x = n.clone();
304    //n.shr(1) = n.div_floor(2)
305    let mut y = (&x + 1_u32).shr(1_u32);
306
307    while y < x {
308        x = y;
309        y = (&x + n.div_floor(&x)).shr(1_u32);
310    }
311
312    if !(&BigUint::pow(&x, 2_u32) <= n && n < &BigUint::pow(&(&x + 1_u32), 2_u32)) {
313        return Err(MathError::FailedToGetSqrt(Box::new(n.clone())));
314    };
315    Ok(x)
316}
317
318/// Performs integer division between x and y; fails if x is not divisible by y.
319pub fn safe_div(x: &Felt252, y: &Felt252) -> Result<Felt252, MathError> {
320    let (q, r) = x.div_rem(&y.try_into().map_err(|_| MathError::DividedByZero)?);
321
322    if !r.is_zero() {
323        Err(MathError::SafeDivFail(Box::new((*x, *y))))
324    } else {
325        Ok(q)
326    }
327}
328
329/// Performs integer division between x and y; fails if x is not divisible by y.
330pub fn safe_div_bigint(x: &BigInt, y: &BigInt) -> Result<BigInt, MathError> {
331    if y.is_zero() {
332        return Err(MathError::DividedByZero);
333    }
334
335    let (q, r) = x.div_mod_floor(y);
336
337    if !r.is_zero() {
338        return Err(MathError::SafeDivFailBigInt(Box::new((
339            x.clone(),
340            y.clone(),
341        ))));
342    }
343
344    Ok(q)
345}
346
347/// Performs integer division between x and y; fails if x is not divisible by y.
348pub fn safe_div_usize(x: usize, y: usize) -> Result<usize, MathError> {
349    if y.is_zero() {
350        return Err(MathError::DividedByZero);
351    }
352
353    let (q, r) = x.div_mod_floor(&y);
354
355    if !r.is_zero() {
356        return Err(MathError::SafeDivFailUsize(Box::new((x, y))));
357    }
358
359    Ok(q)
360}
361
362///Returns num_a^-1 mod p
363pub(crate) fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt {
364    if num_a.is_zero() {
365        return BigInt::zero();
366    }
367    let mut a = num_a.abs();
368    let x_sign = num_a.signum();
369    let mut b = p.abs();
370    let (mut x, mut r) = (BigInt::one(), BigInt::zero());
371    let (mut c, mut q);
372    while !b.is_zero() {
373        (q, c) = a.div_mod_floor(&b);
374        x -= &q * &r;
375        (a, b, r, x) = (b, c, x, r)
376    }
377
378    x * x_sign
379}
380
381///Returns x, y, g such that g = x*a + y*b = gcd(a, b).
382fn igcdex(num_a: &BigInt, num_b: &BigInt) -> (BigInt, BigInt, BigInt) {
383    match (num_a, num_b) {
384        (a, b) if a.is_zero() && b.is_zero() => (BigInt::zero(), BigInt::one(), BigInt::zero()),
385        (a, _) if a.is_zero() => (BigInt::zero(), num_b.signum(), num_b.abs()),
386        (_, b) if b.is_zero() => (num_a.signum(), BigInt::zero(), num_a.abs()),
387        _ => {
388            let mut a = num_a.abs();
389            let x_sign = num_a.signum();
390            let mut b = num_b.abs();
391            let y_sign = num_b.signum();
392            let (mut x, mut y, mut r, mut s) =
393                (BigInt::one(), BigInt::zero(), BigInt::zero(), BigInt::one());
394            let (mut c, mut q);
395            while !b.is_zero() {
396                (q, c) = a.div_mod_floor(&b);
397                x -= &q * &r;
398                y -= &q * &s;
399                (a, b, r, s, x, y) = (b, c, x, y, r, s)
400            }
401            (x * x_sign, y * y_sign, a)
402        }
403    }
404}
405
406///Finds a nonnegative integer x < p such that (m * x) % p == n.
407pub fn div_mod(n: &BigInt, m: &BigInt, p: &BigInt) -> Result<BigInt, MathError> {
408    let (a, _, c) = igcdex(m, p);
409    if !c.is_one() {
410        return Err(MathError::DivModIgcdexNotZero(Box::new((
411            n.clone(),
412            m.clone(),
413            p.clone(),
414        ))));
415    }
416    Ok((n * a).mod_floor(p))
417}
418
419pub(crate) fn div_mod_unsigned(
420    n: &BigUint,
421    m: &BigUint,
422    p: &BigUint,
423) -> Result<BigUint, MathError> {
424    // BigUint to BigInt conversion cannot fail & div_mod will always return a positive value if all values are positive so we can safely unwrap here
425    div_mod(
426        &n.to_bigint().unwrap(),
427        &m.to_bigint().unwrap(),
428        &p.to_bigint().unwrap(),
429    )
430    .map(|i| i.to_biguint().unwrap())
431}
432
433pub fn ec_add(
434    point_a: (BigInt, BigInt),
435    point_b: (BigInt, BigInt),
436    prime: &BigInt,
437) -> Result<(BigInt, BigInt), MathError> {
438    let m = line_slope(&point_a, &point_b, prime)?;
439    let x = (m.clone() * m.clone() - point_a.0.clone() - point_b.0).mod_floor(prime);
440    let y = (m * (point_a.0 - x.clone()) - point_a.1).mod_floor(prime);
441    Ok((x, y))
442}
443
444/// Computes the slope of the line connecting the two given EC points over the field GF(p).
445/// Assumes the points are given in affine form (x, y) and have different x coordinates.
446pub fn line_slope(
447    point_a: &(BigInt, BigInt),
448    point_b: &(BigInt, BigInt),
449    prime: &BigInt,
450) -> Result<BigInt, MathError> {
451    debug_assert!(!(&point_a.0 - &point_b.0).is_multiple_of(prime));
452    div_mod(
453        &(&point_a.1 - &point_b.1),
454        &(&point_a.0 - &point_b.0),
455        prime,
456    )
457}
458
459///  Doubles a point on an elliptic curve with the equation y^2 = x^3 + alpha*x + beta mod p.
460/// Assumes the point is given in affine form (x, y) and has y != 0.
461pub fn ec_double(
462    point: (BigInt, BigInt),
463    alpha: &BigInt,
464    prime: &BigInt,
465) -> Result<(BigInt, BigInt), MathError> {
466    let m = ec_double_slope(&point, alpha, prime)?;
467    let x = ((&m * &m) - (2_i32 * &point.0)).mod_floor(prime);
468    let y = (m * (point.0 - &x) - point.1).mod_floor(prime);
469    Ok((x, y))
470}
471/// Computes the slope of an elliptic curve with the equation y^2 = x^3 + alpha*x + beta mod p, at
472/// the given point.
473/// Assumes the point is given in affine form (x, y) and has y != 0.
474pub fn ec_double_slope(
475    point: &(BigInt, BigInt),
476    alpha: &BigInt,
477    prime: &BigInt,
478) -> Result<BigInt, MathError> {
479    debug_assert!(!point.1.is_multiple_of(prime));
480    div_mod(
481        &(3_i32 * &point.0 * &point.0 + alpha),
482        &(2_i32 * &point.1),
483        prime,
484    )
485}
486
487// Adapted from sympy _sqrt_prime_power with k == 1
488pub fn sqrt_prime_power(a: &BigUint, p: &BigUint) -> Option<BigUint> {
489    if p.is_zero() || !is_prime(p) {
490        return None;
491    }
492    let two = BigUint::from(2_u32);
493    let a = a.mod_floor(p);
494    if p == &two {
495        return Some(a);
496    }
497    if !(a < two || (a.modpow(&(p - 1_u32).div_floor(&two), p)).is_one()) {
498        return None;
499    };
500
501    if p.mod_floor(&BigUint::from(4_u32)) == 3_u32.into() {
502        let res = a.modpow(&(p + 1_u32).div_floor(&BigUint::from(4_u32)), p);
503        return Some(min(res.clone(), p - res));
504    };
505
506    if p.mod_floor(&BigUint::from(8_u32)) == 5_u32.into() {
507        let sign = a.modpow(&(p - 1_u32).div_floor(&BigUint::from(4_u32)), p);
508        if sign.is_one() {
509            let res = a.modpow(&(p + 3_u32).div_floor(&BigUint::from(8_u32)), p);
510            return Some(min(res.clone(), p - res));
511        } else {
512            let b = (4_u32 * &a).modpow(&(p - 5_u32).div_floor(&BigUint::from(8_u32)), p);
513            let x = (2_u32 * &a * b).mod_floor(p);
514            if x.modpow(&two, p) == a {
515                return Some(x);
516            }
517        }
518    };
519
520    Some(sqrt_tonelli_shanks(&a, p))
521}
522
523fn sqrt_tonelli_shanks(n: &BigUint, prime: &BigUint) -> BigUint {
524    // Based on Tonelli-Shanks' algorithm for finding square roots
525    // and sympy's library implementation of said algorithm.
526    if n.is_zero() || n.is_one() {
527        return n.clone();
528    }
529    let s = (prime - 1_u32).trailing_zeros().unwrap_or_default();
530    let t = prime >> s;
531    let a = n.modpow(&t, prime);
532    // Rng is not critical here so its safe to use a seeded value
533    let mut rng = SmallRng::seed_from_u64(11480028852697973135);
534    let mut d;
535    loop {
536        d = RandBigInt::gen_biguint_range(&mut rng, &BigUint::from(2_u32), &(prime - 1_u32));
537        let r = legendre_symbol(&d, prime);
538        if r == -1 {
539            break;
540        };
541    }
542    d = d.modpow(&t, prime);
543    let mut m = BigUint::zero();
544    let mut exponent = BigUint::one() << (s - 1);
545    let mut adm;
546    for i in 0..s as u32 {
547        adm = &a * &d.modpow(&m, prime);
548        adm = adm.modpow(&exponent, prime);
549        exponent >>= 1;
550        if adm == (prime - 1_u32) {
551            m += BigUint::from(1_u32) << i;
552        }
553    }
554    let root_1 =
555        (n.modpow(&((t + 1_u32) >> 1), prime) * d.modpow(&(m >> 1), prime)).mod_floor(prime);
556    let root_2 = prime - &root_1;
557    if root_1 < root_2 {
558        root_1
559    } else {
560        root_2
561    }
562}
563
564/* Disclaimer: Some asumptions have been taken based on the functions that rely on this function, make sure these are true before calling this function individually
565Adpted from sympy implementation, asuming:
566    - p is an odd prime number
567    - a.mod_floor(p) == a
568Returns the Legendre symbol `(a / p)`.
569
570    For an integer ``a`` and an odd prime ``p``, the Legendre symbol is
571    defined as
572
573    .. math ::
574        \genfrac(){}{}{a}{p} = \begin{cases}
575             0 & \text{if } p \text{ divides } a\\
576             1 & \text{if } a \text{ is a quadratic residue modulo } p\\
577            -1 & \text{if } a \text{ is a quadratic nonresidue modulo } p
578        \end{cases}
579*/
580fn legendre_symbol(a: &BigUint, p: &BigUint) -> i8 {
581    if a.is_zero() {
582        return 0;
583    };
584    if is_quad_residue(a, p).unwrap_or_default() {
585        1
586    } else {
587        -1
588    }
589}
590
591// Ported from sympy implementation
592// Simplified as a & p are nonnegative
593// Asumes p is a prime number
594pub(crate) fn is_quad_residue(a: &BigUint, p: &BigUint) -> Result<bool, MathError> {
595    if p.is_zero() {
596        return Err(MathError::IsQuadResidueZeroPrime);
597    }
598    let a = if a >= p { a.mod_floor(p) } else { a.clone() };
599    if a < BigUint::from(2_u8) || p < &BigUint::from(3_u8) {
600        return Ok(true);
601    }
602    Ok(
603        a.modpow(&(p - BigUint::one()).div_floor(&BigUint::from(2_u8)), p)
604            .is_one(),
605    )
606}
607
608#[cfg(test)]
609mod tests {
610    use super::*;
611    use crate::utils::test_utils::*;
612    use crate::utils::CAIRO_PRIME;
613    use assert_matches::assert_matches;
614
615    use num_traits::Num;
616    use rand::Rng;
617
618    #[cfg(feature = "std")]
619    use num_prime::RandPrime;
620
621    #[cfg(feature = "std")]
622    use proptest::prelude::*;
623
624    // Only used in proptest for now
625    #[cfg(feature = "std")]
626    use num_bigint::Sign;
627
628    #[cfg(target_arch = "wasm32")]
629    use wasm_bindgen_test::*;
630
631    #[test]
632    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
633    fn calculate_divmod_a() {
634        let a = bigint_str!(
635            "11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788"
636        );
637        let b = bigint_str!(
638            "4020711254448367604954374443741161860304516084891705811279711044808359405970"
639        );
640        assert_eq!(
641            bigint_str!(
642                "2904750555256547440469454488220756360634457312540595732507835416669695939476"
643            ),
644            div_mod(
645                &a,
646                &b,
647                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
648                    .expect("Couldn't parse prime")
649            )
650            .unwrap()
651        );
652    }
653
654    #[test]
655    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
656    fn calculate_divmod_b() {
657        let a = bigint_str!(
658            "29642372811668969595956851264770043260610851505766181624574941701711520154703788233010819515917136995474951116158286220089597404329949295479559895970988"
659        );
660        let b = bigint_str!(
661            "3443173965374276972000139705137775968422921151703548011275075734291405722262"
662        );
663        assert_eq!(
664            bigint_str!(
665                "3601388548860259779932034493250169083811722919049731683411013070523752439691"
666            ),
667            div_mod(
668                &a,
669                &b,
670                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
671                    .expect("Couldn't parse prime")
672            )
673            .unwrap()
674        );
675    }
676
677    #[test]
678    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
679    fn calculate_divmod_c() {
680        let a = bigint_str!(
681            "1208267356464811040667664150251401430616174694388968865551115897173431833224432165394286799069453655049199580362994484548890574931604445970825506916876"
682        );
683        let b = bigint_str!(
684            "1809792356889571967986805709823554331258072667897598829955472663737669990418"
685        );
686        assert_eq!(
687            bigint_str!(
688                "1545825591488572374291664030703937603499513742109806697511239542787093258962"
689            ),
690            div_mod(
691                &a,
692                &b,
693                &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
694                    .expect("Couldn't parse prime")
695            )
696            .unwrap()
697        );
698    }
699
700    #[test]
701    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
702    fn compute_safe_div() {
703        let x = Felt252::from(26);
704        let y = Felt252::from(13);
705        assert_matches!(safe_div(&x, &y), Ok(i) if i == Felt252::from(2));
706    }
707
708    #[test]
709    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
710    fn compute_safe_div_non_divisor() {
711        let x = Felt252::from(25);
712        let y = Felt252::from(4);
713        let result = safe_div(&x, &y);
714        assert_matches!(
715            result,
716            Err(MathError::SafeDivFail(bx)) if *bx == (Felt252::from(25), Felt252::from(4)));
717    }
718
719    #[test]
720    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
721    fn compute_safe_div_by_zero() {
722        let x = Felt252::from(25);
723        let y = Felt252::ZERO;
724        let result = safe_div(&x, &y);
725        assert_matches!(result, Err(MathError::DividedByZero));
726    }
727
728    #[test]
729    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
730    fn compute_safe_div_usize() {
731        assert_matches!(safe_div_usize(26, 13), Ok(2));
732    }
733
734    #[test]
735    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
736    fn compute_safe_div_usize_non_divisor() {
737        assert_matches!(
738            safe_div_usize(25, 4),
739            Err(MathError::SafeDivFailUsize(bx)) if *bx == (25, 4)
740        );
741    }
742
743    #[test]
744    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
745    fn compute_safe_div_usize_by_zero() {
746        assert_matches!(safe_div_usize(25, 0), Err(MathError::DividedByZero));
747    }
748
749    #[test]
750    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
751    fn compute_line_slope_for_valid_points() {
752        let point_a = (
753            bigint_str!(
754                "3139037544796708144595053687182055617920475701120786241351436619796497072089"
755            ),
756            bigint_str!(
757                "2119589567875935397690285099786081818522144748339117565577200220779667999801"
758            ),
759        );
760        let point_b = (
761            bigint_str!(
762                "3324833730090626974525872402899302150520188025637965566623476530814354734325"
763            ),
764            bigint_str!(
765                "3147007486456030910661996439995670279305852583596209647900952752170983517249"
766            ),
767        );
768        let prime = (*CAIRO_PRIME).clone().into();
769        assert_eq!(
770            bigint_str!(
771                "992545364708437554384321881954558327331693627531977596999212637460266617010"
772            ),
773            line_slope(&point_a, &point_b, &prime).unwrap()
774        );
775    }
776
777    #[test]
778    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
779    fn compute_double_slope_for_valid_point_a() {
780        let point = (
781            bigint_str!(
782                "3143372541908290873737380228370996772020829254218248561772745122290262847573"
783            ),
784            bigint_str!(
785                "1721586982687138486000069852568887984211460575851774005637537867145702861131"
786            ),
787        );
788        let prime = (*CAIRO_PRIME).clone().into();
789        let alpha = bigint!(1);
790        assert_eq!(
791            bigint_str!(
792                "3601388548860259779932034493250169083811722919049731683411013070523752439691"
793            ),
794            ec_double_slope(&point, &alpha, &prime).unwrap()
795        );
796    }
797
798    #[test]
799    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
800    fn compute_double_slope_for_valid_point_b() {
801        let point = (
802            bigint_str!(
803                "1937407885261715145522756206040455121546447384489085099828343908348117672673"
804            ),
805            bigint_str!(
806                "2010355627224183802477187221870580930152258042445852905639855522404179702985"
807            ),
808        );
809        let prime = (*CAIRO_PRIME).clone().into();
810        let alpha = bigint!(1);
811        assert_eq!(
812            bigint_str!(
813                "2904750555256547440469454488220756360634457312540595732507835416669695939476"
814            ),
815            ec_double_slope(&point, &alpha, &prime).unwrap()
816        );
817    }
818
819    #[test]
820    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
821    fn calculate_ec_double_for_valid_point_a() {
822        let point = (
823            bigint_str!(
824                "1937407885261715145522756206040455121546447384489085099828343908348117672673"
825            ),
826            bigint_str!(
827                "2010355627224183802477187221870580930152258042445852905639855522404179702985"
828            ),
829        );
830        let prime = (*CAIRO_PRIME).clone().into();
831        let alpha = bigint!(1);
832        assert_eq!(
833            (
834                bigint_str!(
835                    "58460926014232092148191979591712815229424797874927791614218178721848875644"
836                ),
837                bigint_str!(
838                    "1065613861227134732854284722490492186040898336012372352512913425790457998694"
839                )
840            ),
841            ec_double(point, &alpha, &prime).unwrap()
842        );
843    }
844
845    #[test]
846    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
847    fn calculate_ec_double_for_valid_point_b() {
848        let point = (
849            bigint_str!(
850                "3143372541908290873737380228370996772020829254218248561772745122290262847573"
851            ),
852            bigint_str!(
853                "1721586982687138486000069852568887984211460575851774005637537867145702861131"
854            ),
855        );
856        let prime = (*CAIRO_PRIME).clone().into();
857        let alpha = bigint!(1);
858        assert_eq!(
859            (
860                bigint_str!(
861                    "1937407885261715145522756206040455121546447384489085099828343908348117672673"
862                ),
863                bigint_str!(
864                    "2010355627224183802477187221870580930152258042445852905639855522404179702985"
865                )
866            ),
867            ec_double(point, &alpha, &prime).unwrap()
868        );
869    }
870
871    #[test]
872    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
873    fn calculate_ec_double_for_valid_point_c() {
874        let point = (
875            bigint_str!(
876                "634630432210960355305430036410971013200846091773294855689580772209984122075"
877            ),
878            bigint_str!(
879                "904896178444785983993402854911777165629036333948799414977736331868834995209"
880            ),
881        );
882        let prime = (*CAIRO_PRIME).clone().into();
883        let alpha = bigint!(1);
884        assert_eq!(
885            (
886                bigint_str!(
887                    "3143372541908290873737380228370996772020829254218248561772745122290262847573"
888                ),
889                bigint_str!(
890                    "1721586982687138486000069852568887984211460575851774005637537867145702861131"
891                )
892            ),
893            ec_double(point, &alpha, &prime).unwrap()
894        );
895    }
896
897    #[test]
898    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
899    fn calculate_ec_add_for_valid_points_a() {
900        let point_a = (
901            bigint_str!(
902                "1183418161532233795704555250127335895546712857142554564893196731153957537489"
903            ),
904            bigint_str!(
905                "1938007580204102038458825306058547644691739966277761828724036384003180924526"
906            ),
907        );
908        let point_b = (
909            bigint_str!(
910                "1977703130303461992863803129734853218488251484396280000763960303272760326570"
911            ),
912            bigint_str!(
913                "2565191853811572867032277464238286011368568368717965689023024980325333517459"
914            ),
915        );
916        let prime = (*CAIRO_PRIME).clone().into();
917        assert_eq!(
918            (
919                bigint_str!(
920                    "1977874238339000383330315148209250828062304908491266318460063803060754089297"
921                ),
922                bigint_str!(
923                    "2969386888251099938335087541720168257053975603483053253007176033556822156706"
924                )
925            ),
926            ec_add(point_a, point_b, &prime).unwrap()
927        );
928    }
929
930    #[test]
931    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
932    fn calculate_ec_add_for_valid_points_b() {
933        let point_a = (
934            bigint_str!(
935                "3139037544796708144595053687182055617920475701120786241351436619796497072089"
936            ),
937            bigint_str!(
938                "2119589567875935397690285099786081818522144748339117565577200220779667999801"
939            ),
940        );
941        let point_b = (
942            bigint_str!(
943                "3324833730090626974525872402899302150520188025637965566623476530814354734325"
944            ),
945            bigint_str!(
946                "3147007486456030910661996439995670279305852583596209647900952752170983517249"
947            ),
948        );
949        let prime = (*CAIRO_PRIME).clone().into();
950        assert_eq!(
951            (
952                bigint_str!(
953                    "1183418161532233795704555250127335895546712857142554564893196731153957537489"
954                ),
955                bigint_str!(
956                    "1938007580204102038458825306058547644691739966277761828724036384003180924526"
957                )
958            ),
959            ec_add(point_a, point_b, &prime).unwrap()
960        );
961    }
962
963    #[test]
964    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
965    fn calculate_ec_add_for_valid_points_c() {
966        let point_a = (
967            bigint_str!(
968                "1183418161532233795704555250127335895546712857142554564893196731153957537489"
969            ),
970            bigint_str!(
971                "1938007580204102038458825306058547644691739966277761828724036384003180924526"
972            ),
973        );
974        let point_b = (
975            bigint_str!(
976                "1977703130303461992863803129734853218488251484396280000763960303272760326570"
977            ),
978            bigint_str!(
979                "2565191853811572867032277464238286011368568368717965689023024980325333517459"
980            ),
981        );
982        let prime = (*CAIRO_PRIME).clone().into();
983        assert_eq!(
984            (
985                bigint_str!(
986                    "1977874238339000383330315148209250828062304908491266318460063803060754089297"
987                ),
988                bigint_str!(
989                    "2969386888251099938335087541720168257053975603483053253007176033556822156706"
990                )
991            ),
992            ec_add(point_a, point_b, &prime).unwrap()
993        );
994    }
995
996    #[test]
997    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
998    fn calculate_isqrt_a() {
999        let n = biguint!(81);
1000        assert_matches!(isqrt(&n), Ok(x) if x == biguint!(9));
1001    }
1002
1003    #[test]
1004    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1005    fn calculate_isqrt_b() {
1006        let n = biguint_str!("4573659632505831259480");
1007        assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(num) if num == n);
1008    }
1009
1010    #[test]
1011    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1012    fn calculate_isqrt_c() {
1013        let n = biguint_str!(
1014            "3618502788666131213697322783095070105623107215331596699973092056135872020481"
1015        );
1016        assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(inner) if inner == n);
1017    }
1018
1019    #[test]
1020    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1021    fn calculate_isqrt_zero() {
1022        let n = BigUint::zero();
1023        assert_matches!(isqrt(&n), Ok(inner) if inner.is_zero());
1024    }
1025
1026    #[test]
1027    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1028    fn safe_div_bigint_by_zero() {
1029        let x = BigInt::one();
1030        let y = BigInt::zero();
1031        assert_matches!(safe_div_bigint(&x, &y), Err(MathError::DividedByZero))
1032    }
1033
1034    #[test]
1035    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1036    fn test_sqrt_prime_power() {
1037        let n: BigUint = 25_u32.into();
1038        let p: BigUint = 18446744069414584321_u128.into();
1039        assert_eq!(sqrt_prime_power(&n, &p), Some(5_u32.into()));
1040    }
1041
1042    #[test]
1043    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1044    fn test_sqrt_prime_power_p_is_zero() {
1045        let n = BigUint::one();
1046        let p: BigUint = BigUint::zero();
1047        assert_eq!(sqrt_prime_power(&n, &p), None);
1048    }
1049
1050    #[test]
1051    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1052    fn test_sqrt_prime_power_non_prime() {
1053        let p: BigUint = BigUint::from_bytes_be(&[
1054            69, 15, 232, 82, 215, 167, 38, 143, 173, 94, 133, 111, 1, 2, 182, 229, 110, 113, 76, 0,
1055            47, 110, 148, 109, 6, 133, 27, 190, 158, 197, 168, 219, 165, 254, 81, 53, 25, 34,
1056        ]);
1057        let n = BigUint::from_bytes_be(&[
1058            9, 13, 22, 191, 87, 62, 157, 83, 157, 85, 93, 105, 230, 187, 32, 101, 51, 181, 49, 202,
1059            203, 195, 76, 193, 149, 78, 109, 146, 240, 126, 182, 115, 161, 238, 30, 118, 157, 252,
1060        ]);
1061
1062        assert_eq!(sqrt_prime_power(&n, &p), None);
1063    }
1064
1065    #[test]
1066    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1067    fn test_sqrt_prime_power_none() {
1068        let n: BigUint = 10_u32.into();
1069        let p: BigUint = 602_u32.into();
1070        assert_eq!(sqrt_prime_power(&n, &p), None);
1071    }
1072
1073    #[test]
1074    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1075    fn test_sqrt_prime_power_prime_two() {
1076        let n: BigUint = 25_u32.into();
1077        let p: BigUint = 2_u32.into();
1078        assert_eq!(sqrt_prime_power(&n, &p), Some(BigUint::one()));
1079    }
1080
1081    #[test]
1082    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1083    fn test_sqrt_prime_power_prime_mod_8_is_5_sign_not_one() {
1084        let n: BigUint = 676_u32.into();
1085        let p: BigUint = 9956234341095173_u64.into();
1086        assert_eq!(
1087            sqrt_prime_power(&n, &p),
1088            Some(BigUint::from(9956234341095147_u64))
1089        );
1090    }
1091
1092    #[test]
1093    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1094    fn test_sqrt_prime_power_prime_mod_8_is_5_sign_is_one() {
1095        let n: BigUint = 130283432663_u64.into();
1096        let p: BigUint = 743900351477_u64.into();
1097        assert_eq!(
1098            sqrt_prime_power(&n, &p),
1099            Some(BigUint::from(123538694848_u64))
1100        );
1101    }
1102
1103    #[test]
1104    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1105    fn test_legendre_symbol_zero() {
1106        assert!(legendre_symbol(&BigUint::zero(), &BigUint::one()).is_zero())
1107    }
1108
1109    #[test]
1110    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1111    fn test_is_quad_residue_prime_zero() {
1112        assert_eq!(
1113            is_quad_residue(&BigUint::one(), &BigUint::zero()),
1114            Err(MathError::IsQuadResidueZeroPrime)
1115        )
1116    }
1117
1118    #[test]
1119    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1120    fn test_is_quad_residue_prime_a_one_true() {
1121        assert_eq!(is_quad_residue(&BigUint::one(), &BigUint::one()), Ok(true))
1122    }
1123
1124    #[test]
1125    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1126    fn mul_inv_0_is_0() {
1127        let p = &(*CAIRO_PRIME).clone().into();
1128        let x = &BigInt::zero();
1129        let x_inv = mul_inv(x, p);
1130
1131        assert_eq!(x_inv, BigInt::zero());
1132    }
1133
1134    #[test]
1135    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1136    fn igcdex_1_1() {
1137        assert_eq!(
1138            igcdex(&BigInt::one(), &BigInt::one()),
1139            (BigInt::zero(), BigInt::one(), BigInt::one())
1140        )
1141    }
1142
1143    #[test]
1144    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1145    fn igcdex_0_0() {
1146        assert_eq!(
1147            igcdex(&BigInt::zero(), &BigInt::zero()),
1148            (BigInt::zero(), BigInt::one(), BigInt::zero())
1149        )
1150    }
1151
1152    #[test]
1153    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1154    fn igcdex_1_0() {
1155        assert_eq!(
1156            igcdex(&BigInt::one(), &BigInt::zero()),
1157            (BigInt::one(), BigInt::zero(), BigInt::one())
1158        )
1159    }
1160
1161    #[test]
1162    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1163    fn igcdex_4_6() {
1164        assert_eq!(
1165            igcdex(&BigInt::from(4), &BigInt::from(6)),
1166            (BigInt::from(-1), BigInt::one(), BigInt::from(2))
1167        )
1168    }
1169
1170    #[test]
1171    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1172    fn qm31_packed_reduced_read_coordinates_over_144_bits() {
1173        let mut felt_bytes = [0u8; 32];
1174        felt_bytes[18] = 1;
1175        let felt = Felt252::from_bytes_le(&felt_bytes);
1176        assert_matches!(
1177            qm31_packed_reduced_read_coordinates(felt),
1178            Err(MathError::QM31UnreducedError(bx)) if *bx == felt
1179        );
1180    }
1181
1182    #[test]
1183    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1184    fn qm31_packed_reduced_read_coordinates_unreduced() {
1185        let mut felt_bytes = [0u8; 32];
1186        felt_bytes[0] = 0xff;
1187        felt_bytes[1] = 0xff;
1188        felt_bytes[2] = 0xff;
1189        felt_bytes[3] = (1 << 7) - 1;
1190        let felt = Felt252::from_bytes_le(&felt_bytes);
1191        assert_matches!(
1192            qm31_packed_reduced_read_coordinates(felt),
1193            Err(MathError::QM31UnreducedError(bx)) if *bx == felt
1194        );
1195    }
1196
1197    #[test]
1198    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1199    fn test_qm31_packed_reduced_add() {
1200        let x_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1201        let y_coordinates = [1234567890, 1414213562, 1732050807, 1618033988];
1202        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1203        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1204        let res = qm31_packed_reduced_add(x, y).unwrap();
1205        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1206        assert_eq!(
1207            res_coordinates,
1208            Ok([
1209                (1414213562 + 1234567890) % STWO_PRIME,
1210                (1732050807 + 1414213562) % STWO_PRIME,
1211                (1618033988 + 1732050807) % STWO_PRIME,
1212                (1234567890 + 1618033988) % STWO_PRIME,
1213            ])
1214        );
1215    }
1216
1217    #[test]
1218    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1219    fn test_qm31_packed_reduced_neg() {
1220        let x_coordinates = [1749652895, 834624081, 1930174752, 2063872165];
1221        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1222        let res = qm31_packed_reduced_neg(x).unwrap();
1223        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1224        assert_eq!(
1225            res_coordinates,
1226            Ok([
1227                STWO_PRIME - x_coordinates[0],
1228                STWO_PRIME - x_coordinates[1],
1229                STWO_PRIME - x_coordinates[2],
1230                STWO_PRIME - x_coordinates[3]
1231            ])
1232        );
1233    }
1234
1235    #[test]
1236    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1237    fn test_qm31_packed_reduced_sub() {
1238        let x_coordinates = [
1239            (1414213562 + 1234567890) % STWO_PRIME,
1240            (1732050807 + 1414213562) % STWO_PRIME,
1241            (1618033988 + 1732050807) % STWO_PRIME,
1242            (1234567890 + 1618033988) % STWO_PRIME,
1243        ];
1244        let y_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1245        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1246        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1247        let res = qm31_packed_reduced_sub(x, y).unwrap();
1248        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1249        assert_eq!(
1250            res_coordinates,
1251            Ok([1234567890, 1414213562, 1732050807, 1618033988])
1252        );
1253    }
1254
1255    #[test]
1256    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1257    fn test_qm31_packed_reduced_mul() {
1258        let x_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1259        let y_coordinates = [1259921049, 1442249570, 1847759065, 2094551481];
1260        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1261        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1262        let res = qm31_packed_reduced_mul(x, y).unwrap();
1263        let res_coordinates = qm31_packed_reduced_read_coordinates(res);
1264        assert_eq!(
1265            res_coordinates,
1266            Ok([947980980, 1510986506, 623360030, 1260310989])
1267        );
1268    }
1269
1270    #[test]
1271    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1272    fn test_qm31_packed_reduced_inv() {
1273        let x_coordinates = [1259921049, 1442249570, 1847759065, 2094551481];
1274        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1275        let res = qm31_packed_reduced_inv(x).unwrap();
1276        assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1277
1278        let x_coordinates = [1, 2, 3, 4];
1279        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1280        let res = qm31_packed_reduced_inv(x).unwrap();
1281        assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1282
1283        let x_coordinates = [1749652895, 834624081, 1930174752, 2063872165];
1284        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1285        let res = qm31_packed_reduced_inv(x).unwrap();
1286        assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1287    }
1288
1289    // TODO: Refactor using proptest and separating particular cases
1290    #[test]
1291    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1292    fn test_qm31_packed_reduced_inv_extensive() {
1293        let mut rng = SmallRng::seed_from_u64(11480028852697973135);
1294        #[derive(Clone, Copy)]
1295        enum Configuration {
1296            Zero,
1297            One,
1298            MinusOne,
1299            Random,
1300        }
1301        let configurations = [
1302            Configuration::Zero,
1303            Configuration::One,
1304            Configuration::MinusOne,
1305            Configuration::Random,
1306        ];
1307        let mut cartesian_product = vec![];
1308        for &a in &configurations {
1309            for &b in &configurations {
1310                for &c in &configurations {
1311                    for &d in &configurations {
1312                        cartesian_product.push([a, b, c, d]);
1313                    }
1314                }
1315            }
1316        }
1317
1318        for test_case in cartesian_product {
1319            let x_coordinates: [u64; 4] = test_case
1320                .iter()
1321                .map(|&x| match x {
1322                    Configuration::Zero => 0,
1323                    Configuration::One => 1,
1324                    Configuration::MinusOne => STWO_PRIME - 1,
1325                    Configuration::Random => rng.gen_range(0..STWO_PRIME),
1326                })
1327                .collect::<Vec<u64>>()
1328                .try_into()
1329                .unwrap();
1330            if x_coordinates == [0, 0, 0, 0] {
1331                continue;
1332            }
1333            let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1334            let res = qm31_packed_reduced_inv(x).unwrap();
1335            assert_eq!(qm31_packed_reduced_mul(x, res), Ok(Felt252::from(1)));
1336        }
1337    }
1338
1339    #[test]
1340    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
1341    fn test_qm31_packed_reduced_div() {
1342        let x_coordinates = [1259921049, 1442249570, 1847759065, 2094551481];
1343        let y_coordinates = [1414213562, 1732050807, 1618033988, 1234567890];
1344        let xy_coordinates = [947980980, 1510986506, 623360030, 1260310989];
1345        let x = qm31_coordinates_to_packed_reduced(x_coordinates);
1346        let y = qm31_coordinates_to_packed_reduced(y_coordinates);
1347        let xy = qm31_coordinates_to_packed_reduced(xy_coordinates);
1348
1349        let res = qm31_packed_reduced_div(xy, y).unwrap();
1350        assert_eq!(res, x);
1351
1352        let res = qm31_packed_reduced_div(xy, x).unwrap();
1353        assert_eq!(res, y);
1354    }
1355
1356    #[cfg(feature = "std")]
1357    proptest! {
1358        #[test]
1359        fn pow2_const_in_range_returns_power_of_2(x in 0..=251u32) {
1360            prop_assert_eq!(pow2_const(x), Felt252::TWO.pow(x));
1361        }
1362
1363        #[test]
1364        fn pow2_const_oob_returns_1(x in 252u32..) {
1365            prop_assert_eq!(pow2_const(x), Felt252::ONE);
1366        }
1367
1368        #[test]
1369        fn pow2_const_nz_in_range_returns_power_of_2(x in 0..=251u32) {
1370            prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::TWO.pow(x));
1371        }
1372
1373        #[test]
1374        fn pow2_const_nz_oob_returns_1(x in 252u32..) {
1375            prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::ONE);
1376        }
1377
1378        #[test]
1379        // Test for sqrt_prime_power_ of a quadratic residue. Result should be the minimum root.
1380        fn sqrt_prime_power_using_random_prime(ref x in any::<[u8; 38]>(), ref y in any::<u64>()) {
1381            let mut rng = SmallRng::seed_from_u64(*y);
1382            let x = &BigUint::from_bytes_be(x);
1383            // Generate a prime here instead of relying on y, otherwise y may never be a prime number
1384            let p : &BigUint = &RandPrime::gen_prime(&mut rng, 384,  None);
1385            let x_sq = x * x;
1386            if let Some(sqrt) = sqrt_prime_power(&x_sq, p) {
1387                if &sqrt != x {
1388                    prop_assert_eq!(&(p - sqrt), x);
1389                } else {
1390                prop_assert_eq!(&sqrt, x);
1391                }
1392            }
1393        }
1394
1395        #[test]
1396        fn mul_inv_x_by_x_is_1(ref x in any::<[u8; 32]>()) {
1397            let p = &(*CAIRO_PRIME).clone().into();
1398            let pos_x = &BigInt::from_bytes_be(Sign::Plus, x);
1399            let neg_x = &BigInt::from_bytes_be(Sign::Minus, x);
1400            let pos_x_inv = mul_inv(pos_x, p);
1401            let neg_x_inv = mul_inv(neg_x, p);
1402
1403            prop_assert_eq!((pos_x * pos_x_inv).mod_floor(p), BigInt::one());
1404            prop_assert_eq!((neg_x * neg_x_inv).mod_floor(p), BigInt::one());
1405        }
1406    }
1407}