#[derive(Clone, Copy, Debug)]
pub struct FieldElement<const LIMBS: usize> {
pub limbs: [u64; LIMBS],
}
impl<const LIMBS: usize> PartialEq for FieldElement<LIMBS> {
fn eq(&self, other: &Self) -> bool {
let mut acc = 0u64;
for i in 0..LIMBS {
acc |= self.limbs[i] ^ other.limbs[i];
}
acc == 0
}
}
impl<const LIMBS: usize> Eq for FieldElement<LIMBS> {}
impl<const LIMBS: usize> FieldElement<LIMBS> {
pub const ZERO: Self = Self { limbs: [0u64; LIMBS] };
pub const fn one() -> Self {
let mut limbs = [0u64; LIMBS];
limbs[0] = 1;
Self { limbs }
}
pub fn is_zero(&self) -> bool {
let mut acc = 0u64;
for i in 0..LIMBS {
acc |= self.limbs[i];
}
acc == 0
}
pub fn from_bytes_be(bytes: &[u8]) -> Self {
let mut limbs = [0u64; LIMBS];
let byte_len = LIMBS * 8;
for i in 0..byte_len.min(bytes.len()) {
let byte_idx = bytes.len() - 1 - i;
let limb_idx = i / 8;
let shift = (i % 8) * 8;
limbs[limb_idx] |= (bytes[byte_idx] as u64) << shift;
}
Self { limbs }
}
pub fn to_bytes_be(&self) -> Vec<u8> {
let byte_len = LIMBS * 8;
let mut out = vec![0u8; byte_len];
for i in 0..byte_len {
let limb_idx = i / 8;
let shift = (i % 8) * 8;
out[byte_len - 1 - i] = (self.limbs[limb_idx] >> shift) as u8;
}
out
}
pub fn from_bytes_le(bytes: &[u8]) -> Self {
let mut limbs = [0u64; LIMBS];
let byte_len = LIMBS * 8;
for i in 0..byte_len.min(bytes.len()) {
let limb_idx = i / 8;
let shift = (i % 8) * 8;
limbs[limb_idx] |= (bytes[i] as u64) << shift;
}
Self { limbs }
}
pub fn to_bytes_le(&self) -> Vec<u8> {
let byte_len = LIMBS * 8;
let mut out = vec![0u8; byte_len];
for i in 0..byte_len {
let limb_idx = i / 8;
let shift = (i % 8) * 8;
out[i] = (self.limbs[limb_idx] >> shift) as u8;
}
out
}
}
pub const P256_P: [u64; 4] = [
0xFFFF_FFFF_FFFF_FFFF,
0x0000_0000_FFFF_FFFF,
0x0000_0000_0000_0000,
0xFFFF_FFFF_0000_0001,
];
pub const P256_N: [u64; 4] = [
0xF3B9_CAC2_FC63_2551,
0xBCE6_FAAD_A717_9E84,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_0000_0000,
];
pub const P384_P: [u64; 6] = [
0x0000_0000_FFFF_FFFF,
0xFFFF_FFFF_0000_0000,
0xFFFF_FFFF_FFFF_FFFE,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
];
pub const P384_N: [u64; 6] = [
0xECEC_196A_CCC5_2973,
0x581A_0DB2_48B0_A77A,
0xC763_4D81_F437_2DDF,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
];
pub const CURVE25519_P: [u64; 4] = [
0xFFFF_FFFF_FFFF_FFED, 0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0x7FFF_FFFF_FFFF_FFFF, ];
pub const CURVE448_P: [u64; 7] = [
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFE_FFFF_FFFF, 0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
0xFFFF_FFFF_FFFF_FFFF,
];
pub fn field_add<const LIMBS: usize>(
a: &FieldElement<LIMBS>,
b: &FieldElement<LIMBS>,
p: &[u64; LIMBS],
) -> FieldElement<LIMBS> {
let mut result = [0u64; LIMBS];
let mut carry: u64 = 0;
for i in 0..LIMBS {
let sum = (a.limbs[i] as u128) + (b.limbs[i] as u128) + (carry as u128);
result[i] = sum as u64;
carry = (sum >> 64) as u64;
}
let mut borrow: u64 = 0;
let mut sub = [0u64; LIMBS];
for i in 0..LIMBS {
let diff = (result[i] as u128)
.wrapping_sub(p[i] as u128)
.wrapping_sub(borrow as u128);
sub[i] = diff as u64;
borrow = ((diff >> 64) as u64) & 1;
}
let need_sub = carry | (1 - borrow);
let mask = core::hint::black_box(0u64.wrapping_sub(need_sub));
let inv_mask = !mask;
let mut out = FieldElement { limbs: [0u64; LIMBS] };
for i in 0..LIMBS {
out.limbs[i] = (sub[i] & mask) | (result[i] & inv_mask);
}
out
}
pub fn field_sub<const LIMBS: usize>(
a: &FieldElement<LIMBS>,
b: &FieldElement<LIMBS>,
p: &[u64; LIMBS],
) -> FieldElement<LIMBS> {
let mut result = [0u64; LIMBS];
let mut borrow: u64 = 0;
for i in 0..LIMBS {
let diff = (a.limbs[i] as u128)
.wrapping_sub(b.limbs[i] as u128)
.wrapping_sub(borrow as u128);
result[i] = diff as u64;
borrow = ((diff >> 64) as u64) & 1;
}
let mut carry: u64 = 0;
let mut added = [0u64; LIMBS];
for i in 0..LIMBS {
let sum = (result[i] as u128) + (p[i] as u128) + (carry as u128);
added[i] = sum as u64;
carry = (sum >> 64) as u64;
}
let mask = core::hint::black_box(0u64.wrapping_sub(borrow));
let inv_mask = !mask;
let mut out = FieldElement { limbs: [0u64; LIMBS] };
for i in 0..LIMBS {
out.limbs[i] = (added[i] & mask) | (result[i] & inv_mask);
}
out
}
pub fn field_neg<const LIMBS: usize>(a: &FieldElement<LIMBS>, p: &[u64; LIMBS]) -> FieldElement<LIMBS> {
field_sub(&FieldElement::<LIMBS>::ZERO, a, p)
}
pub fn field_mul<const LIMBS: usize>(
a: &FieldElement<LIMBS>,
b: &FieldElement<LIMBS>,
p: &[u64; LIMBS],
) -> FieldElement<LIMBS> {
let mut product = [0u64; 18];
for i in 0..LIMBS {
let mut carry: u64 = 0;
for j in 0..LIMBS {
let uv = (a.limbs[i] as u128) * (b.limbs[j] as u128) + (product[i + j] as u128) + (carry as u128);
product[i + j] = uv as u64;
carry = (uv >> 64) as u64;
}
product[i + LIMBS] = carry;
}
reduce_wide::<LIMBS>(&product, p)
}
pub fn field_sqr<const LIMBS: usize>(a: &FieldElement<LIMBS>, p: &[u64; LIMBS]) -> FieldElement<LIMBS> {
field_mul(a, a, p)
}
fn reduce_wide<const LIMBS: usize>(product: &[u64; 18], p: &[u64; LIMBS]) -> FieldElement<LIMBS> {
let double = 2 * LIMBS;
let total_bits = double * 64;
let mut remainder = FieldElement { limbs: [0u64; LIMBS] };
for i in (0..total_bits).rev() {
let mut carry = 0u64;
for j in 0..LIMBS {
let new_carry = remainder.limbs[j] >> 63;
remainder.limbs[j] = (remainder.limbs[j] << 1) | carry;
carry = new_carry;
}
let word_idx = i / 64;
let bit_idx = i % 64;
let bit = (product[word_idx] >> bit_idx) & 1;
remainder.limbs[0] |= bit;
let mut trial_borrow: u64 = 0;
let mut trial = [0u64; LIMBS];
for j in 0..LIMBS {
let diff = (remainder.limbs[j] as u128)
.wrapping_sub(p[j] as u128)
.wrapping_sub(trial_borrow as u128);
trial[j] = diff as u64;
trial_borrow = ((diff >> 64) as u64) & 1;
}
let need_sub = carry | (1u64.wrapping_sub(trial_borrow));
let mask = core::hint::black_box(0u64.wrapping_sub(need_sub));
let inv_mask = !mask;
for j in 0..LIMBS {
remainder.limbs[j] = (trial[j] & mask) | (remainder.limbs[j] & inv_mask);
}
}
remainder
}
pub fn field_inv<const LIMBS: usize>(a: &FieldElement<LIMBS>, p: &[u64; LIMBS]) -> FieldElement<LIMBS> {
let mut pm2 = [0u64; LIMBS];
let mut borrow: u64 = 0;
for i in 0..LIMBS {
let sub_val = if i == 0 { 2u64 } else { 0u64 };
let diff = (p[i] as u128)
.wrapping_sub(sub_val as u128)
.wrapping_sub(borrow as u128);
pm2[i] = diff as u64;
borrow = ((diff >> 64) as u64) & 1;
}
field_pow(a, &pm2, p)
}
pub fn field_pow<const LIMBS: usize>(
base: &FieldElement<LIMBS>,
exp: &[u64; LIMBS],
p: &[u64; LIMBS],
) -> FieldElement<LIMBS> {
let mut result = FieldElement::<LIMBS>::one();
for i in (0..LIMBS).rev() {
for bit in (0..64).rev() {
result = field_sqr(&result, p);
let b = (exp[i] >> bit) & 1;
let product = field_mul(&result, base, p);
let mask = 0u64.wrapping_sub(b);
let inv = !mask;
for j in 0..LIMBS {
result.limbs[j] = (product.limbs[j] & mask) | (result.limbs[j] & inv);
}
}
}
result
}
pub fn field_sqrt_p3mod4<const LIMBS: usize>(a: &FieldElement<LIMBS>, p: &[u64; LIMBS]) -> FieldElement<LIMBS> {
let mut exp = [0u64; LIMBS];
let mut carry: u128 = 1;
for i in 0..LIMBS {
let sum = p[i] as u128 + carry;
exp[i] = sum as u64;
carry = sum >> 64;
}
debug_assert_eq!(carry, 0, "p + 1 overflowed the limb array");
let mut prev_lo: u64 = 0;
for i in (0..LIMBS).rev() {
let new_prev = exp[i] & 0x3;
exp[i] = (exp[i] >> 2) | (prev_lo << 62);
prev_lo = new_prev;
}
field_pow(a, &exp, p)
}
pub fn scalar_add<const LIMBS: usize>(
a: &FieldElement<LIMBS>,
b: &FieldElement<LIMBS>,
n: &[u64; LIMBS],
) -> FieldElement<LIMBS> {
field_add(a, b, n)
}
pub fn scalar_mul<const LIMBS: usize>(
a: &FieldElement<LIMBS>,
b: &FieldElement<LIMBS>,
n: &[u64; LIMBS],
) -> FieldElement<LIMBS> {
field_mul(a, b, n)
}
pub fn scalar_inv<const LIMBS: usize>(a: &FieldElement<LIMBS>, n: &[u64; LIMBS]) -> FieldElement<LIMBS> {
field_inv(a, n)
}
pub fn scalar_is_valid<const LIMBS: usize>(a: &FieldElement<LIMBS>, n: &[u64; LIMBS]) -> bool {
let mut borrow: u64 = 0;
for i in 0..LIMBS {
let diff = (a.limbs[i] as u128)
.wrapping_sub(n[i] as u128)
.wrapping_sub(borrow as u128);
borrow = ((diff >> 64) as u64) & 1;
}
borrow == 1
}
#[cfg(test)]
mod tests {
use super::*;
fn hex_to_bytes(hex: &str) -> Vec<u8> {
(0..hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&hex[i..i + 2], 16).unwrap())
.collect()
}
#[test]
fn test_p256_field_add_sub() {
let a = FieldElement::<4>::from_bytes_be(&[1]);
let b = FieldElement::<4>::from_bytes_be(&[2]);
let sum = field_add(&a, &b, &P256_P);
assert_eq!(sum.limbs[0], 3);
let diff = field_sub(&sum, &b, &P256_P);
assert_eq!(diff.limbs[0], 1);
}
#[test]
fn test_p256_field_mul() {
let a = FieldElement::<4>::from_bytes_be(&[3]);
let b = FieldElement::<4>::from_bytes_be(&[7]);
let prod = field_mul(&a, &b, &P256_P);
assert_eq!(prod.limbs[0], 21);
}
#[test]
fn test_p256_mul_large() {
let mut pm1 = FieldElement::<4> { limbs: P256_P };
pm1.limbs[0] -= 1;
let result = field_mul(&pm1, &pm1, &P256_P);
assert_eq!(result, FieldElement::<4>::one(), "(p-1)^2 should be 1");
}
#[test]
fn test_p256_pow_small() {
let two = FieldElement::<4>::from_bytes_be(&[2]);
let exp = [10u64, 0, 0, 0];
let result = field_pow(&two, &exp, &P256_P);
assert_eq!(result.limbs[0], 1024, "2^10 should be 1024");
assert_eq!(result.limbs[1], 0);
assert_eq!(result.limbs[2], 0);
assert_eq!(result.limbs[3], 0);
}
#[test]
fn test_p256_field_mul_known() {
let gx = FieldElement::<4>::from_bytes_be(&hex_to_bytes(
"6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296",
));
let gy = FieldElement::<4>::from_bytes_be(&hex_to_bytes(
"4FE342E2FE1A7F9B8EE7EB4A7C0F9E162BCE33576B315ECECBB6406837BF51F5",
));
let product = field_mul(&gx, &gy, &P256_P);
let product_hex: String = product.to_bytes_be().iter().map(|b| format!("{:02x}", b)).collect();
eprintln!("Gx * Gy mod p = {}", product_hex);
let gy_inv = field_inv(&gy, &P256_P);
let should_be_gx = field_mul(&product, &gy_inv, &P256_P);
assert_eq!(should_be_gx, gx, "field_mul or field_inv inconsistency");
}
#[test]
fn test_p256_field_inv() {
let a = FieldElement::<4>::from_bytes_be(&[42]);
let a_inv = field_inv(&a, &P256_P);
let product = field_mul(&a, &a_inv, &P256_P);
assert_eq!(product, FieldElement::<4>::one());
}
#[test]
fn test_p256_field_wrap() {
let mut pm1 = FieldElement::<4> { limbs: P256_P };
pm1.limbs[0] -= 1;
let one = FieldElement::<4>::one();
let result = field_add(&pm1, &one, &P256_P);
assert!(result.is_zero());
}
#[test]
fn test_bytes_roundtrip() {
let bytes = hex_to_bytes("6B17D1F2E12C4247F8BCE6E563A440F277037D812DEB33A0F4A13945D898C296");
let fe = FieldElement::<4>::from_bytes_be(&bytes);
let out = fe.to_bytes_be();
assert_eq!(out, bytes);
}
#[test]
fn test_scalar_inv_p256() {
let a = FieldElement::<4>::from_bytes_be(&[7]);
let a_inv = scalar_inv(&a, &P256_N);
let product = scalar_mul(&a, &a_inv, &P256_N);
assert_eq!(product, FieldElement::<4>::one());
}
#[test]
fn test_field_sqrt_p3mod4_p256() {
let four = FieldElement::<4>::from_bytes_be(&[4]);
let y = field_sqrt_p3mod4(&four, &P256_P);
let y2 = field_sqr(&y, &P256_P);
assert_eq!(y2, four);
let twenty_five = FieldElement::<4>::from_bytes_be(&[25]);
let y = field_sqrt_p3mod4(&twenty_five, &P256_P);
let y2 = field_sqr(&y, &P256_P);
assert_eq!(y2, twenty_five);
}
#[test]
fn test_field_sqrt_p3mod4_p384() {
let four = FieldElement::<6>::from_bytes_be(&[4]);
let y = field_sqrt_p3mod4(&four, &P384_P);
let y2 = field_sqr(&y, &P384_P);
assert_eq!(y2, four);
}
#[test]
fn test_field_sqrt_p3mod4_non_residue_is_caught_by_squaring() {
let three = FieldElement::<4>::from_bytes_be(&[3]);
let y = field_sqrt_p3mod4(&three, &P256_P);
let y2 = field_sqr(&y, &P256_P);
assert_ne!(
y2, three,
"3 is known non-residue on P-256; sqrt should not satisfy y^2 == 3"
);
}
}