use crate::{FieldBytes, NistP521, Uint};
use core::{
cmp::Ordering,
fmt::{self, Debug},
iter::{Product, Sum},
ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
};
use elliptic_curve::{
Error, FieldBytesEncoding, Generate,
array::Array,
bigint::{Word, cpubits, modular::Retrieve},
ff::{self, Field, PrimeField},
ops::Invert,
rand_core::TryRng,
subtle::{Choice, ConditionallySelectable, ConstantTimeEq, ConstantTimeLess, CtOption},
zeroize::DefaultIsZeroes,
};
use primefield::bigint::{self, Limb, Odd};
cpubits! {
32 => {
#[allow(clippy::needless_lifetimes, clippy::unnecessary_cast)]
#[allow(dead_code)]
#[rustfmt::skip]
#[path = "field/p521_32.rs"]
mod field_impl;
}
64 => {
#[allow(clippy::needless_lifetimes, clippy::unnecessary_cast)]
#[allow(dead_code)]
#[rustfmt::skip]
#[path = "field/p521_64.rs"]
mod field_impl;
}
}
mod loose;
use self::field_impl::*;
pub(crate) use self::loose::LooseFieldElement;
const MODULUS_HEX: &str = {
cpubits! {
32 => {
"000001ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
}
64 => {
"00000000000001ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"
}
}
};
pub(crate) const MODULUS: Uint = Uint::from_be_hex(MODULUS_HEX);
#[derive(Clone, Copy)]
pub struct FieldElement(pub(crate) fiat_p521_tight_field_element);
impl FieldElement {
pub const ZERO: Self = Self::from_u64(0);
pub const ONE: Self = Self::from_u64(1);
cpubits! {
32 => { const LIMBS: usize = 19; }
64 => { const LIMBS: usize = 9; }
}
pub fn from_bytes(repr: &FieldBytes) -> CtOption<Self> {
let uint = <Uint as FieldBytesEncoding<NistP521>>::decode_field_bytes(repr);
Self::from_uint(uint)
}
pub fn from_slice(slice: &[u8]) -> elliptic_curve::Result<Self> {
let field_bytes = FieldBytes::try_from(slice).map_err(|_| Error)?;
Self::from_bytes(&field_bytes).into_option().ok_or(Error)
}
pub fn from_uint(uint: Uint) -> CtOption<Self> {
let is_some = uint.ct_lt(&MODULUS);
CtOption::new(Self::from_uint_unchecked(uint), is_some)
}
pub(crate) const fn from_hex(hex: &str) -> Self {
assert!(
hex.len() == 521usize.div_ceil(8) * 2,
"hex is the wrong length (expected 132 hex chars)"
);
let mut hex_bytes = [b'0'; { Uint::BITS as usize / 4 }];
let offset = hex_bytes.len() - hex.len();
let mut i = 0;
while i < hex.len() {
hex_bytes[i + offset] = hex.as_bytes()[i];
i += 1;
}
let uint = match core::str::from_utf8(&hex_bytes) {
Ok(padded_hex) => Uint::from_be_hex(padded_hex),
Err(_) => panic!("invalid hex string"),
};
assert!(matches!(uint.cmp_vartime(&MODULUS), Ordering::Less));
Self::from_uint_unchecked(uint)
}
pub const fn from_u64(w: u64) -> Self {
Self::from_uint_unchecked(Uint::from_u64(w))
}
pub(crate) const fn from_uint_unchecked(w: Uint) -> Self {
let le_bytes_wide = w.to_le_bytes();
let mut le_bytes = [0u8; 66];
let mut i = 0;
while i < le_bytes.len() {
le_bytes[i] = le_bytes_wide.as_slice()[i];
i += 1;
}
let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
fiat_p521_from_bytes(&mut out, &le_bytes);
Self(out)
}
pub const fn to_bytes(self) -> FieldBytes {
const BYTES: usize = 66;
let mut ret = [0u8; BYTES];
fiat_p521_to_bytes(&mut ret, &self.0);
let mut i = 0;
while i < (BYTES / 2) {
let j = BYTES - i - 1;
let tmp = ret[i];
ret[i] = ret[j];
ret[j] = tmp;
i += 1;
}
Array(ret)
}
pub fn is_odd(&self) -> Choice {
Choice::from(self.0[0] as u8 & 1)
}
pub fn is_even(&self) -> Choice {
!self.is_odd()
}
pub fn is_zero(&self) -> Choice {
self.ct_eq(&Self::ZERO)
}
#[inline]
pub const fn add_loose(&self, rhs: &Self) -> LooseFieldElement {
let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
fiat_p521_add(&mut out, &self.0, &rhs.0);
LooseFieldElement(out)
}
#[inline]
#[must_use]
pub const fn double_loose(&self) -> LooseFieldElement {
self.add_loose(self)
}
#[inline]
pub const fn sub_loose(&self, rhs: &Self) -> LooseFieldElement {
let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
fiat_p521_sub(&mut out, &self.0, &rhs.0);
LooseFieldElement(out)
}
#[inline]
pub const fn neg_loose(&self) -> LooseFieldElement {
let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
fiat_p521_opp(&mut out, &self.0);
LooseFieldElement(out)
}
#[inline]
pub const fn add(&self, rhs: &Self) -> Self {
let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
fiat_p521_carry_add(&mut out, &self.0, &rhs.0);
Self(out)
}
#[inline]
pub const fn sub(&self, rhs: &Self) -> Self {
let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
fiat_p521_carry_sub(&mut out, &self.0, &rhs.0);
Self(out)
}
#[inline]
pub const fn neg(&self) -> Self {
let mut out = fiat_p521_tight_field_element([0; Self::LIMBS]);
fiat_p521_carry_opp(&mut out, &self.0);
Self(out)
}
#[inline]
#[must_use]
pub const fn double(&self) -> Self {
self.add(self)
}
#[inline]
pub const fn multiply(&self, rhs: &Self) -> Self {
self.relax().multiply(&rhs.relax())
}
#[inline]
pub const fn square(&self) -> Self {
self.relax().square()
}
const fn sqn(&self, n: usize) -> Self {
self.sqn_vartime(n)
}
pub const fn pow_vartime<const RHS_LIMBS: usize>(&self, exp: &bigint::Uint<RHS_LIMBS>) -> Self {
let mut res = Self::ONE;
let mut i = RHS_LIMBS;
while i > 0 {
i -= 1;
let mut j = Limb::BITS;
while j > 0 {
j -= 1;
res = res.square();
if ((exp.as_limbs()[i].0 >> j) & 1) == 1 {
res = res.multiply(self);
}
}
}
res
}
pub const fn sqn_vartime(&self, n: usize) -> Self {
let mut x = *self;
let mut i = 0;
while i < n {
x = x.square();
i += 1;
}
x
}
pub fn invert(&self) -> CtOption<Self> {
self.to_uint()
.invert_odd_mod(const { &Odd::from_be_hex(MODULUS_HEX) })
.map(Self::from_uint_unchecked)
.into()
}
pub fn invert_vartime(&self) -> CtOption<Self> {
self.to_uint()
.invert_odd_mod_vartime(const { &Odd::from_be_hex(MODULUS_HEX) })
.map(Self::from_uint_unchecked)
.into()
}
const fn invert_unwrap(&self) -> Self {
Self::from_uint_unchecked(
self.to_uint()
.invert_odd_mod(const { &Odd::from_be_hex(MODULUS_HEX) })
.expect_copied("input should be non-zero"),
)
}
pub fn sqrt(&self) -> CtOption<Self> {
let sqrt = self.sqn(519);
CtOption::new(sqrt, sqrt.square().ct_eq(self))
}
#[inline]
pub const fn relax(&self) -> LooseFieldElement {
let mut out = fiat_p521_loose_field_element([0; Self::LIMBS]);
fiat_p521_relax(&mut out, &self.0);
LooseFieldElement(out)
}
#[inline]
pub(crate) const fn to_uint(self) -> Uint {
let field_bytes = self.to_bytes();
let mut uint_bytes = [0u8; Uint::LIMBS * Limb::BYTES];
let offset = uint_bytes.len() - field_bytes.0.len();
let mut i = 0;
while i < field_bytes.0.len() {
uint_bytes[i + offset] = field_bytes.0[i];
i += 1
}
Uint::from_be_slice(&uint_bytes)
}
}
impl AsRef<fiat_p521_tight_field_element> for FieldElement {
fn as_ref(&self) -> &fiat_p521_tight_field_element {
&self.0
}
}
impl Default for FieldElement {
fn default() -> Self {
Self::ZERO
}
}
impl Debug for FieldElement {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let bytes = self.to_bytes();
let formatter = base16ct::HexDisplay(&bytes);
f.debug_tuple("FieldElement")
.field(&format_args!("0x{formatter:X}"))
.finish()
}
}
impl Eq for FieldElement {}
impl PartialEq for FieldElement {
fn eq(&self, rhs: &Self) -> bool {
self.ct_eq(rhs).into()
}
}
impl From<u32> for FieldElement {
fn from(n: u32) -> FieldElement {
Self::from_uint_unchecked(Uint::from(n))
}
}
impl From<u64> for FieldElement {
fn from(n: u64) -> FieldElement {
Self::from_uint_unchecked(Uint::from(n))
}
}
impl From<u128> for FieldElement {
fn from(n: u128) -> FieldElement {
Self::from_uint_unchecked(Uint::from(n))
}
}
impl ConditionallySelectable for FieldElement {
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let out = <[Word; Self::LIMBS]>::conditional_select(&a.0.0, &b.0.0, choice);
Self(fiat_p521_tight_field_element(out))
}
}
impl ConstantTimeEq for FieldElement {
fn ct_eq(&self, other: &Self) -> Choice {
let a = self.to_bytes();
let b = other.to_bytes();
a.ct_eq(&b)
}
}
impl DefaultIsZeroes for FieldElement {}
impl Field for FieldElement {
const ZERO: Self = Self::ZERO;
const ONE: Self = Self::ONE;
fn try_random<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
let mut bytes = <FieldBytes>::default();
loop {
rng.try_fill_bytes(&mut bytes)?;
if let Some(fe) = Self::from_bytes(&bytes).into() {
return Ok(fe);
}
}
}
fn is_zero(&self) -> Choice {
Self::ZERO.ct_eq(self)
}
fn square(&self) -> Self {
self.square()
}
fn double(&self) -> Self {
self.double()
}
fn invert(&self) -> CtOption<Self> {
self.invert()
}
fn sqrt(&self) -> CtOption<Self> {
self.sqrt()
}
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
ff::helpers::sqrt_ratio_generic(num, div)
}
}
impl Generate for FieldElement {
fn try_generate_from_rng<R: TryRng + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
Self::try_random(rng)
}
}
impl PrimeField for FieldElement {
type Repr = FieldBytes;
const MODULUS: &'static str = MODULUS_HEX;
const NUM_BITS: u32 = 521;
const CAPACITY: u32 = 520;
const TWO_INV: Self = Self::from_u64(2).invert_unwrap();
const MULTIPLICATIVE_GENERATOR: Self = Self::from_u64(3);
const S: u32 = 1;
const ROOT_OF_UNITY: Self = Self::from_hex(
"01fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe",
);
const ROOT_OF_UNITY_INV: Self = Self::ROOT_OF_UNITY.invert_unwrap();
const DELTA: Self = Self::from_u64(9);
#[inline]
fn from_repr(bytes: FieldBytes) -> CtOption<Self> {
Self::from_bytes(&bytes)
}
#[inline]
fn to_repr(&self) -> FieldBytes {
self.to_bytes()
}
#[inline]
fn is_odd(&self) -> Choice {
self.is_odd()
}
}
impl Add for FieldElement {
type Output = FieldElement;
#[inline]
fn add(self, rhs: FieldElement) -> FieldElement {
Self::add(&self, &rhs)
}
}
impl Add<&FieldElement> for FieldElement {
type Output = FieldElement;
#[inline]
fn add(self, rhs: &FieldElement) -> FieldElement {
Self::add(&self, rhs)
}
}
impl Add<&FieldElement> for &FieldElement {
type Output = FieldElement;
#[inline]
fn add(self, rhs: &FieldElement) -> FieldElement {
FieldElement::add(self, rhs)
}
}
impl AddAssign<FieldElement> for FieldElement {
#[inline]
fn add_assign(&mut self, other: FieldElement) {
*self = *self + other;
}
}
impl AddAssign<&FieldElement> for FieldElement {
#[inline]
fn add_assign(&mut self, other: &FieldElement) {
*self = *self + other;
}
}
impl Sub for FieldElement {
type Output = FieldElement;
#[inline]
fn sub(self, rhs: FieldElement) -> FieldElement {
Self::sub(&self, &rhs)
}
}
impl Sub<&FieldElement> for FieldElement {
type Output = FieldElement;
#[inline]
fn sub(self, rhs: &FieldElement) -> FieldElement {
Self::sub(&self, rhs)
}
}
impl Sub<&FieldElement> for &FieldElement {
type Output = FieldElement;
#[inline]
fn sub(self, rhs: &FieldElement) -> FieldElement {
FieldElement::sub(self, rhs)
}
}
impl SubAssign<FieldElement> for FieldElement {
#[inline]
fn sub_assign(&mut self, other: FieldElement) {
*self = *self - other;
}
}
impl SubAssign<&FieldElement> for FieldElement {
#[inline]
fn sub_assign(&mut self, other: &FieldElement) {
*self = *self - other;
}
}
impl Mul for FieldElement {
type Output = FieldElement;
#[inline]
fn mul(self, rhs: FieldElement) -> FieldElement {
self.relax().mul(&rhs.relax())
}
}
impl Mul<&FieldElement> for FieldElement {
type Output = FieldElement;
#[inline]
fn mul(self, rhs: &FieldElement) -> FieldElement {
self.relax().mul(&rhs.relax())
}
}
impl Mul<&FieldElement> for &FieldElement {
type Output = FieldElement;
#[inline]
fn mul(self, rhs: &FieldElement) -> FieldElement {
self.relax().mul(&rhs.relax())
}
}
impl MulAssign<&FieldElement> for FieldElement {
#[inline]
fn mul_assign(&mut self, other: &FieldElement) {
*self = *self * other;
}
}
impl MulAssign for FieldElement {
#[inline]
fn mul_assign(&mut self, other: FieldElement) {
*self = *self * other;
}
}
impl Neg for FieldElement {
type Output = FieldElement;
#[inline]
fn neg(self) -> FieldElement {
Self::neg(&self)
}
}
impl Sum for FieldElement {
fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(Add::add).unwrap_or(Self::ZERO)
}
}
impl<'a> Sum<&'a FieldElement> for FieldElement {
fn sum<I: Iterator<Item = &'a FieldElement>>(iter: I) -> Self {
iter.copied().sum()
}
}
impl Product for FieldElement {
fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
iter.reduce(Mul::mul).unwrap_or(Self::ONE)
}
}
impl<'a> Product<&'a FieldElement> for FieldElement {
fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
iter.copied().product()
}
}
impl Invert for FieldElement {
type Output = CtOption<Self>;
fn invert(&self) -> CtOption<Self> {
self.invert()
}
fn invert_vartime(&self) -> CtOption<Self> {
self.invert_vartime()
}
}
impl Retrieve for FieldElement {
type Output = Uint;
fn retrieve(&self) -> Uint {
self.to_uint()
}
}
#[cfg(test)]
mod tests {
use super::{FieldElement, Uint};
use hex_literal::hex;
primefield::test_primefield!(FieldElement, Uint);
#[test]
fn decode_invalid_field_element_returns_err() {
let overflowing_bytes = hex!(
"FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF"
);
let ct_option = FieldElement::from_bytes(&overflowing_bytes.into());
assert!(bool::from(ct_option.is_none()));
}
#[test]
fn sqn_edge_cases() {
let a = FieldElement::from_u64(5);
assert_eq!(a.sqn(0), a);
assert_eq!(a.sqn(1), a.square());
assert_eq!(a.sqn(2), a.square().square());
}
}