1mod is_prime;
2
3pub use is_prime::is_prime;
4
5use core::cmp::min;
6
7use crate::types::errors::math_errors::MathError;
8use crate::utils::CAIRO_PRIME;
9use crate::Felt252;
10use lazy_static::lazy_static;
11use num_bigint::{BigInt, BigUint, RandBigInt, ToBigInt};
12use num_integer::Integer;
13use num_traits::{One, Signed, Zero};
14use rand::{rngs::SmallRng, SeedableRng};
15use starknet_types_core::felt::NonZeroFelt;
16use std::ops::Shr;
17
18lazy_static! {
19 pub static ref SIGNED_FELT_MAX: BigUint = (&*CAIRO_PRIME).shr(1_u32);
20 static ref POWERS_OF_TWO: Vec<NonZeroFelt> =
21 core::iter::successors(Some(Felt252::ONE), |x| Some(x * Felt252::TWO))
22 .take(252)
23 .map(|x| x.try_into().unwrap())
24 .collect::<Vec<_>>();
25}
26
27pub const STWO_PRIME: u32 = (1 << 31) - 1;
28
29pub(crate) fn qm31_pack_reduced(qm31: starknet_types_core::qm31::QM31) -> Felt252 {
34 let (a, b, c, d) = qm31.to_coefficients();
35 starknet_types_core::qm31::QM31::from_coefficients(
36 a % STWO_PRIME,
37 b % STWO_PRIME,
38 c % STWO_PRIME,
39 d % STWO_PRIME,
40 )
41 .pack_into_felt()
42}
43
44pub fn pow2_const(n: u32) -> Felt252 {
48 POWERS_OF_TWO
50 .get(n as usize)
51 .unwrap_or(&POWERS_OF_TWO[0])
52 .into()
53}
54
55pub fn pow2_const_nz(n: u32) -> &'static NonZeroFelt {
59 POWERS_OF_TWO.get(n as usize).unwrap_or(&POWERS_OF_TWO[0])
61}
62
63pub fn signed_felt(felt: Felt252) -> BigInt {
77 let biguint = felt.to_biguint();
78 if biguint > *SIGNED_FELT_MAX {
79 BigInt::from_biguint(num_bigint::Sign::Minus, &*CAIRO_PRIME - &biguint)
80 } else {
81 biguint.to_bigint().expect("cannot fail")
82 }
83}
84
85pub fn signed_felt_for_prime(value: Felt252, prime: &BigUint) -> BigInt {
86 let value = value.to_biguint();
87 let half_prime = prime / 2u32;
88 if value > half_prime {
89 BigInt::from_biguint(num_bigint::Sign::Minus, prime - &value)
90 } else {
91 BigInt::from_biguint(num_bigint::Sign::Plus, value)
92 }
93}
94
95pub fn isqrt(n: &BigUint) -> Result<BigUint, MathError> {
99 let mut x = n.clone();
110 let mut y = (&x + 1_u32).shr(1_u32);
112
113 while y < x {
114 x = y;
115 y = (&x + n.div_floor(&x)).shr(1_u32);
116 }
117
118 if !(&BigUint::pow(&x, 2_u32) <= n && n < &BigUint::pow(&(&x + 1_u32), 2_u32)) {
119 return Err(MathError::FailedToGetSqrt(Box::new(n.clone())));
120 };
121 Ok(x)
122}
123
124pub fn safe_div(x: &Felt252, y: &Felt252) -> Result<Felt252, MathError> {
126 let (q, r) = x.div_rem(&y.try_into().map_err(|_| MathError::DividedByZero)?);
127
128 if !r.is_zero() {
129 Err(MathError::SafeDivFail(Box::new((*x, *y))))
130 } else {
131 Ok(q)
132 }
133}
134
135pub fn safe_div_bigint(x: &BigInt, y: &BigInt) -> Result<BigInt, MathError> {
137 if y.is_zero() {
138 return Err(MathError::DividedByZero);
139 }
140
141 let (q, r) = x.div_mod_floor(y);
142
143 if !r.is_zero() {
144 return Err(MathError::SafeDivFailBigInt(Box::new((
145 x.clone(),
146 y.clone(),
147 ))));
148 }
149
150 Ok(q)
151}
152
153pub fn safe_div_usize(x: usize, y: usize) -> Result<usize, MathError> {
155 if y.is_zero() {
156 return Err(MathError::DividedByZero);
157 }
158
159 let (q, r) = x.div_mod_floor(&y);
160
161 if !r.is_zero() {
162 return Err(MathError::SafeDivFailUsize(Box::new((x, y))));
163 }
164
165 Ok(q)
166}
167
168pub(crate) fn mul_inv(num_a: &BigInt, p: &BigInt) -> BigInt {
170 if num_a.is_zero() {
171 return BigInt::zero();
172 }
173 let mut a = num_a.abs();
174 let x_sign = num_a.signum();
175 let mut b = p.abs();
176 let (mut x, mut r) = (BigInt::one(), BigInt::zero());
177 let (mut c, mut q);
178 while !b.is_zero() {
179 (q, c) = a.div_mod_floor(&b);
180 x -= &q * &r;
181 (a, b, r, x) = (b, c, x, r)
182 }
183
184 x * x_sign
185}
186
187fn igcdex(num_a: &BigInt, num_b: &BigInt) -> (BigInt, BigInt, BigInt) {
189 match (num_a, num_b) {
190 (a, b) if a.is_zero() && b.is_zero() => (BigInt::zero(), BigInt::one(), BigInt::zero()),
191 (a, _) if a.is_zero() => (BigInt::zero(), num_b.signum(), num_b.abs()),
192 (_, b) if b.is_zero() => (num_a.signum(), BigInt::zero(), num_a.abs()),
193 _ => {
194 let mut a = num_a.abs();
195 let x_sign = num_a.signum();
196 let mut b = num_b.abs();
197 let y_sign = num_b.signum();
198 let (mut x, mut y, mut r, mut s) =
199 (BigInt::one(), BigInt::zero(), BigInt::zero(), BigInt::one());
200 let (mut c, mut q);
201 while !b.is_zero() {
202 (q, c) = a.div_mod_floor(&b);
203 x -= &q * &r;
204 y -= &q * &s;
205 (a, b, r, s, x, y) = (b, c, x, y, r, s)
206 }
207 (x * x_sign, y * y_sign, a)
208 }
209 }
210}
211
212pub fn div_mod(n: &BigInt, m: &BigInt, p: &BigInt) -> Result<BigInt, MathError> {
214 let (a, _, c) = igcdex(m, p);
215 if !c.is_one() {
216 return Err(MathError::DivModIgcdexNotZero(Box::new((
217 n.clone(),
218 m.clone(),
219 p.clone(),
220 ))));
221 }
222 Ok((n * a).mod_floor(p))
223}
224
225pub(crate) fn div_mod_unsigned(
226 n: &BigUint,
227 m: &BigUint,
228 p: &BigUint,
229) -> Result<BigUint, MathError> {
230 div_mod(
232 &n.to_bigint().unwrap(),
233 &m.to_bigint().unwrap(),
234 &p.to_bigint().unwrap(),
235 )
236 .map(|i| i.to_biguint().unwrap())
237}
238
239pub fn ec_add(
240 point_a: (BigInt, BigInt),
241 point_b: (BigInt, BigInt),
242 prime: &BigInt,
243) -> Result<(BigInt, BigInt), MathError> {
244 let m = line_slope(&point_a, &point_b, prime)?;
245 let x = (m.clone() * m.clone() - point_a.0.clone() - point_b.0).mod_floor(prime);
246 let y = (m * (point_a.0 - x.clone()) - point_a.1).mod_floor(prime);
247 Ok((x, y))
248}
249
250pub fn line_slope(
253 point_a: &(BigInt, BigInt),
254 point_b: &(BigInt, BigInt),
255 prime: &BigInt,
256) -> Result<BigInt, MathError> {
257 debug_assert!(!(&point_a.0 - &point_b.0).is_multiple_of(prime));
258 div_mod(
259 &(&point_a.1 - &point_b.1),
260 &(&point_a.0 - &point_b.0),
261 prime,
262 )
263}
264
265pub fn ec_double(
268 point: (BigInt, BigInt),
269 alpha: &BigInt,
270 prime: &BigInt,
271) -> Result<(BigInt, BigInt), MathError> {
272 let m = ec_double_slope(&point, alpha, prime)?;
273 let x = ((&m * &m) - (2_i32 * &point.0)).mod_floor(prime);
274 let y = (m * (point.0 - &x) - point.1).mod_floor(prime);
275 Ok((x, y))
276}
277pub fn ec_double_slope(
281 point: &(BigInt, BigInt),
282 alpha: &BigInt,
283 prime: &BigInt,
284) -> Result<BigInt, MathError> {
285 debug_assert!(!point.1.is_multiple_of(prime));
286 div_mod(
287 &(3_i32 * &point.0 * &point.0 + alpha),
288 &(2_i32 * &point.1),
289 prime,
290 )
291}
292
293pub fn sqrt_prime_power(a: &BigUint, p: &BigUint) -> Option<BigUint> {
295 if p.is_zero() || !is_prime(p) {
296 return None;
297 }
298 let two = BigUint::from(2_u32);
299 let a = a.mod_floor(p);
300 if p == &two {
301 return Some(a);
302 }
303 if !(a < two || (a.modpow(&(p - 1_u32).div_floor(&two), p)).is_one()) {
304 return None;
305 };
306
307 if p.mod_floor(&BigUint::from(4_u32)) == 3_u32.into() {
308 let res = a.modpow(&(p + 1_u32).div_floor(&BigUint::from(4_u32)), p);
309 return Some(min(res.clone(), p - res));
310 };
311
312 if p.mod_floor(&BigUint::from(8_u32)) == 5_u32.into() {
313 let sign = a.modpow(&(p - 1_u32).div_floor(&BigUint::from(4_u32)), p);
314 if sign.is_one() {
315 let res = a.modpow(&(p + 3_u32).div_floor(&BigUint::from(8_u32)), p);
316 return Some(min(res.clone(), p - res));
317 } else {
318 let b = (4_u32 * &a).modpow(&(p - 5_u32).div_floor(&BigUint::from(8_u32)), p);
319 let x = (2_u32 * &a * b).mod_floor(p);
320 if x.modpow(&two, p) == a {
321 return Some(x);
322 }
323 }
324 };
325
326 Some(sqrt_tonelli_shanks(&a, p))
327}
328
329fn sqrt_tonelli_shanks(n: &BigUint, prime: &BigUint) -> BigUint {
330 if n.is_zero() || n.is_one() {
333 return n.clone();
334 }
335 let s = (prime - 1_u32).trailing_zeros().unwrap_or_default();
336 let t = prime >> s;
337 let a = n.modpow(&t, prime);
338 let mut rng = SmallRng::seed_from_u64(11480028852697973135);
340 let mut d;
341 loop {
342 d = RandBigInt::gen_biguint_range(&mut rng, &BigUint::from(2_u32), &(prime - 1_u32));
343 let r = legendre_symbol(&d, prime);
344 if r == -1 {
345 break;
346 };
347 }
348 d = d.modpow(&t, prime);
349 let mut m = BigUint::zero();
350 let mut exponent = BigUint::one() << (s - 1);
351 let mut adm;
352 for i in 0..s as u32 {
353 adm = &a * &d.modpow(&m, prime);
354 adm = adm.modpow(&exponent, prime);
355 exponent >>= 1;
356 if adm == (prime - 1_u32) {
357 m += BigUint::from(1_u32) << i;
358 }
359 }
360 let root_1 =
361 (n.modpow(&((t + 1_u32) >> 1), prime) * d.modpow(&(m >> 1), prime)).mod_floor(prime);
362 let root_2 = prime - &root_1;
363 if root_1 < root_2 {
364 root_1
365 } else {
366 root_2
367 }
368}
369
370fn legendre_symbol(a: &BigUint, p: &BigUint) -> i8 {
387 if a.is_zero() {
388 return 0;
389 };
390 if is_quad_residue(a, p).unwrap_or_default() {
391 1
392 } else {
393 -1
394 }
395}
396
397pub(crate) fn is_quad_residue(a: &BigUint, p: &BigUint) -> Result<bool, MathError> {
401 if p.is_zero() {
402 return Err(MathError::IsQuadResidueZeroPrime);
403 }
404 let a = if a >= p { a.mod_floor(p) } else { a.clone() };
405 if a < BigUint::from(2_u8) || p < &BigUint::from(3_u8) {
406 return Ok(true);
407 }
408 Ok(
409 a.modpow(&(p - BigUint::one()).div_floor(&BigUint::from(2_u8)), p)
410 .is_one(),
411 )
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417 use crate::utils::test_utils::*;
418 use crate::utils::CAIRO_PRIME;
419 use assert_matches::assert_matches;
420
421 use num_traits::Num;
422
423 use num_prime::RandPrime;
424
425 use proptest::prelude::*;
426
427 use num_bigint::Sign;
429
430 #[test]
431 fn calculate_divmod_a() {
432 let a = bigint_str!(
433 "11260647941622813594563746375280766662237311019551239924981511729608487775604310196863705127454617186486639011517352066501847110680463498585797912894788"
434 );
435 let b = bigint_str!(
436 "4020711254448367604954374443741161860304516084891705811279711044808359405970"
437 );
438 assert_eq!(
439 bigint_str!(
440 "2904750555256547440469454488220756360634457312540595732507835416669695939476"
441 ),
442 div_mod(
443 &a,
444 &b,
445 &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
446 .expect("Couldn't parse prime")
447 )
448 .unwrap()
449 );
450 }
451
452 #[test]
453 fn calculate_divmod_b() {
454 let a = bigint_str!(
455 "29642372811668969595956851264770043260610851505766181624574941701711520154703788233010819515917136995474951116158286220089597404329949295479559895970988"
456 );
457 let b = bigint_str!(
458 "3443173965374276972000139705137775968422921151703548011275075734291405722262"
459 );
460 assert_eq!(
461 bigint_str!(
462 "3601388548860259779932034493250169083811722919049731683411013070523752439691"
463 ),
464 div_mod(
465 &a,
466 &b,
467 &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
468 .expect("Couldn't parse prime")
469 )
470 .unwrap()
471 );
472 }
473
474 #[test]
475 fn calculate_divmod_c() {
476 let a = bigint_str!(
477 "1208267356464811040667664150251401430616174694388968865551115897173431833224432165394286799069453655049199580362994484548890574931604445970825506916876"
478 );
479 let b = bigint_str!(
480 "1809792356889571967986805709823554331258072667897598829955472663737669990418"
481 );
482 assert_eq!(
483 bigint_str!(
484 "1545825591488572374291664030703937603499513742109806697511239542787093258962"
485 ),
486 div_mod(
487 &a,
488 &b,
489 &BigInt::from_str_radix(&crate::utils::PRIME_STR[2..], 16)
490 .expect("Couldn't parse prime")
491 )
492 .unwrap()
493 );
494 }
495
496 #[test]
497 fn compute_safe_div() {
498 let x = Felt252::from(26);
499 let y = Felt252::from(13);
500 assert_matches!(safe_div(&x, &y), Ok(i) if i == Felt252::from(2));
501 }
502
503 #[test]
504 fn compute_safe_div_non_divisor() {
505 let x = Felt252::from(25);
506 let y = Felt252::from(4);
507 let result = safe_div(&x, &y);
508 assert_matches!(
509 result,
510 Err(MathError::SafeDivFail(bx)) if *bx == (Felt252::from(25), Felt252::from(4)));
511 }
512
513 #[test]
514 fn compute_safe_div_by_zero() {
515 let x = Felt252::from(25);
516 let y = Felt252::ZERO;
517 let result = safe_div(&x, &y);
518 assert_matches!(result, Err(MathError::DividedByZero));
519 }
520
521 #[test]
522 fn compute_safe_div_usize() {
523 assert_matches!(safe_div_usize(26, 13), Ok(2));
524 }
525
526 #[test]
527 fn compute_safe_div_usize_non_divisor() {
528 assert_matches!(
529 safe_div_usize(25, 4),
530 Err(MathError::SafeDivFailUsize(bx)) if *bx == (25, 4)
531 );
532 }
533
534 #[test]
535 fn compute_safe_div_usize_by_zero() {
536 assert_matches!(safe_div_usize(25, 0), Err(MathError::DividedByZero));
537 }
538
539 #[test]
540 fn compute_line_slope_for_valid_points() {
541 let point_a = (
542 bigint_str!(
543 "3139037544796708144595053687182055617920475701120786241351436619796497072089"
544 ),
545 bigint_str!(
546 "2119589567875935397690285099786081818522144748339117565577200220779667999801"
547 ),
548 );
549 let point_b = (
550 bigint_str!(
551 "3324833730090626974525872402899302150520188025637965566623476530814354734325"
552 ),
553 bigint_str!(
554 "3147007486456030910661996439995670279305852583596209647900952752170983517249"
555 ),
556 );
557 let prime = (*CAIRO_PRIME).clone().into();
558 assert_eq!(
559 bigint_str!(
560 "992545364708437554384321881954558327331693627531977596999212637460266617010"
561 ),
562 line_slope(&point_a, &point_b, &prime).unwrap()
563 );
564 }
565
566 #[test]
567 fn compute_double_slope_for_valid_point_a() {
568 let point = (
569 bigint_str!(
570 "3143372541908290873737380228370996772020829254218248561772745122290262847573"
571 ),
572 bigint_str!(
573 "1721586982687138486000069852568887984211460575851774005637537867145702861131"
574 ),
575 );
576 let prime = (*CAIRO_PRIME).clone().into();
577 let alpha = bigint!(1);
578 assert_eq!(
579 bigint_str!(
580 "3601388548860259779932034493250169083811722919049731683411013070523752439691"
581 ),
582 ec_double_slope(&point, &alpha, &prime).unwrap()
583 );
584 }
585
586 #[test]
587 fn compute_double_slope_for_valid_point_b() {
588 let point = (
589 bigint_str!(
590 "1937407885261715145522756206040455121546447384489085099828343908348117672673"
591 ),
592 bigint_str!(
593 "2010355627224183802477187221870580930152258042445852905639855522404179702985"
594 ),
595 );
596 let prime = (*CAIRO_PRIME).clone().into();
597 let alpha = bigint!(1);
598 assert_eq!(
599 bigint_str!(
600 "2904750555256547440469454488220756360634457312540595732507835416669695939476"
601 ),
602 ec_double_slope(&point, &alpha, &prime).unwrap()
603 );
604 }
605
606 #[test]
607 fn calculate_ec_double_for_valid_point_a() {
608 let point = (
609 bigint_str!(
610 "1937407885261715145522756206040455121546447384489085099828343908348117672673"
611 ),
612 bigint_str!(
613 "2010355627224183802477187221870580930152258042445852905639855522404179702985"
614 ),
615 );
616 let prime = (*CAIRO_PRIME).clone().into();
617 let alpha = bigint!(1);
618 assert_eq!(
619 (
620 bigint_str!(
621 "58460926014232092148191979591712815229424797874927791614218178721848875644"
622 ),
623 bigint_str!(
624 "1065613861227134732854284722490492186040898336012372352512913425790457998694"
625 )
626 ),
627 ec_double(point, &alpha, &prime).unwrap()
628 );
629 }
630
631 #[test]
632 fn calculate_ec_double_for_valid_point_b() {
633 let point = (
634 bigint_str!(
635 "3143372541908290873737380228370996772020829254218248561772745122290262847573"
636 ),
637 bigint_str!(
638 "1721586982687138486000069852568887984211460575851774005637537867145702861131"
639 ),
640 );
641 let prime = (*CAIRO_PRIME).clone().into();
642 let alpha = bigint!(1);
643 assert_eq!(
644 (
645 bigint_str!(
646 "1937407885261715145522756206040455121546447384489085099828343908348117672673"
647 ),
648 bigint_str!(
649 "2010355627224183802477187221870580930152258042445852905639855522404179702985"
650 )
651 ),
652 ec_double(point, &alpha, &prime).unwrap()
653 );
654 }
655
656 #[test]
657 fn calculate_ec_double_for_valid_point_c() {
658 let point = (
659 bigint_str!(
660 "634630432210960355305430036410971013200846091773294855689580772209984122075"
661 ),
662 bigint_str!(
663 "904896178444785983993402854911777165629036333948799414977736331868834995209"
664 ),
665 );
666 let prime = (*CAIRO_PRIME).clone().into();
667 let alpha = bigint!(1);
668 assert_eq!(
669 (
670 bigint_str!(
671 "3143372541908290873737380228370996772020829254218248561772745122290262847573"
672 ),
673 bigint_str!(
674 "1721586982687138486000069852568887984211460575851774005637537867145702861131"
675 )
676 ),
677 ec_double(point, &alpha, &prime).unwrap()
678 );
679 }
680
681 #[test]
682 fn calculate_ec_add_for_valid_points_a() {
683 let point_a = (
684 bigint_str!(
685 "1183418161532233795704555250127335895546712857142554564893196731153957537489"
686 ),
687 bigint_str!(
688 "1938007580204102038458825306058547644691739966277761828724036384003180924526"
689 ),
690 );
691 let point_b = (
692 bigint_str!(
693 "1977703130303461992863803129734853218488251484396280000763960303272760326570"
694 ),
695 bigint_str!(
696 "2565191853811572867032277464238286011368568368717965689023024980325333517459"
697 ),
698 );
699 let prime = (*CAIRO_PRIME).clone().into();
700 assert_eq!(
701 (
702 bigint_str!(
703 "1977874238339000383330315148209250828062304908491266318460063803060754089297"
704 ),
705 bigint_str!(
706 "2969386888251099938335087541720168257053975603483053253007176033556822156706"
707 )
708 ),
709 ec_add(point_a, point_b, &prime).unwrap()
710 );
711 }
712
713 #[test]
714 fn calculate_ec_add_for_valid_points_b() {
715 let point_a = (
716 bigint_str!(
717 "3139037544796708144595053687182055617920475701120786241351436619796497072089"
718 ),
719 bigint_str!(
720 "2119589567875935397690285099786081818522144748339117565577200220779667999801"
721 ),
722 );
723 let point_b = (
724 bigint_str!(
725 "3324833730090626974525872402899302150520188025637965566623476530814354734325"
726 ),
727 bigint_str!(
728 "3147007486456030910661996439995670279305852583596209647900952752170983517249"
729 ),
730 );
731 let prime = (*CAIRO_PRIME).clone().into();
732 assert_eq!(
733 (
734 bigint_str!(
735 "1183418161532233795704555250127335895546712857142554564893196731153957537489"
736 ),
737 bigint_str!(
738 "1938007580204102038458825306058547644691739966277761828724036384003180924526"
739 )
740 ),
741 ec_add(point_a, point_b, &prime).unwrap()
742 );
743 }
744
745 #[test]
746 fn calculate_ec_add_for_valid_points_c() {
747 let point_a = (
748 bigint_str!(
749 "1183418161532233795704555250127335895546712857142554564893196731153957537489"
750 ),
751 bigint_str!(
752 "1938007580204102038458825306058547644691739966277761828724036384003180924526"
753 ),
754 );
755 let point_b = (
756 bigint_str!(
757 "1977703130303461992863803129734853218488251484396280000763960303272760326570"
758 ),
759 bigint_str!(
760 "2565191853811572867032277464238286011368568368717965689023024980325333517459"
761 ),
762 );
763 let prime = (*CAIRO_PRIME).clone().into();
764 assert_eq!(
765 (
766 bigint_str!(
767 "1977874238339000383330315148209250828062304908491266318460063803060754089297"
768 ),
769 bigint_str!(
770 "2969386888251099938335087541720168257053975603483053253007176033556822156706"
771 )
772 ),
773 ec_add(point_a, point_b, &prime).unwrap()
774 );
775 }
776
777 #[test]
778 fn calculate_isqrt_a() {
779 let n = biguint!(81);
780 assert_matches!(isqrt(&n), Ok(x) if x == biguint!(9));
781 }
782
783 #[test]
784 fn calculate_isqrt_b() {
785 let n = biguint_str!("4573659632505831259480");
786 assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(num) if num == n);
787 }
788
789 #[test]
790 fn calculate_isqrt_c() {
791 let n = biguint_str!(
792 "3618502788666131213697322783095070105623107215331596699973092056135872020481"
793 );
794 assert_matches!(isqrt(&BigUint::pow(&n, 2_u32)), Ok(inner) if inner == n);
795 }
796
797 #[test]
798 fn calculate_isqrt_zero() {
799 let n = BigUint::zero();
800 assert_matches!(isqrt(&n), Ok(inner) if inner.is_zero());
801 }
802
803 #[test]
804 fn safe_div_bigint_by_zero() {
805 let x = BigInt::one();
806 let y = BigInt::zero();
807 assert_matches!(safe_div_bigint(&x, &y), Err(MathError::DividedByZero))
808 }
809
810 #[test]
811 fn test_sqrt_prime_power() {
812 let n: BigUint = 25_u32.into();
813 let p: BigUint = 18446744069414584321_u128.into();
814 assert_eq!(sqrt_prime_power(&n, &p), Some(5_u32.into()));
815 }
816
817 #[test]
818 fn test_sqrt_prime_power_p_is_zero() {
819 let n = BigUint::one();
820 let p: BigUint = BigUint::zero();
821 assert_eq!(sqrt_prime_power(&n, &p), None);
822 }
823
824 #[test]
825 fn test_sqrt_prime_power_non_prime() {
826 let p: BigUint = BigUint::from_bytes_be(&[
827 69, 15, 232, 82, 215, 167, 38, 143, 173, 94, 133, 111, 1, 2, 182, 229, 110, 113, 76, 0,
828 47, 110, 148, 109, 6, 133, 27, 190, 158, 197, 168, 219, 165, 254, 81, 53, 25, 34,
829 ]);
830 let n = BigUint::from_bytes_be(&[
831 9, 13, 22, 191, 87, 62, 157, 83, 157, 85, 93, 105, 230, 187, 32, 101, 51, 181, 49, 202,
832 203, 195, 76, 193, 149, 78, 109, 146, 240, 126, 182, 115, 161, 238, 30, 118, 157, 252,
833 ]);
834
835 assert_eq!(sqrt_prime_power(&n, &p), None);
836 }
837
838 #[test]
839 fn test_sqrt_prime_power_none() {
840 let n: BigUint = 10_u32.into();
841 let p: BigUint = 602_u32.into();
842 assert_eq!(sqrt_prime_power(&n, &p), None);
843 }
844
845 #[test]
846 fn test_sqrt_prime_power_prime_two() {
847 let n: BigUint = 25_u32.into();
848 let p: BigUint = 2_u32.into();
849 assert_eq!(sqrt_prime_power(&n, &p), Some(BigUint::one()));
850 }
851
852 #[test]
853 fn test_sqrt_prime_power_prime_mod_8_is_5_sign_not_one() {
854 let n: BigUint = 676_u32.into();
855 let p: BigUint = 9956234341095173_u64.into();
856 assert_eq!(
857 sqrt_prime_power(&n, &p),
858 Some(BigUint::from(9956234341095147_u64))
859 );
860 }
861
862 #[test]
863 fn test_sqrt_prime_power_prime_mod_8_is_5_sign_is_one() {
864 let n: BigUint = 130283432663_u64.into();
865 let p: BigUint = 743900351477_u64.into();
866 assert_eq!(
867 sqrt_prime_power(&n, &p),
868 Some(BigUint::from(123538694848_u64))
869 );
870 }
871
872 #[test]
873 fn test_legendre_symbol_zero() {
874 assert!(legendre_symbol(&BigUint::zero(), &BigUint::one()).is_zero())
875 }
876
877 #[test]
878 fn test_is_quad_residue_prime_zero() {
879 assert_eq!(
880 is_quad_residue(&BigUint::one(), &BigUint::zero()),
881 Err(MathError::IsQuadResidueZeroPrime)
882 )
883 }
884
885 #[test]
886 fn test_is_quad_residue_prime_a_one_true() {
887 assert_eq!(is_quad_residue(&BigUint::one(), &BigUint::one()), Ok(true))
888 }
889
890 #[test]
891 fn mul_inv_0_is_0() {
892 let p = &(*CAIRO_PRIME).clone().into();
893 let x = &BigInt::zero();
894 let x_inv = mul_inv(x, p);
895
896 assert_eq!(x_inv, BigInt::zero());
897 }
898
899 #[test]
900 fn igcdex_1_1() {
901 assert_eq!(
902 igcdex(&BigInt::one(), &BigInt::one()),
903 (BigInt::zero(), BigInt::one(), BigInt::one())
904 )
905 }
906
907 #[test]
908 fn igcdex_0_0() {
909 assert_eq!(
910 igcdex(&BigInt::zero(), &BigInt::zero()),
911 (BigInt::zero(), BigInt::one(), BigInt::zero())
912 )
913 }
914
915 #[test]
916 fn igcdex_1_0() {
917 assert_eq!(
918 igcdex(&BigInt::one(), &BigInt::zero()),
919 (BigInt::one(), BigInt::zero(), BigInt::one())
920 )
921 }
922
923 #[test]
924 fn igcdex_4_6() {
925 assert_eq!(
926 igcdex(&BigInt::from(4), &BigInt::from(6)),
927 (BigInt::from(-1), BigInt::one(), BigInt::from(2))
928 )
929 }
930
931 proptest! {
932
933 #[test]
934 fn pow2_const_in_range_returns_power_of_2(x in 0..=251u32) {
935 prop_assert_eq!(pow2_const(x), Felt252::TWO.pow(x));
936 }
937
938 #[test]
939 fn pow2_const_oob_returns_1(x in 252u32..) {
940 prop_assert_eq!(pow2_const(x), Felt252::ONE);
941 }
942
943 #[test]
944 fn pow2_const_nz_in_range_returns_power_of_2(x in 0..=251u32) {
945 prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::TWO.pow(x));
946 }
947
948 #[test]
949 fn pow2_const_nz_oob_returns_1(x in 252u32..) {
950 prop_assert_eq!(Felt252::from(pow2_const_nz(x)), Felt252::ONE);
951 }
952
953 #[test]
954 fn sqrt_prime_power_using_random_prime(ref x in any::<[u8; 38]>(), ref y in any::<u64>()) {
956 let mut rng = SmallRng::seed_from_u64(*y);
957 let x = &BigUint::from_bytes_be(x);
958 let p : &BigUint = &RandPrime::gen_prime(&mut rng, 384, None);
960 let x_sq = x * x;
961 if let Some(sqrt) = sqrt_prime_power(&x_sq, p) {
962 if &sqrt != x {
963 prop_assert_eq!(&(p - sqrt), x);
964 } else {
965 prop_assert_eq!(&sqrt, x);
966 }
967 }
968 }
969
970 #[test]
971 fn mul_inv_x_by_x_is_1(ref x in any::<[u8; 32]>()) {
972 let p = &(*CAIRO_PRIME).clone().into();
973 let pos_x = &BigInt::from_bytes_be(Sign::Plus, x);
974 let neg_x = &BigInt::from_bytes_be(Sign::Minus, x);
975 let pos_x_inv = mul_inv(pos_x, p);
976 let neg_x_inv = mul_inv(neg_x, p);
977
978 prop_assert_eq!((pos_x * pos_x_inv).mod_floor(p), BigInt::one());
979 prop_assert_eq!((neg_x * neg_x_inv).mod_floor(p), BigInt::one());
980 }
981 }
982}