#![allow(clippy::wrong_self_convention)]
use crate::bignum::{MontModulus, Uint};
use crate::ct::{Choice, ConstantTimeEq, ConstantTimeLess};
pub(crate) type Fe = Uint<4>;
#[derive(Clone, Copy)]
pub(crate) struct CtOption {
value: Fe,
is_some: Choice,
}
impl CtOption {
#[inline]
pub(crate) fn new(value: Fe, is_some: Choice) -> Self {
CtOption { value, is_some }
}
#[inline]
pub(crate) fn into_option(self) -> Option<Fe> {
if bool::from(self.is_some) {
Some(self.value)
} else {
None
}
}
}
pub(crate) const P_HEX: &str = "fffffffffffffffffffffffffffffffffffffffffffffffffffffffefffffc2f";
pub(crate) fn fe_from_hex(hex: &str) -> Fe {
let h = hex.as_bytes();
assert!(h.len() == 64, "field hex must be 64 chars");
let mut bytes = [0u8; 32];
let mut i = 0;
while i < 32 {
bytes[i] = (hex_nibble(h[2 * i]) << 4) | hex_nibble(h[2 * i + 1]);
i += 1;
}
Fe::from_be_bytes(&bytes)
}
const fn hex_nibble(c: u8) -> u8 {
match c {
b'0'..=b'9' => c - b'0',
b'a'..=b'f' => c - b'a' + 10,
b'A'..=b'F' => c - b'A' + 10,
_ => 0,
}
}
#[inline]
pub(crate) fn p() -> Fe {
fe_from_hex(P_HEX)
}
fn sqrt_exponent() -> Fe {
let p_plus_1 = p().wrapping_add(&Fe::ONE);
p_plus_1.shr1().shr1()
}
fn p_minus_2() -> Fe {
p().wrapping_sub(&Fe::from_u64(2))
}
pub(crate) trait FieldBackend {
fn zero(&self) -> Fe;
fn one(&self) -> Fe;
fn add(&self, a: &Fe, b: &Fe) -> Fe;
fn sub(&self, a: &Fe, b: &Fe) -> Fe;
fn mul(&self, a: &Fe, b: &Fe) -> Fe;
#[inline]
fn square(&self, a: &Fe) -> Fe {
self.mul(a, a)
}
fn negate(&self, a: &Fe) -> Fe;
fn invert(&self, a: &Fe) -> Fe;
fn sqrt(&self, a: &Fe) -> CtOption;
fn from_bytes_be(&self, bytes: &[u8; 32]) -> CtOption;
fn to_bytes_be(&self, a: &Fe) -> [u8; 32];
}
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) struct GenericMont {
fp: MontModulus<4>,
}
impl GenericMont {
#[cfg_attr(not(test), allow(dead_code))]
pub(crate) fn new() -> Self {
GenericMont {
fp: MontModulus::new(p()),
}
}
}
impl FieldBackend for GenericMont {
#[inline]
fn zero(&self) -> Fe {
Fe::ZERO
}
#[inline]
fn one(&self) -> Fe {
Fe::ONE
}
#[inline]
fn add(&self, a: &Fe, b: &Fe) -> Fe {
self.fp.add_mod(a, b)
}
#[inline]
fn sub(&self, a: &Fe, b: &Fe) -> Fe {
self.fp.sub_mod(a, b)
}
#[inline]
fn mul(&self, a: &Fe, b: &Fe) -> Fe {
self.fp.mul_mod(a, b)
}
#[inline]
fn negate(&self, a: &Fe) -> Fe {
self.fp.sub_mod(&Fe::ZERO, a)
}
fn invert(&self, a: &Fe) -> Fe {
let p_minus_2 = p().wrapping_sub(&Fe::from_u64(2));
self.fp.pow(a, &p_minus_2)
}
fn sqrt(&self, a: &Fe) -> CtOption {
let cand = self.fp.pow(a, &sqrt_exponent());
let ok = self.mul(&cand, &cand).ct_eq(a);
CtOption::new(cand, ok)
}
fn from_bytes_be(&self, bytes: &[u8; 32]) -> CtOption {
let v = Fe::from_be_bytes(bytes);
let in_range = v.ct_lt(&p());
CtOption::new(v, in_range)
}
fn to_bytes_be(&self, a: &Fe) -> [u8; 32] {
let mut out = [0u8; 32];
a.write_be_bytes(&mut out);
out
}
}
const C: u128 = 0x1_0000_03D1;
const P_LIMBS: [u64; 4] = [
0xFFFF_FFFE_FFFF_FC2F,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
];
pub(crate) struct Secp256k1Field;
impl Secp256k1Field {
#[inline]
pub(crate) fn new() -> Self {
Secp256k1Field
}
}
#[inline]
fn fold_carry(mut r: [u64; 4], carry: u64) -> ([u64; 4], u64) {
let mut acc: u128 = (r[0] as u128) + (carry as u128) * C;
r[0] = acc as u64;
acc >>= 64;
let mut i = 1;
while i < 4 {
acc += r[i] as u128;
r[i] = acc as u64;
acc >>= 64;
i += 1;
}
(r, acc as u64)
}
#[inline]
fn sub_p_mask(r: &[u64; 4], hi: u64) -> ([u64; 4], u64) {
let mut out = [0u64; 4];
let mut borrow: u128 = 0;
let mut i = 0;
while i < 4 {
let tmp = (r[i] as u128).wrapping_sub(P_LIMBS[i] as u128 + borrow);
out[i] = tmp as u64;
borrow = (tmp >> 64) & 1;
i += 1;
}
let ge = (hi != 0) | (borrow == 0);
let mask = if ge { u64::MAX } else { 0 };
(out, mask)
}
#[inline]
fn select(a: &[u64; 4], b: &[u64; 4], mask: u64) -> [u64; 4] {
let mut out = [0u64; 4];
let mut i = 0;
while i < 4 {
out[i] = (a[i] & !mask) | (b[i] & mask);
i += 1;
}
out
}
#[inline]
fn reduce_once(mut r: [u64; 4], mut hi: u64) -> [u64; 4] {
let mut k = 0;
while k < 2 {
let (diff, mask) = sub_p_mask(&r, hi);
r = select(&r, &diff, mask);
hi = 0;
k += 1;
}
r
}
#[inline]
fn reduce512(t: [u64; 8]) -> [u64; 4] {
let mut r = [0u64; 4];
let mut carry: u128 = 0;
let mut i = 0;
while i < 4 {
let acc = (t[i] as u128) + carry + (t[i + 4] as u128) * C;
r[i] = acc as u64;
carry = acc >> 64;
i += 1;
}
let d = carry as u64;
let (r, c1) = fold_carry(r, d);
let (r, c2) = fold_carry(r, c1);
let (r, c3) = fold_carry(r, c2);
reduce_once(r, c3)
}
#[inline]
fn mul_wide(a: &[u64; 4], b: &[u64; 4]) -> [u64; 8] {
let mut t = [0u64; 8];
let mut i = 0;
while i < 4 {
let mut carry: u128 = 0;
let mut j = 0;
while j < 4 {
let acc = (t[i + j] as u128) + (a[i] as u128) * (b[j] as u128) + carry;
t[i + j] = acc as u64;
carry = acc >> 64;
j += 1;
}
t[i + 4] = carry as u64;
i += 1;
}
t
}
impl Secp256k1Field {
#[inline]
fn mul_limbs(a: &[u64; 4], b: &[u64; 4]) -> [u64; 4] {
reduce512(mul_wide(a, b))
}
}
impl FieldBackend for Secp256k1Field {
#[inline]
fn zero(&self) -> Fe {
Fe::ZERO
}
#[inline]
fn one(&self) -> Fe {
Fe::ONE
}
#[inline]
fn add(&self, a: &Fe, b: &Fe) -> Fe {
let a = a.as_limbs();
let b = b.as_limbs();
let mut r = [0u64; 4];
let mut carry: u128 = 0;
let mut i = 0;
while i < 4 {
let acc = (a[i] as u128) + (b[i] as u128) + carry;
r[i] = acc as u64;
carry = acc >> 64;
i += 1;
}
Fe::from_limbs(reduce_once(r, carry as u64))
}
#[inline]
fn sub(&self, a: &Fe, b: &Fe) -> Fe {
let a = a.as_limbs();
let b = b.as_limbs();
let mut r = [0u64; 4];
let mut borrow: u128 = 0;
let mut i = 0;
while i < 4 {
let tmp = (a[i] as u128).wrapping_sub(b[i] as u128 + borrow);
r[i] = tmp as u64;
borrow = (tmp >> 64) & 1;
i += 1;
}
let mask = if borrow != 0 { u64::MAX } else { 0 };
let mut out = [0u64; 4];
let mut carry: u128 = 0;
let mut j = 0;
while j < 4 {
let acc = (r[j] as u128) + ((P_LIMBS[j] & mask) as u128) + carry;
out[j] = acc as u64;
carry = acc >> 64;
j += 1;
}
Fe::from_limbs(out)
}
#[inline]
fn mul(&self, a: &Fe, b: &Fe) -> Fe {
Fe::from_limbs(Self::mul_limbs(a.as_limbs(), b.as_limbs()))
}
#[inline]
fn negate(&self, a: &Fe) -> Fe {
let a = a.as_limbs();
let mut r = [0u64; 4];
let mut borrow: u128 = 0;
let mut i = 0;
while i < 4 {
let tmp = (P_LIMBS[i] as u128).wrapping_sub(a[i] as u128 + borrow);
r[i] = tmp as u64;
borrow = (tmp >> 64) & 1;
i += 1;
}
let is_zero = (a[0] | a[1] | a[2] | a[3]) == 0;
let zero_mask = if is_zero { u64::MAX } else { 0 };
let out = select(&r, &[0u64; 4], zero_mask);
Fe::from_limbs(out)
}
fn invert(&self, a: &Fe) -> Fe {
self.pow(a, &p_minus_2())
}
fn sqrt(&self, a: &Fe) -> CtOption {
let cand = self.pow(a, &sqrt_exponent());
let ok = self.square(&cand).ct_eq(a);
CtOption::new(cand, ok)
}
fn from_bytes_be(&self, bytes: &[u8; 32]) -> CtOption {
let v = Fe::from_be_bytes(bytes);
let in_range = v.ct_lt(&p());
CtOption::new(v, in_range)
}
fn to_bytes_be(&self, a: &Fe) -> [u8; 32] {
let mut out = [0u8; 32];
a.write_be_bytes(&mut out);
out
}
}
impl Secp256k1Field {
fn pow(&self, base: &Fe, exp: &Fe) -> Fe {
let exp = exp.as_limbs();
let mut acc = Fe::ONE;
let mut i = 4;
while i > 0 {
i -= 1;
let limb = exp[i];
let mut bit = 64;
while bit > 0 {
bit -= 1;
acc = self.square(&acc);
let prod = self.mul(&acc, base);
if (limb >> bit) & 1 == 1 {
acc = prod;
}
}
}
acc
}
}
#[cfg(test)]
mod backend_tests {
use super::*;
struct SplitMix64(u64);
impl SplitMix64 {
fn next_u64(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut z = self.0;
z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
z ^ (z >> 31)
}
}
fn rand_fe(rng: &mut SplitMix64) -> Fe {
let limbs = [
rng.next_u64(),
rng.next_u64(),
rng.next_u64(),
rng.next_u64(),
];
Fe::from_limbs(limbs).reduce(&p())
}
fn bytes(a: &Fe) -> [u8; 32] {
let mut out = [0u8; 32];
a.write_be_bytes(&mut out);
out
}
fn edge_cases() -> [Fe; 12] {
let prime = p();
let p_minus_1 = prime.wrapping_sub(&Fe::ONE);
let p_minus_2 = prime.wrapping_sub(&Fe::from_u64(2));
let all_ones = Fe::from_limbs([u64::MAX; 4]).reduce(&prime);
let near_low = Fe::from_limbs([0xFFFF_FFFE_FFFF_FC2E, !0, !0, !0]);
let mid = Fe::from_limbs([0, 0, 0, 0x8000_0000_0000_0000]);
[
Fe::ZERO,
Fe::ONE,
Fe::from_u64(2),
Fe::from_u64(7),
p_minus_1,
p_minus_2,
all_ones,
near_low,
mid,
Fe::from_u64(0xFFFF_FFFF),
Fe::from_limbs([C as u64, 0, 0, 0]),
Fe::from_limbs([(C as u64) - 1, 0, 0, 0]),
]
}
fn check_pair(g: &GenericMont, n: &Secp256k1Field, a: &Fe, b: &Fe) {
assert_eq!(
bytes(&g.add(a, b)),
bytes(&n.add(a, b)),
"add mismatch: a={:x?} b={:x?}",
a.as_limbs(),
b.as_limbs()
);
assert_eq!(
bytes(&g.sub(a, b)),
bytes(&n.sub(a, b)),
"sub mismatch: a={:x?} b={:x?}",
a.as_limbs(),
b.as_limbs()
);
assert_eq!(
bytes(&g.mul(a, b)),
bytes(&n.mul(a, b)),
"mul mismatch: a={:x?} b={:x?}",
a.as_limbs(),
b.as_limbs()
);
}
fn check_unary(g: &GenericMont, n: &Secp256k1Field, a: &Fe) {
assert_eq!(
bytes(&g.square(a)),
bytes(&n.square(a)),
"square mismatch: a={:x?}",
a.as_limbs()
);
assert_eq!(
bytes(&g.negate(a)),
bytes(&n.negate(a)),
"negate mismatch: a={:x?}",
a.as_limbs()
);
assert_eq!(
bytes(&g.invert(a)),
bytes(&n.invert(a)),
"invert mismatch: a={:x?}",
a.as_limbs()
);
let gs = g.sqrt(a);
let ns = n.sqrt(a);
assert_eq!(
bool::from(gs.is_some),
bool::from(ns.is_some),
"sqrt is_square mismatch: a={:x?}",
a.as_limbs()
);
if bool::from(gs.is_some) {
assert_eq!(
bytes(&gs.value),
bytes(&ns.value),
"sqrt root mismatch: a={:x?}",
a.as_limbs()
);
}
}
#[test]
fn native_matches_generic_edge_cases() {
let g = GenericMont::new();
let n = Secp256k1Field::new();
let cases = edge_cases();
for a in &cases {
check_unary(&g, &n, a);
for b in &cases {
check_pair(&g, &n, a, b);
}
}
}
#[test]
fn native_matches_generic_random_batch() {
let g = GenericMont::new();
let n = Secp256k1Field::new();
let mut rng = SplitMix64(0x0123_4567_89AB_CDEF);
for _ in 0..100_000 {
let a = rand_fe(&mut rng);
let b = rand_fe(&mut rng);
check_pair(&g, &n, &a, &b);
}
}
#[test]
fn native_matches_generic_random_unary() {
let g = GenericMont::new();
let n = Secp256k1Field::new();
let mut rng = SplitMix64(0xDEAD_BEEF_CAFE_F00D);
for _ in 0..20_000 {
let a = rand_fe(&mut rng);
check_unary(&g, &n, &a);
}
}
#[test]
fn native_sqrt_roundtrips_squares() {
let g = GenericMont::new();
let n = Secp256k1Field::new();
let mut rng = SplitMix64(0xA5A5_5A5A_F0F0_0F0F);
for _ in 0..5_000 {
let a = rand_fe(&mut rng);
let sq = n.square(&a);
let r = n.sqrt(&sq);
assert!(
bool::from(r.is_some),
"square has no root: a={:x?}",
a.as_limbs()
);
assert_eq!(
bytes(&n.square(&r.value)),
bytes(&sq),
"native sqrt does not round-trip: a={:x?}",
a.as_limbs()
);
let gr = g.sqrt(&sq);
assert_eq!(bytes(&gr.value), bytes(&r.value));
}
}
}