use crate::{
Choice, CtOption, Uint,
modular::{FixedMontyForm, FixedMontyParams, prime_params::PrimeParams},
};
#[must_use]
pub const fn sqrt_montgomery_form<const LIMBS: usize>(
monty_value: &Uint<LIMBS>,
monty_params: &FixedMontyParams<LIMBS>,
prime_params: &PrimeParams<LIMBS>,
) -> CtOption<Uint<LIMBS>> {
let value = FixedMontyForm::from_montgomery(*monty_value, monty_params);
let b = value.pow_vartime(&prime_params.sqrt_exp);
let x = match prime_params.s.get() {
1 => {
b
}
2 => {
let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
let cb = value.mul(&b);
let zeta = cb.mul(&b);
let is_one = Uint::eq(zeta.as_montgomery(), monty_params.one());
monty_select(&cb.mul(&ru), &cb, is_one)
}
3 => {
let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
let ru_2 =
FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
let ru_3 = ru.mul(&ru_2);
let cb = value.mul(&b);
let zeta = cb.mul(&b);
let mut m = monty_select(
&ru,
&FixedMontyForm::one(ru.params()),
Uint::eq(zeta.as_montgomery(), monty_params.one()),
);
m = monty_select(
&m,
&ru_2,
Uint::eq(zeta.neg().as_montgomery(), monty_params.one()),
);
m = monty_select(&m, &ru_3, monty_eq(&zeta, &ru_2));
cb.mul(&m)
}
4 => {
let ru = FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
let ru_2 =
FixedMontyForm::from_montgomery(prime_params.monty_root_unity_p2, monty_params);
let ru_4 = ru_2.square();
let ru_6 = ru_2.mul(&ru_4);
let cb = value.mul(&b);
let zeta = cb.mul(&b);
let neg_zeta = zeta.neg();
let zeta_b = monty_eq(&zeta, &ru_2);
let neg_zeta_b = monty_eq(&neg_zeta, &ru_2);
let zeta_d = monty_eq(&zeta, &ru_6);
let mut m = monty_select(
&FixedMontyForm::one(ru.params()),
&ru_2,
neg_zeta_b.or(monty_eq(&neg_zeta, &ru_4)),
);
m = monty_select(
&m,
&ru_4,
Uint::eq(neg_zeta.as_montgomery(), monty_params.one()).or(zeta_d),
);
m = monty_select(&m, &ru_6, zeta_b.or(monty_eq(&zeta, &ru_4)));
m = monty_select(
&m,
&m.mul(&ru),
zeta_b
.or(zeta_d)
.or(neg_zeta_b)
.or(monty_eq(&neg_zeta, &ru_6)),
);
cb.mul(&m)
}
_ => {
let mut x = value.mul(&b);
let mut d = x.mul(&b);
let mut z =
FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params);
let mut v = prime_params.s.get();
let mut max_v = v;
while max_v >= 1 {
let mut k = 1;
let mut tmp = d.square();
let mut j_less_than_v = Choice::TRUE;
let mut j = 2;
while j < max_v {
let tmp_is_one = Uint::eq(tmp.as_montgomery(), monty_params.one());
let squared = monty_select(&tmp, &z, tmp_is_one).square();
tmp = monty_select(&squared, &tmp, tmp_is_one);
j_less_than_v = j_less_than_v.and(Choice::from_u32_eq(j, v).not());
z = monty_select(&z, &squared, tmp_is_one.and(j_less_than_v));
k = tmp_is_one.select_u32(j, k);
j += 1;
}
let b_is_one = Uint::eq(d.as_montgomery(), monty_params.one());
x = monty_select(&x.mul(&z), &x, b_is_one);
z = z.square();
d = d.mul(&z);
v = k;
max_v -= 1;
}
x
}
};
CtOption::new(x.to_montgomery(), monty_eq(&x.square(), &value))
}
const fn monty_eq<const LIMBS: usize>(
a: &FixedMontyForm<LIMBS>,
b: &FixedMontyForm<LIMBS>,
) -> Choice {
Uint::eq(a.as_montgomery(), b.as_montgomery())
}
const fn monty_select<const LIMBS: usize>(
a: &FixedMontyForm<LIMBS>,
b: &FixedMontyForm<LIMBS>,
c: Choice,
) -> FixedMontyForm<LIMBS> {
FixedMontyForm::from_montgomery(
Uint::select(a.as_montgomery(), b.as_montgomery(), c),
a.params(),
)
}
#[cfg(test)]
mod tests {
use super::sqrt_montgomery_form;
use crate::{
Odd, U256, U576, Uint,
modular::{FixedMontyForm, FixedMontyParams, PrimeParams},
};
fn root_of_unity<const LIMBS: usize>(
monty_params: &FixedMontyParams<LIMBS>,
prime_params: &PrimeParams<LIMBS>,
) -> Uint<LIMBS> {
FixedMontyForm::from_montgomery(prime_params.monty_root_unity, monty_params).retrieve()
}
fn test_monty_sqrt<const LIMBS: usize>(
monty_params: FixedMontyParams<LIMBS>,
prime_params: PrimeParams<LIMBS>,
) {
let modulus = monty_params.modulus.get();
let rounds = if cfg!(miri) { 1..=2 } else { 0..=256 };
for i in rounds {
let s = i * i;
let s_monty = FixedMontyForm::new(&Uint::from_u32(s), &monty_params);
let rt_monty =
sqrt_montgomery_form(s_monty.as_montgomery(), &monty_params, &prime_params)
.expect("no sqrt found");
let rt = FixedMontyForm::from_montgomery(rt_monty, &monty_params).retrieve();
let i = Uint::from_u32(i);
assert!(
Uint::eq(&rt, &i)
.or(Uint::eq(&rt, &modulus.wrapping_sub(&i)))
.to_bool_vartime()
);
}
let generator = Uint::from_u32(prime_params.generator.get());
let gen_monty = FixedMontyForm::new(&generator, &monty_params);
assert!(
sqrt_montgomery_form(gen_monty.as_montgomery(), &monty_params, &prime_params)
.is_none()
.to_bool_vartime()
);
}
#[test]
fn mod_sqrt_s_1() {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"ffffffff00000001000000000000000000000000ffffffffffffffffffffffff",
));
let prime_params = PrimeParams::new_vartime(&monty_params, 6);
assert_eq!(prime_params.s.get(), 1);
assert_eq!(prime_params.generator.get(), 6);
assert_eq!(
root_of_unity(&monty_params, &prime_params),
U256::from_be_hex("FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFE")
);
test_monty_sqrt(monty_params, prime_params);
}
#[test]
fn mod_sqrt_s_2() {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed",
));
let prime_params = PrimeParams::new_vartime(&monty_params, 2);
assert_eq!(prime_params.s.get(), 2);
assert_eq!(prime_params.generator.get(), 2);
assert_eq!(
root_of_unity(&monty_params, &prime_params),
U256::from_be_hex("2B8324804FC1DF0B2B4D00993DFBD7A72F431806AD2FE478C4EE1B274A0EA0B0")
);
test_monty_sqrt(monty_params, prime_params);
}
#[test]
fn mod_sqrt_s_3() {
let monty_params = FixedMontyParams::new_vartime(Odd::<U576>::from_be_hex(
"00000000000001fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffa51868783bf2f966b7fcc0148f709a5d03bb5c9b8899c47aebb6fb71e91386409",
));
let prime_params = PrimeParams::new_vartime(&monty_params, 3);
assert_eq!(prime_params.s.get(), 3);
assert_eq!(prime_params.generator.get(), 3);
assert_eq!(
root_of_unity(&monty_params, &prime_params),
U576::from_be_hex(
"000000000000009a0a650d44b28c17f3d708ad2fa8c4fbc7e6000d7c12dafa92fcc5673a3055276d535f79ff391dcdbcd998b7836647d3a72472b3da861ac810a7f9c7b7b63e2205"
)
);
test_monty_sqrt(monty_params, prime_params);
}
#[test]
fn mod_sqrt_s_4() {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551",
));
let prime_params = PrimeParams::new_vartime(&monty_params, 7);
assert_eq!(prime_params.s.get(), 4);
assert_eq!(prime_params.generator.get(), 7);
assert_eq!(
root_of_unity(&monty_params, &prime_params),
U256::from_be_hex("ffc97f062a770992ba807ace842a3dfc1546cad004378daf0592d7fbb41e6602")
);
test_monty_sqrt(monty_params, prime_params);
}
#[test]
fn mod_sqrt_s_6() {
let monty_params = FixedMontyParams::new_vartime(Odd::<U256>::from_be_hex(
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141",
));
let prime_params = PrimeParams::new_vartime(&monty_params, 7);
assert_eq!(prime_params.s.get(), 6);
assert_eq!(prime_params.generator.get(), 7);
assert_eq!(
root_of_unity(&monty_params, &prime_params),
U256::from_be_hex("0C1DC060E7A91986DF9879A3FBC483A898BDEAB680756045992F4B5402B052F2")
);
test_monty_sqrt(monty_params, prime_params);
}
}