use super::U256;
impl U256 {
#[inline]
pub fn pow_mod_ct(&self, exp: &U256, modulus: &U256) -> U256 {
let base = self.modulo(modulus);
let mut result = U256::ONE;
let mut i = 256u32;
while i > 0 {
i -= 1;
result = result.square_mod(modulus);
let multiplied = result.mul_mod(&base, modulus);
result = multiplied.ct_select(&result, exp.bit(i));
}
result
}
}
#[cfg(test)]
mod ai_tests {
use super::*;
const P: U256 = U256::from_be_limbs([
0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF, 0xFFFFFFFEFFFFFC2F,
]);
#[test]
fn exp_zero() {
let base = U256::from_be_limbs([0, 0, 0, 42]);
assert_eq!(base.pow_mod_ct(&U256::ZERO, &P), U256::ONE);
}
#[test]
fn exp_one() {
let base = U256::from_be_limbs([0, 0, 0, 42]);
assert_eq!(base.pow_mod_ct(&U256::ONE, &P), base);
}
#[test]
fn two_pow_ten() {
let base = U256::from_be_limbs([0, 0, 0, 2]);
let exp = U256::from_be_limbs([0, 0, 0, 10]);
assert_eq!(base.pow_mod_ct(&exp, &P), U256::from_be_limbs([0, 0, 0, 1024]));
}
#[test]
fn three_pow_five() {
let base = U256::from_be_limbs([0, 0, 0, 3]);
let exp = U256::from_be_limbs([0, 0, 0, 5]);
assert_eq!(base.pow_mod_ct(&exp, &P), U256::from_be_limbs([0, 0, 0, 243]));
}
#[test]
fn fermats_little_theorem() {
let base = U256::from_be_limbs([0, 0, 0, 7]);
let p_minus_1 = P - U256::ONE;
assert_eq!(base.pow_mod_ct(&p_minus_1, &P), U256::ONE);
}
#[test]
fn zero_pow_zero() {
assert_eq!(U256::ZERO.pow_mod_ct(&U256::ZERO, &P), U256::ONE);
}
#[test]
fn zero_base() {
let exp = U256::from_be_limbs([0, 0, 0, 5]);
assert_eq!(U256::ZERO.pow_mod_ct(&exp, &P), U256::ZERO);
}
#[test]
fn matches_vt() {
let cases = [
(U256::from_be_limbs([0, 0, 0, 2]), U256::from_be_limbs([0, 0, 0, 10])),
(U256::from_be_limbs([0, 0, 0, 7]), P - U256::ONE),
(U256::ZERO, U256::ZERO),
(U256::from_be_limbs([0, 0, 0, 13]), U256::from_be_limbs([0, 0, 0, 7])),
];
for (base, exp) in &cases {
assert_eq!(
base.pow_mod_ct(exp, &P),
base.pow_mod_vt(exp, &P),
"CT/VT mismatch for {base:?}^{exp:?}"
);
}
}
#[test]
fn matches_iterative_mul() {
let base = U256::from_be_limbs([0, 0, 0, 13]);
let mut expected = U256::ONE;
for _ in 0..7 {
expected = expected.mul_mod(&base, &P);
}
let exp = U256::from_be_limbs([0, 0, 0, 7]);
assert_eq!(base.pow_mod_ct(&exp, &P), expected);
}
}