use crypto_bigint::{
CtEq as _, CtAssign, Zero, One, Odd, JacobiSymbol, NegMod, MulMod, SquareMod, BitOps, Limb,
UintRef,
};
trait SquareRoot:
Clone
+ AsRef<[Limb]>
+ AsMut<[Limb]>
+ Zero
+ One
+ NegMod<Output = Self>
+ MulMod<Output = Self>
+ SquareMod<Output = Self>
+ BitOps
{
fn pow_mod(self, exp: Self, modulus: &Odd<Self>) -> Self;
}
impl<
T: Clone
+ AsRef<[Limb]>
+ AsMut<[Limb]>
+ CtAssign
+ Zero
+ One
+ NegMod<Output = Self>
+ MulMod<Output = Self>
+ SquareMod<Output = Self>
+ BitOps,
> SquareRoot for T
{
fn pow_mod(self, exp: Self, modulus: &Odd<Self>) -> Self {
let mut result = Self::one_like(modulus);
let mut square = self.clone();
for i in 0 .. modulus.as_ref().bits_precision() {
result.ct_assign(
&result.mul_mod(&square, modulus.as_nz_ref()),
u8::from(exp.bit_vartime(i)).ct_eq(&1),
);
square = square.square_mod(modulus.as_nz_ref());
}
result
}
}
#[expect(private_bounds)]
pub(crate) fn legendre_symbol<U: SquareRoot>(a: U, p: &Odd<U>) -> JacobiSymbol {
let p_minus_one_div_two = {
let mut p_minus_one_div_two = p.as_ref().clone();
UintRef::new_mut(p_minus_one_div_two.as_mut()).shr1_assign();
p_minus_one_div_two
};
let exponentation = a.pow_mod(p_minus_one_div_two, p);
let (exponentation_is_zero, exponentation_is_one) = {
let exponentation = <_ as AsRef<[Limb]>>::as_ref(&exponentation);
let mut exponentation_is_zero = Limb::ZERO;
for limb in exponentation.iter().skip(1) {
exponentation_is_zero |= limb;
}
let mut exponentation_is_zero = exponentation_is_zero.is_zero();
let mut exponentation_is_one = exponentation_is_zero;
exponentation_is_zero &= exponentation[0].is_zero();
exponentation_is_one &= exponentation[0].is_one();
(exponentation_is_zero, exponentation_is_one)
};
let (one, zero, minus_one) = unsafe {
(
core::mem::transmute::<JacobiSymbol, i8>(JacobiSymbol::One),
core::mem::transmute::<JacobiSymbol, i8>(JacobiSymbol::Zero),
core::mem::transmute::<JacobiSymbol, i8>(JacobiSymbol::MinusOne),
)
};
let mut result = minus_one;
result.ct_assign(&one, exponentation_is_one);
result.ct_assign(&zero, exponentation_is_zero);
unsafe { core::mem::transmute::<i8, JacobiSymbol>(result) }
}
#[expect(private_bounds)]
pub(crate) fn sqrt_mod_p_vartime<U: SquareRoot>(n: U, p: &Odd<U>) -> Option<U> {
match legendre_symbol(n.clone(), p) {
JacobiSymbol::One => {}
JacobiSymbol::Zero => return Some(U::zero_like(p.as_ref())),
JacobiSymbol::MinusOne => return None,
}
let mut S = 1;
while !p.as_ref().bit_vartime(S) {
S += 1;
}
let mut Q = p.as_ref().clone();
UintRef::new_mut(Q.as_mut()).shr_assign(S);
let mut Q_plus_1_div_2 = Q.clone();
{
let limbs = <_ as AsMut<[Limb]>>::as_mut(&mut Q_plus_1_div_2);
let mut carry = UintRef::new_mut(limbs).add_assign_limb(Limb::ONE);
for i in (0 .. limbs.len()).rev() {
let new_limb = (carry << (Limb::BITS - 1)) | (limbs[i] >> 1);
carry = limbs[i] & Limb::ONE;
limbs[i] = new_limb;
}
debug_assert!(bool::from(carry.is_zero()));
}
let mut R = n.clone().pow_mod(Q_plus_1_div_2, p);
let mut t = n.pow_mod(Q.clone(), p);
let z = {
let mut z = U::zero_like(p.as_ref());
<_ as AsMut<[Limb]>>::as_mut(&mut z)[0] = Limb::from(2u8);
while legendre_symbol(z.clone(), p) != JacobiSymbol::MinusOne {
let _carry = UintRef::new_mut(z.as_mut()).add_assign_limb(Limb::ONE);
}
z
};
let mut M = S;
let mut c = z.pow_mod(Q, p);
while !bool::from(t.is_one()) {
let i = {
let mut t = t.clone();
let mut i = 0;
while !bool::from(t.is_one()) {
i += 1;
assert!(i < M);
t = t.square_mod(p.as_nz_ref());
}
i
};
let mut b = c;
for _ in 0 .. (M - i - 1) {
b = b.square_mod(p.as_nz_ref());
}
M = i;
c = b.square_mod(p.as_nz_ref());
t = t.mul_mod(&c, p.as_nz_ref());
R = R.mul_mod(&b, p.as_nz_ref());
}
if (<_ as AsRef<[Limb]>>::as_ref(&R)[0] & Limb::ONE) != Limb::ONE {
R = R.neg_mod(p.as_nz_ref());
}
Some(R)
}