use crate::math::{
jacobi::legendre_symbol_given_mont_params,
misc::{crt_combine, euler_totient, is_3_mod_4},
};
use core::ops::{Add, AddAssign, Shr, Sub};
use crypto_bigint::{
modular::{MontyForm, MontyParams, SafeGcdInverter},
subtle::ConstantTimeEq,
Concat, Integer, Odd, PrecomputeInverter, Split, Uint,
};
pub fn is_quadratic_residue_mod_prime<const LIMBS: usize>(
a: &Uint<LIMBS>,
p: MontyParams<LIMBS>,
) -> bool {
legendre_symbol_given_mont_params(a, p).is_one()
}
pub fn is_quadratic_non_residue_mod_prime<const LIMBS: usize>(
a: &Uint<LIMBS>,
p: MontyParams<LIMBS>,
) -> bool {
legendre_symbol_given_mont_params(a, p).is_minus_one()
}
pub fn sqrt_using_tonelli_shanks<const LIMBS: usize, const WIDE_LIMBS: usize>(
a: Uint<LIMBS>,
p: Odd<Uint<LIMBS>>,
) -> Option<Uint<LIMBS>>
where
Uint<LIMBS>: Concat<Output = Uint<WIDE_LIMBS>>,
Uint<WIDE_LIMBS>: Split<Output = Uint<LIMBS>>,
{
let p_mtg = MontyParams::new(p);
sqrt_using_tonelli_shanks_given_mtg_params::<LIMBS>(a, p_mtg)
}
pub fn sqrt_using_tonelli_shanks_given_mtg_params<const LIMBS: usize>(
a: Uint<LIMBS>,
p: MontyParams<LIMBS>,
) -> Option<Uint<LIMBS>> {
if !is_quadratic_residue_mod_prime(&a, p) {
return None;
}
let mut q = p.modulus().sub(Uint::ONE);
let mut s = 0;
while q.is_even().into() {
q = q.shr(1);
s += 1;
}
let mut z = Uint::from(2_u32);
while !is_quadratic_non_residue_mod_prime(&z, p) {
z.add_assign(&Uint::ONE);
}
let mut m = s;
let mut c = MontyForm::new(&z, p).pow(&q);
let mut t = MontyForm::new(&a, p).pow(&q);
let mut r = MontyForm::new(&a, p).pow(&q.add(&Uint::ONE).shr(1));
let one = MontyForm::one(p);
while t.ct_ne(&one).into() {
let mut i = 0;
let mut temp = t.clone();
while temp.ct_ne(&one).into() {
temp = temp.square();
i += 1;
}
let exp = Uint::<LIMBS>::ONE.shl(m - i - 1);
let b = c.pow(&exp);
let b_sqr = b.square();
r = r * b;
t = t * b_sqr;
c = b_sqr;
m = i;
}
Some(r.retrieve())
}
pub fn sqrt_for_blum_prime<const LIMBS: usize, const WIDE_LIMBS: usize>(
a: Uint<LIMBS>,
p: Odd<Uint<LIMBS>>,
) -> Option<Uint<LIMBS>>
where
Uint<LIMBS>: Concat<Output = Uint<WIDE_LIMBS>>,
Uint<WIDE_LIMBS>: Split<Output = Uint<LIMBS>>,
{
let p_mtg = MontyParams::new(p);
sqrt_for_blum_prime_given_mtg_params::<LIMBS>(a, p_mtg)
}
pub fn sqrt_for_blum_prime_given_mtg_params<const LIMBS: usize>(
a: Uint<LIMBS>,
p: MontyParams<LIMBS>,
) -> Option<Uint<LIMBS>> {
let p_plus_1 = p.modulus().add(Uint::<LIMBS>::ONE);
let exp = p_plus_1.shr(2);
sqrt_for_blum_prime_given_precomp(a, &exp, p)
}
pub fn sqrt_for_blum_prime_given_precomp<const LIMBS: usize>(
a: Uint<LIMBS>,
exp: &Uint<LIMBS>,
p_mtg: MontyParams<LIMBS>,
) -> Option<Uint<LIMBS>> {
if !is_quadratic_residue_mod_prime(&a, p_mtg) {
return None; }
Some(MontyForm::new(&a, p_mtg).pow(&exp).retrieve())
}
pub fn sqrt_mod_prime<const LIMBS: usize, const WIDE_LIMBS: usize>(
a: Uint<LIMBS>,
p: Odd<Uint<LIMBS>>,
) -> Option<Uint<LIMBS>>
where
Uint<LIMBS>: Concat<Output = Uint<WIDE_LIMBS>>,
Uint<WIDE_LIMBS>: Split<Output = Uint<LIMBS>>,
{
if is_3_mod_4(p.as_ref()) {
sqrt_for_blum_prime(a, p)
} else {
sqrt_using_tonelli_shanks(a, p)
}
}
pub fn sqrt_mod_prime_given_mtg_params<const LIMBS: usize>(
a: Uint<LIMBS>,
p_mtg: MontyParams<LIMBS>,
) -> Option<Uint<LIMBS>> {
if is_3_mod_4(p_mtg.modulus().as_ref()) {
sqrt_for_blum_prime_given_mtg_params(a, p_mtg)
} else {
sqrt_using_tonelli_shanks_given_mtg_params(a, p_mtg)
}
}
pub fn sqrt_mod_composite_given_prime_factors<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
const PRIME_UNSAT_LIMBS: usize,
>(
a: Uint<PRIME_PRODUCT_LIMBS>,
p: Odd<Uint<PRIME_LIMBS>>,
q: Odd<Uint<PRIME_LIMBS>>,
) -> Option<Uint<PRIME_PRODUCT_LIMBS>>
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
Uint<PRIME_PRODUCT_LIMBS>: Split<Output = Uint<PRIME_LIMBS>>,
Odd<Uint<PRIME_LIMBS>>: PrecomputeInverter<
Inverter = SafeGcdInverter<PRIME_LIMBS, PRIME_UNSAT_LIMBS>,
Output = Uint<PRIME_LIMBS>,
>,
{
let a_mod_p = a.rem(&p.resize().to_nz().unwrap()).resize();
let a_mod_q = a.rem(&q.resize().to_nz().unwrap()).resize();
match (sqrt_mod_prime(a_mod_p, p), sqrt_mod_prime(a_mod_q, q)) {
(Some(s_p), Some(s_q)) => {
let q_mtg = MontyParams::new(q);
let p_inv = MontyForm::new(&p, q_mtg).inv().unwrap();
Some(crt_combine(&s_p, &s_q, p_inv, &p, q_mtg))
}
_ => None,
}
}
pub fn sqrt_mod_composite_given_prime_factors_as_mtg_params<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
const PRIME_UNSAT_LIMBS: usize,
>(
a: Uint<PRIME_PRODUCT_LIMBS>,
p_mtg: MontyParams<PRIME_LIMBS>,
q_mtg: MontyParams<PRIME_LIMBS>,
) -> Option<Uint<PRIME_PRODUCT_LIMBS>>
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
Odd<Uint<PRIME_LIMBS>>: PrecomputeInverter<
Inverter = SafeGcdInverter<PRIME_LIMBS, PRIME_UNSAT_LIMBS>,
Output = Uint<PRIME_LIMBS>,
>,
{
let p_inv = MontyForm::new(p_mtg.modulus(), q_mtg).inv().unwrap();
sqrt_mod_composite_given_prime_factors_as_mtg_params_and_p_inv(a, p_mtg, q_mtg, p_inv)
}
pub fn sqrt_mod_composite_given_prime_factors_as_mtg_params_and_p_inv<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
>(
a: Uint<PRIME_PRODUCT_LIMBS>,
p_mtg: MontyParams<PRIME_LIMBS>,
q_mtg: MontyParams<PRIME_LIMBS>,
p_inv: MontyForm<PRIME_LIMBS>,
) -> Option<Uint<PRIME_PRODUCT_LIMBS>>
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
{
let a_mod_p = a.rem(&p_mtg.modulus().resize().to_nz().unwrap()).resize();
let a_mod_q = a.rem(&q_mtg.modulus().resize().to_nz().unwrap()).resize();
match (
sqrt_mod_prime_given_mtg_params(a_mod_p, p_mtg),
sqrt_mod_prime_given_mtg_params(a_mod_q, q_mtg),
) {
(Some(s_p), Some(s_q)) => Some(crt_combine(&s_p, &s_q, p_inv, p_mtg.modulus(), q_mtg)),
_ => None,
}
}
pub fn sqrt_mod_blum_integer_given_prime_factors_as_mtg_params<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
const PRIME_UNSAT_LIMBS: usize,
>(
a: &Uint<PRIME_PRODUCT_LIMBS>,
p_mtg: MontyParams<PRIME_LIMBS>,
q_mtg: MontyParams<PRIME_LIMBS>,
) -> Option<Uint<PRIME_PRODUCT_LIMBS>>
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
Odd<Uint<PRIME_LIMBS>>: PrecomputeInverter<
Inverter = SafeGcdInverter<PRIME_LIMBS, PRIME_UNSAT_LIMBS>,
Output = Uint<PRIME_LIMBS>,
>,
{
let (exp_mod_p_minus_1, exp_mod_q_minus_1, p_inv) = precomputation_for_sqrt_mod_blum_integer::<
PRIME_LIMBS,
PRIME_PRODUCT_LIMBS,
PRIME_UNSAT_LIMBS,
>(p_mtg, q_mtg);
sqrt_mod_blum_integer_given_prime_factors_as_mtg_params_and_precomp(
a,
&exp_mod_p_minus_1,
&exp_mod_q_minus_1,
p_inv,
p_mtg,
q_mtg,
)
}
pub fn sqrt_mod_blum_integer_given_prime_factors_as_mtg_params_and_precomp<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
>(
a: &Uint<PRIME_PRODUCT_LIMBS>,
exp_mod_p_minus_1: &Uint<PRIME_LIMBS>,
exp_mod_q_minus_1: &Uint<PRIME_LIMBS>,
p_inv: MontyForm<PRIME_LIMBS>,
p_mtg: MontyParams<PRIME_LIMBS>,
q_mtg: MontyParams<PRIME_LIMBS>,
) -> Option<Uint<PRIME_PRODUCT_LIMBS>>
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
{
let a_mod_p = a.rem(&p_mtg.modulus().resize().to_nz().unwrap()).resize();
let a_mod_q = a.rem(&q_mtg.modulus().resize().to_nz().unwrap()).resize();
if is_quadratic_residue_mod_prime(&a_mod_p, p_mtg)
&& is_quadratic_residue_mod_prime(&a_mod_q, q_mtg)
{
let a_p = MontyForm::new(&a_mod_p, p_mtg)
.pow(exp_mod_p_minus_1)
.retrieve();
let a_q = MontyForm::new(&a_mod_q, q_mtg)
.pow(exp_mod_q_minus_1)
.retrieve();
Some(crt_combine(&a_p, &a_q, p_inv, p_mtg.modulus(), q_mtg))
} else {
None
}
}
pub fn precomputation_for_sqrt_mod_blum_integer<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
const PRIME_UNSAT_LIMBS: usize,
>(
p_mtg: MontyParams<PRIME_LIMBS>,
q_mtg: MontyParams<PRIME_LIMBS>,
) -> (Uint<PRIME_LIMBS>, Uint<PRIME_LIMBS>, MontyForm<PRIME_LIMBS>)
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
Odd<Uint<PRIME_LIMBS>>: PrecomputeInverter<
Inverter = SafeGcdInverter<PRIME_LIMBS, PRIME_UNSAT_LIMBS>,
Output = Uint<PRIME_LIMBS>,
>,
{
let (exp_mod_p_minus_1, exp_mod_q_minus_1) = exponents_for_sqrt_mod_blum_integer(p_mtg, q_mtg);
let p_inv = MontyForm::new(p_mtg.modulus(), q_mtg).inv().unwrap();
(exp_mod_p_minus_1, exp_mod_q_minus_1, p_inv)
}
pub fn exponents_for_sqrt_mod_blum_integer<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
>(
p_mtg: MontyParams<PRIME_LIMBS>,
q_mtg: MontyParams<PRIME_LIMBS>,
) -> (Uint<PRIME_LIMBS>, Uint<PRIME_LIMBS>)
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
{
let exp: Uint<PRIME_PRODUCT_LIMBS> =
euler_totient::<PRIME_LIMBS, PRIME_PRODUCT_LIMBS>(p_mtg.modulus(), q_mtg.modulus())
.add(&Uint::from(4_u32))
.shr(3);
let exp_mod_p_minus_1 = exp
.rem(&p_mtg.modulus().sub(Uint::ONE).resize().to_nz().unwrap())
.resize();
let exp_mod_q_minus_1 = exp
.rem(&q_mtg.modulus().sub(Uint::ONE).resize().to_nz().unwrap())
.resize();
(exp_mod_p_minus_1, exp_mod_q_minus_1)
}
pub fn fourth_root_mod_blum_integer_given_prime_factors_as_mtg_params_and_precomp<
const PRIME_LIMBS: usize,
const PRIME_PRODUCT_LIMBS: usize,
>(
a: &Uint<PRIME_PRODUCT_LIMBS>,
exp_mod_p_minus_1: &Uint<PRIME_LIMBS>,
exp_mod_q_minus_1: &Uint<PRIME_LIMBS>,
p_inv: MontyForm<PRIME_LIMBS>,
p_mtg: MontyParams<PRIME_LIMBS>,
q_mtg: MontyParams<PRIME_LIMBS>,
) -> Option<Uint<PRIME_PRODUCT_LIMBS>>
where
Uint<PRIME_LIMBS>: Concat<Output = Uint<PRIME_PRODUCT_LIMBS>>,
{
let sqrt = sqrt_mod_blum_integer_given_prime_factors_as_mtg_params_and_precomp::<
PRIME_LIMBS,
PRIME_PRODUCT_LIMBS,
>(
a,
&exp_mod_p_minus_1,
&exp_mod_q_minus_1,
p_inv,
p_mtg,
q_mtg,
);
if let Some(s) = sqrt {
sqrt_mod_blum_integer_given_prime_factors_as_mtg_params_and_precomp::<
PRIME_LIMBS,
PRIME_PRODUCT_LIMBS,
>(
&s,
&exp_mod_p_minus_1,
&exp_mod_q_minus_1,
p_inv,
p_mtg,
q_mtg,
)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
math::jacobi::legendre_symbol,
util::{blum_prime, get_1024_bit_primes, get_2048_bit_primes},
};
use crypto_bigint::{RandomMod, U1024, U128, U2048, U256, U4096, U512, U64};
use crypto_primes::generate_prime_with_rng;
use rand_core::OsRng;
use std::time::{Duration, Instant};
#[test]
fn square_root_tonelli_shanks() {
let mut rng = OsRng::default();
for (i, p) in [
(10, 13),
(8, 17),
(9, 17),
(39, 41),
(13, 10000019),
(392203, 852167),
(379606557, 425172197),
(585251669, 892950901),
(404690348, 430183399),
(210205747, 625380647),
] {
let a = U64::from(i as u64);
let p = U64::from(p as u64).to_odd().unwrap();
assert!(legendre_symbol(&a, p).is_one());
let s = sqrt_using_tonelli_shanks(a, p).unwrap();
let s_sqr = s.square().rem(&p.resize().to_nz().unwrap()).resize();
assert_eq!(a, s_sqr);
}
macro_rules! check_given_prime {
( $num_iters:expr, $prime_type:ident, $p:ident ) => {
let p_nz = $p.resize().to_nz().unwrap();
for _ in 0..$num_iters {
let a = $prime_type::random_mod(&mut rng, $p.as_nz_ref());
let a_sqr = a.square().rem(&p_nz).resize();
let s = sqrt_using_tonelli_shanks(a_sqr, $p).unwrap();
let s_sqr = s.square().rem(&p_nz).resize();
assert_eq!(a_sqr, s_sqr, "{} {} {} {} {}", $p, a, a_sqr, s, s_sqr);
let c1 = a == s;
let c2 = a == $p.sub(s);
assert!(c1 || c2, "{} {} {} {}", $p, a, a_sqr, s);
}
};
}
macro_rules! check {
( $num_iters:expr, $prime_type:ident ) => {
let p: $prime_type = generate_prime_with_rng(&mut rng, $prime_type::BITS);
let p = p.to_odd().unwrap();
check_given_prime!($num_iters, $prime_type, p);
};
}
check!(1000, U64);
check!(1000, U128);
check!(100, U256);
check!(100, U512);
let (p, _) = get_1024_bit_primes();
check_given_prime!(30, U1024, p);
let (p, _) = get_2048_bit_primes();
check_given_prime!(30, U2048, p);
}
#[test]
fn square_root_blum_prime() {
let mut rng = OsRng::default();
macro_rules! check_given_prime {
( $num_iters:expr, $prime_type:ident, $p:ident ) => {
let p_nz = $p.resize().to_nz().unwrap();
let mut t1 = Duration::default();
let mut t2 = Duration::default();
for _ in 0..$num_iters {
let a = $prime_type::random_mod(&mut rng, &$p.as_nz_ref());
let a_sqr = a.square().rem(&p_nz).resize();
let start = Instant::now();
let s = sqrt_using_tonelli_shanks(a_sqr, $p).unwrap();
t1 += start.elapsed();
let s_sqr = s.square().rem(&p_nz).resize();
assert_eq!(a_sqr, s_sqr, "{} {} {} {} {}", $p, a, a_sqr, s, s_sqr);
let start = Instant::now();
let s = sqrt_for_blum_prime(a_sqr, $p).unwrap();
t2 += start.elapsed();
let s_sqr = s.square().rem(&p_nz).resize();
assert_eq!(a_sqr, s_sqr, "{} {} {} {} {}", $p, a, a_sqr, s, s_sqr);
}
println!("Time for {} iterations and for {} bit prime, Tonelli-Shanks={:?}, Special={:?}", $num_iters, $prime_type::BITS, t1, t2);
}
}
macro_rules! check {
( $num_iters:expr, $prime_type:ident ) => {
let p: $prime_type = blum_prime(&mut rng);
let p = p.to_odd().unwrap();
check_given_prime!($num_iters, $prime_type, p);
};
}
check!(1000, U128);
check!(100, U256);
check!(100, U512);
}
#[test]
fn square_root_mod_composite() {
let mut rng = OsRng::default();
macro_rules! check_given_primes {
( $num_iters:expr, $prime_type:ident, $mod_type:ident, $p:ident, $q:ident ) => {
let n = $p.widening_mul(&$q).to_odd().unwrap();
let n_nz = n.resize().to_nz().unwrap();
for _ in 0..$num_iters {
let a = $mod_type::random_mod(&mut rng, &n.to_nz().unwrap());
let a_sqr: $mod_type = a.square().rem(&n_nz).resize();
let s: $mod_type =
sqrt_mod_composite_given_prime_factors(a_sqr, $p, $q).unwrap();
let s_sqr: $mod_type = s.square().rem(&n_nz).resize();
assert_eq!(
a_sqr, s_sqr,
"{} {} {} {} {} {}",
$p, $q, a, a_sqr, s, s_sqr
);
}
};
}
macro_rules! check {
( $num_iters:expr, $prime_type:ident, $mod_type:ident ) => {
let p: $prime_type = generate_prime_with_rng(&mut rng, $prime_type::BITS);
let q: $prime_type = generate_prime_with_rng(&mut rng, $prime_type::BITS);
let p = p.to_odd().unwrap();
let q = q.to_odd().unwrap();
check_given_primes!($num_iters, $prime_type, $mod_type, p, q);
};
}
check!(1000, U128, U256);
check!(100, U256, U512);
check!(100, U512, U1024);
let (p, q) = get_1024_bit_primes();
check_given_primes!(30, U1024, U2048, p, q);
let (p, q) = get_2048_bit_primes();
check_given_primes!(30, U2048, U4096, p, q);
}
#[test]
fn square_root_blum_integer() {
let mut rng = OsRng::default();
macro_rules! check_given_primes {
( $num_iters:expr, $prime_type:ident, $mod_type:ident, $p:ident, $q:ident ) => {
let n = $p.widening_mul(&$q).to_odd().unwrap();
let n_nz = n.resize().to_nz().unwrap();
let p_mtg = MontyParams::new($p);
let q_mtg = MontyParams::new($q);
let (exp_mod_p_minus_1, exp_mod_q_minus_1, p_inv) = precomputation_for_sqrt_mod_blum_integer(p_mtg, q_mtg);
let mut t1 = Duration::default();
let mut t2 = Duration::default();
let mut t3 = Duration::default();
for _ in 0..$num_iters {
let a = $mod_type::random_mod(&mut rng, &n.to_nz().unwrap());
let a_sqr = a.square().rem(&n_nz).resize();
let start = Instant::now();
let s =
sqrt_mod_composite_given_prime_factors_as_mtg_params(a_sqr, p_mtg, q_mtg)
.unwrap();
t1 += start.elapsed();
let s_sqr = s.square().rem(&n_nz).resize();
assert_eq!(
a_sqr, s_sqr,
"{} {} {} {} {} {}",
$p, $q, a, a_sqr, s, s_sqr
);
let start = Instant::now();
let s = sqrt_mod_blum_integer_given_prime_factors_as_mtg_params(
&a_sqr, p_mtg, q_mtg,
)
.unwrap();
t2 += start.elapsed();
let s_sqr = s.square().rem(&n_nz).resize();
assert_eq!(
a_sqr, s_sqr,
"{} {} {} {} {} {}",
$p, $q, a, a_sqr, s, s_sqr
);
let start = Instant::now();
let s = sqrt_mod_blum_integer_given_prime_factors_as_mtg_params_and_precomp(
&a_sqr, &exp_mod_p_minus_1, &exp_mod_q_minus_1, p_inv, p_mtg, q_mtg,
)
.unwrap();
t3 += start.elapsed();
let s_sqr = s.square().rem(&n_nz).resize();
assert_eq!(
a_sqr, s_sqr,
"{} {} {} {} {} {}",
$p, $q, a, a_sqr, s, s_sqr
);
}
println!(
"Time for {} iterations and for {} bit prime, General={:?}, Special={:?}, Special with precomputation={:?}",
$num_iters,
$prime_type::BITS,
t1,
t2,
t3
);
};
}
macro_rules! check {
( $num_iters:expr, $prime_type:ident, $mod_type:ident ) => {
let p: $prime_type = blum_prime(&mut rng);
let q: $prime_type = blum_prime(&mut rng);
let p = p.to_odd().unwrap();
let q = q.to_odd().unwrap();
check_given_primes!($num_iters, $prime_type, $mod_type, p, q);
};
}
check!(1000, U128, U256);
check!(1000, U256, U512);
check!(1000, U512, U1024);
}
}