use num_bigint::BigUint;
use num_traits::{One, Zero};
use std::fmt;
const FIXED_P_BIT_LENGTH: usize = 256;
const MAX_Z_ITERATIONS: u32 = 200;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TonelliShanksError {
NotPrime,
NotQuadraticResidue,
ZeroOrOne,
InvalidInput(BigUint), }
impl fmt::Display for TonelliShanksError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TonelliShanksError::NotPrime => write!(f, "Modulus is not prime"),
TonelliShanksError::NotQuadraticResidue => {
write!(f, "n is not a quadratic residue mod p")
}
TonelliShanksError::ZeroOrOne => {
write!(f, "Trivial case: n = 0 or 1 (handled directly)")
}
TonelliShanksError::InvalidInput(x) => write!(
f,
"Invalid input or algorithmic failure (e.g., z not found, i==m) for value: {}",
x
),
}
}
}
fn ct_select_biguint(condition: bool, val_true: &BigUint, val_false: &BigUint) -> BigUint {
let control_big = BigUint::from(condition as u8); let vt = val_true.clone();
let vf = val_false.clone();
vf.clone() ^ (&control_big * (vt ^ vf))
}
#[inline]
fn ct_eq_biguint(a: &BigUint, b: &BigUint) -> bool {
a == b
}
fn modpow_biguint_ct(base: &BigUint, exp: &BigUint, modulus: &BigUint, max_bits: usize) -> BigUint {
if modulus.is_zero() {
return BigUint::zero(); }
if modulus.is_one() {
return BigUint::zero(); }
let mut res = BigUint::one();
let mut b_val = base % modulus;
let mut e_val = exp.clone();
let one = BigUint::one();
for _i in 0..max_bits {
let e_is_odd = ct_eq_biguint(&(&e_val & &one), &one);
let prod = (&res * &b_val) % modulus;
res = ct_select_biguint(e_is_odd, &prod, &res);
e_val >>= 1u32;
b_val = (&b_val * &b_val) % modulus;
}
res
}
pub fn tonelli_shanks_ct(
n: &BigUint,
p: &BigUint,
) -> Result<(BigUint, BigUint), TonelliShanksError> {
if p.is_zero() || ct_eq_biguint(p, &BigUint::one()) {
return Err(TonelliShanksError::NotPrime);
}
if n.is_zero() {
return Ok((BigUint::zero(), BigUint::zero()));
}
if ct_eq_biguint(n, &BigUint::one()) {
let r2 = p - BigUint::one();
return Ok((BigUint::one(), r2));
}
let mut overall_error: Option<TonelliShanksError> = None;
let one_val = BigUint::one();
let zero_val = BigUint::zero();
let r_final_computed;
let p_minus_1 = p - &one_val;
let exp_euler = &p_minus_1 >> 1u32;
let n_pow_exp = modpow_biguint_ct(n, &exp_euler, p, p.bits() as usize);
let is_not_residue = !ct_eq_biguint(&n_pow_exp, &one_val);
if is_not_residue {
overall_error = Some(TonelliShanksError::NotQuadraticResidue);
}
let mut q_val = p_minus_1.clone();
let mut s_val = 0u32;
let mut q_s_processing_done_flag = false;
for _ in 0..FIXED_P_BIT_LENGTH {
let active_sq_step = !q_s_processing_done_flag && overall_error.is_none();
let q_is_even = ct_eq_biguint(&(&q_val & &one_val), &zero_val); let should_update_q_and_s = q_is_even && active_sq_step;
let q_shifted = &q_val >> 1u32;
q_val = ct_select_biguint(should_update_q_and_s, &q_shifted, &q_val);
let s_incremented = s_val + 1;
s_val = if should_update_q_and_s {
s_incremented
} else {
s_val
}; if !should_update_q_and_s {
q_s_processing_done_flag = true;
}
}
let mut found_z_val = BigUint::from(2u32); let mut z_actually_found_flag = false;
for i in 2u32..(2u32 + MAX_Z_ITERATIONS) {
let active_z_step = overall_error.is_none(); let current_z_candidate = BigUint::from(i);
let z_pow_exp = modpow_biguint_ct(¤t_z_candidate, &exp_euler, p, p.bits() as usize);
let is_non_residue_cond = ct_eq_biguint(&z_pow_exp, &p_minus_1);
let should_capture_this_z = is_non_residue_cond && !z_actually_found_flag && active_z_step;
found_z_val = ct_select_biguint(should_capture_this_z, ¤t_z_candidate, &found_z_val);
if should_capture_this_z {
z_actually_found_flag = true;
} }
let z_finding_failed = !z_actually_found_flag;
if z_finding_failed && overall_error.is_none() {
overall_error = Some(TonelliShanksError::InvalidInput(p.clone()));
}
let c_initial = modpow_biguint_ct(&found_z_val, &q_val, p, p.bits() as usize);
let q_plus_1 = &q_val + &one_val; let q_plus_1_div_2 = &q_plus_1 >> 1u32;
let r_initial = modpow_biguint_ct(n, &q_plus_1_div_2, p, p.bits() as usize);
let t_initial = modpow_biguint_ct(n, &q_val, p, p.bits() as usize);
let m_initial = s_val;
let mut c_val = ct_select_biguint(overall_error.is_none(), &c_initial, &one_val); let mut r_val = ct_select_biguint(overall_error.is_none(), &r_initial, &zero_val); let mut t_val = ct_select_biguint(overall_error.is_none(), &t_initial, &one_val); let mut m_val = if overall_error.is_none() {
m_initial
} else {
0
};
for _outer_loop_idx in 0..FIXED_P_BIT_LENGTH {
let active_outer_iteration = overall_error.is_none() && !ct_eq_biguint(&t_val, &one_val);
let mut determined_i_val = m_val; let mut i_search_temp_t = t_val.clone(); let mut inner_i_found_flag = false;
for k_inner in 0..FIXED_P_BIT_LENGTH {
let active_inner_step =
active_outer_iteration && !inner_i_found_flag && (k_inner < m_val as usize);
let is_temp_t_one = ct_eq_biguint(&i_search_temp_t, &one_val);
let should_capture_i = active_inner_step && is_temp_t_one;
if should_capture_i {
determined_i_val = k_inner as u32; }
if should_capture_i {
inner_i_found_flag = true;
}
let squared_i_search_temp_t = (&i_search_temp_t * &i_search_temp_t) % p;
let cond_stop_temp_t_update = inner_i_found_flag || !active_inner_step;
i_search_temp_t = ct_select_biguint(
cond_stop_temp_t_update,
&i_search_temp_t,
&squared_i_search_temp_t,
);
}
let i_equals_m_error_trigger = active_outer_iteration && !inner_i_found_flag && (m_val > 0);
if i_equals_m_error_trigger && overall_error.is_none() {
overall_error = Some(TonelliShanksError::InvalidInput(n.clone()));
}
let can_update_rctm =
active_outer_iteration && overall_error.is_none() && inner_i_found_flag;
let mut b_exp_power = 0u32;
if m_val > determined_i_val && (m_val - determined_i_val) >= 1 {
b_exp_power = m_val - determined_i_val - 1;
}
let exp_for_b_ct = BigUint::from(1u32) << (b_exp_power as usize);
let b_val = modpow_biguint_ct(&c_val, &exp_for_b_ct, p, p.bits() as usize);
let r_new = (&r_val * &b_val) % p;
let c_new = (&b_val * &b_val) % p;
let t_new = (&t_val * &c_new) % p;
let m_new = determined_i_val;
r_val = ct_select_biguint(can_update_rctm, &r_new, &r_val);
c_val = ct_select_biguint(can_update_rctm, &c_new, &c_val);
t_val = ct_select_biguint(can_update_rctm, &t_new, &t_val);
m_val = if can_update_rctm { m_new } else { m_val };
}
r_final_computed = r_val;
if let Some(err) = overall_error {
Err(err)
} else {
let r2_final = (p - &r_final_computed) % p;
Ok((r_final_computed, r2_final))
}
}
#[cfg(test)]
mod tests {
use rand::Rng;
use super::*;
#[test]
fn test_known_large_prime_10digit() {
let p = BigUint::from(1_000_000_009u64);
let n = BigUint::from(665_820_697u64);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert!(r1 == BigUint::from(378_633_312u64) || r2 == BigUint::from(378_633_312u64));
assert!(r1 == BigUint::from(621_366_697u64) || r2 == BigUint::from(621_366_697u64));
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
}
#[test]
fn test_known_large_prime_13digit() {
let p = BigUint::from(1_000_000_000_039u64);
let n = BigUint::from(881_398_088_036u64);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert!(r1 == BigUint::from(791_399_408_049u64) || r2 == BigUint::from(791_399_408_049u64));
assert!(r1 == BigUint::from(208_600_591_990u64) || r2 == BigUint::from(208_600_591_990u64));
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
}
#[test]
fn test_known_large_prime_14digit() {
let p = BigUint::from(41_434_547_495_153u64);
let n = BigUint::from(2u64);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert!(
r1 == BigUint::from(32_789_792_927_270u64)
|| r2 == BigUint::from(32_789_792_927_270u64)
);
assert!(
r1 == BigUint::from(8_644_754_567_883u64) || r2 == BigUint::from(8_644_754_567_883u64)
);
assert_eq!((&r1 * &r1) % &p, n);
}
#[test]
fn test_known_large_prime_6digit() {
let p = BigUint::from(100_049u64);
let n = BigUint::from(44_402u64);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert!(r1 == BigUint::from(30_468u64) || r2 == BigUint::from(30_468u64));
assert!(r1 == BigUint::from(69_581u64) || r2 == BigUint::from(69_581u64));
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
}
#[test]
fn test_edge_case_zero() {
let p = BigUint::from(101u64);
let n = BigUint::zero();
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_ok());
let (r1, r2) = res.unwrap();
assert_eq!(r1, BigUint::zero());
assert_eq!(r2, BigUint::zero());
}
#[test]
fn test_edge_case_one() {
let p = BigUint::from(101u64);
let n = BigUint::one();
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_ok());
let (r1, r2) = res.unwrap();
assert!(
(r1 == BigUint::one() && r2 == BigUint::from(100u64))
|| (r1 == BigUint::from(100u64) && r2 == BigUint::one())
);
}
#[test]
fn test_non_residue() {
let p = BigUint::from(1009u64);
let n = BigUint::from(23u64); let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_err());
if let Err(TonelliShanksError::NotQuadraticResidue) = res {
} else {
panic!("Expected NotQuadraticResidue error, got {:?}", res);
}
}
#[test]
fn test_tonelli_shanks_basic() {
let n = BigUint::from(5u64);
let p = BigUint::from(41u64);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
assert_ne!(r1, r2); assert!(
(r1 == BigUint::from(13u64) && r2 == BigUint::from(28u64))
|| (r1 == BigUint::from(28u64) && r2 == BigUint::from(13u64))
);
}
#[test]
fn test_tonelli_shanks_non_residue() {
let n = BigUint::from(3u64);
let p = BigUint::from(7u64); let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_err());
if let Err(TonelliShanksError::NotQuadraticResidue) = res {
} else {
panic!("Expected NotQuadraticResidue error, got {:?}", res);
}
}
#[test]
fn test_non_residue_hard_p() {
let p = BigUint::from(41u64);
let n = BigUint::from(3u64);
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_err());
assert_eq!(res.unwrap_err(), TonelliShanksError::NotQuadraticResidue);
let p17 = BigUint::from(17u64);
let n2 = BigUint::from(2u64);
let (r1, r2) = tonelli_shanks_ct(&n2, &p17).unwrap();
assert!(
(r1 == BigUint::from(6u64) && r2 == BigUint::from(11u64))
|| (r1 == BigUint::from(11u64) && r2 == BigUint::from(6u64))
);
}
#[test]
fn test_nist_p192_prime() {
let p =
BigUint::parse_bytes(b"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFFF", 16).unwrap();
let n = BigUint::from(4u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
assert_ne!(r1, r2);
}
#[test]
fn test_nist_p256_prime() {
let p = BigUint::parse_bytes(
b"FFFFFFFF00000001000000000000000000000000FFFFFFFFFFFFFFFFFFFFFFFF",
16,
)
.unwrap();
let n = BigUint::from(4u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
assert_ne!(r1, r2);
}
#[test]
fn test_nist_p224_prime() {
let p = BigUint::parse_bytes(
b"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF000000000000000000000001",
16,
)
.unwrap();
let n = BigUint::from(4u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(9u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(3u32);
let p_legendre = p.clone();
let exp = (&p_legendre - BigUint::one()) >> 1u32;
let legendre = modpow_biguint_ct(&n, &exp, &p_legendre, 256);
if legendre == p_legendre - BigUint::one() {
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_err());
} else {
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_ok());
}
let x = BigUint::parse_bytes(
b"E84FB0B8E7000CB657D7973CF6B42ED78B301674276DF744AF130B3E",
16,
)
.unwrap();
let n = (&x * &x) % &p;
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert!(r1 == x || r2 == x || r1 == (&p - &x) || r2 == (&p - &x));
}
#[test]
fn test_nist_p384_prime() {
let p = BigUint::parse_bytes(b"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFF0000000000000000FFFFFFFF", 16).unwrap();
let n = BigUint::from(4u32);
let res = tonelli_shanks_ct(&n, &p);
let (r1, r2) = res.unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(9u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(3u32);
let p_legendre = p.clone();
let exp = (&p_legendre - BigUint::one()) >> 1u32;
let legendre = modpow_biguint_ct(&n, &exp, &p_legendre, 384);
if legendre == p_legendre - BigUint::one() {
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_err());
} else {
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_ok());
}
let x = BigUint::parse_bytes(b"3BF701BC9E9D36B4D5F1455343F09126F2564390F2B487365071243C61E6471FB9D2AB74657B82F9086489D9EF0F5CB5", 16).unwrap();
let n = (&x * &x) % &p;
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert!(r1 == x || r2 == x || r1 == (&p - &x) || r2 == (&p - &x));
}
#[test]
fn test_nist_p521_prime() {
let p = BigUint::parse_bytes(b"01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 16).unwrap();
let n = BigUint::from(4u32);
let res = tonelli_shanks_ct(&n, &p);
let (r1, r2) = res.unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(9u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(3u32);
let p_legendre = p.clone();
let exp = (&p_legendre - BigUint::one()) >> 1u32;
let legendre = modpow_biguint_ct(&n, &exp, &p_legendre, 521);
if legendre == p_legendre - BigUint::one() {
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_err());
} else {
let res = tonelli_shanks_ct(&n, &p);
assert!(res.is_ok());
}
let x = BigUint::parse_bytes(b"0098E91EEF9A68452822309C52FAB453F5F117C1DA8ED796B255E9AB8F6410CCA16E59DF403A6BDC6CA467A37056B1E54B3005D8AC030DECFEB68DF18B171885D5C4", 16).unwrap();
let n = (&x * &x) % &p;
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert!(r1 == x || r2 == x || r1 == (&p - &x) || r2 == (&p - &x));
}
#[test]
fn test_mersenne_prime() {
let p = BigUint::parse_bytes(b"7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 16).unwrap();
let n = BigUint::from(4u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(3u32);
let exp = (&p - BigUint::one()) >> 1u32;
let legendre = modpow_biguint_ct(&n, &exp, &p, p.bits() as usize);
if legendre == p.clone() - BigUint::one() {
assert!(tonelli_shanks_ct(&n, &p).is_err());
}
}
#[test]
fn test_safe_prime() {
let p = BigUint::parse_bytes(b"01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 16).unwrap();
let n = BigUint::from(9u32);
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let n = BigUint::from(5u32);
let exp = (&p - BigUint::one()) >> 1u32;
let legendre = modpow_biguint_ct(&n, &exp, &p, p.bits() as usize);
if legendre == p.clone() - BigUint::one() {
assert!(tonelli_shanks_ct(&n, &p).is_err());
}
}
#[test]
fn test_random_residues_and_nonresidues_521bit() {
let p = BigUint::parse_bytes(b"01FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF", 16).unwrap();
let mut rng = rand::thread_rng();
for _ in 0..5 {
let mut bytes = vec![0u8; 66];
rng.fill(&mut bytes[..]);
bytes[0] &= 0x01; let x = BigUint::from_bytes_be(&bytes) % &p;
let n = (&x * &x) % &p;
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let mut candidate = BigUint::from(rng.gen_range(2u32..1000u32));
let exp = (&p - BigUint::one()) >> 1u32;
let mut found = false;
for _ in 0..100 {
let legendre = modpow_biguint_ct(&candidate, &exp, &p, p.bits() as usize);
if legendre == p.clone() - BigUint::one() {
assert!(tonelli_shanks_ct(&candidate, &p).is_err());
found = true;
break;
}
candidate += BigUint::one();
}
assert!(found, "Failed to find a non-residue in 100 tries");
}
}
#[test]
fn test_random_1024bit_prime() {
let p = BigUint::parse_bytes(b"FFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFF", 16).unwrap();
let mut rng = rand::thread_rng();
for _ in 0..3 {
let mut bytes = vec![0u8; 128];
rng.fill(&mut bytes[..]);
let x = BigUint::from_bytes_be(&bytes) % &p;
let n = (&x * &x) % &p;
let (r1, r2) = tonelli_shanks_ct(&n, &p).unwrap();
assert_eq!((&r1 * &r1) % &p, n);
assert_eq!((&r2 * &r2) % &p, n);
let mut candidate = BigUint::from(rng.gen_range(2u32..1000u32));
let exp = (&p - BigUint::one()) >> 1u32;
let mut found = false;
for _ in 0..100 {
let legendre = modpow_biguint_ct(&candidate, &exp, &p, p.bits() as usize);
if legendre == p.clone() - BigUint::one() {
assert!(tonelli_shanks_ct(&candidate, &p).is_err());
found = true;
break;
}
candidate += BigUint::one();
}
assert!(found, "Failed to find a non-residue in 100 tries");
}
}
}