use std::{
iter::{Product, Sum},
ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
};
use derive_more::derive::{AsMut, AsRef};
use ff::Field;
use hybrid_array::Array;
use num_traits::{One, Zero};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
use wincode::{SchemaRead, SchemaWrite};
use crate::{
algebra::{
field::{ByteSize, FieldExtension, PrimeFieldExtension, SubfieldElement},
ops::{AccReduce, DefaultDotProduct, IntoWide, MulAccReduce, ReduceWide},
uniform_bytes::FromUniformBytes,
},
errors::PrimitiveError,
random::{CryptoRngCore, Random},
sharing::unauthenticated::AdditiveShares,
};
#[derive(
Copy, Clone, Debug, PartialOrd, PartialEq, Eq, Hash, AsRef, AsMut, SchemaRead, SchemaWrite,
)]
#[repr(transparent)]
pub struct FieldElement<F: FieldExtension>(pub F);
impl<F: FieldExtension> FieldElement<F> {
#[inline]
pub fn new(inner: F) -> Self {
FieldElement(inner)
}
#[inline]
pub fn inner(&self) -> F {
self.0
}
#[inline]
pub fn pow(&self, exp: u64) -> Self {
FieldElement::new(self.0.pow([exp]))
}
#[inline]
pub fn from_be_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
let mut bytes = bytes.to_vec();
bytes.reverse();
Ok(FieldElement(F::from_le_bytes(&bytes).ok_or_else(|| {
PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
})?))
}
pub fn from_le_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
Ok(FieldElement(F::from_le_bytes(bytes).ok_or_else(|| {
PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
})?))
}
#[inline]
pub fn to_le_bytes(&self) -> Array<u8, ByteSize<F>> {
self.0.to_le_bytes()
}
#[inline]
pub fn to_be_bytes(&self) -> Array<u8, ByteSize<F>> {
let mut rev = self.0.to_le_bytes();
rev.as_mut().reverse();
rev
}
pub fn to_biguint(&self) -> num_bigint::BigUint {
num_bigint::BigUint::from_bytes_le(self.to_le_bytes().as_ref())
}
}
impl<F: FieldExtension> Random for FieldElement<F> {
#[inline]
fn random(rng: impl CryptoRngCore) -> Self {
FieldElement(Random::random(rng))
}
}
impl<F: FieldExtension> Default for FieldElement<F> {
fn default() -> Self {
Self::zero()
}
}
impl<F: FieldExtension> Serialize for FieldElement<F> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let bytes = self
.to_le_bytes()
.into_iter()
.collect::<Array<u8, F::FieldBytesSize>>();
serde_bytes::serialize(AsRef::<[u8]>::as_ref(&bytes), serializer)
}
}
impl<'de, F: FieldExtension> Deserialize<'de> for FieldElement<F> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
let field_elem = FieldElement::from_le_bytes(bytes).map_err(|err| {
serde::de::Error::custom(format!("Failed to deserialize field element: {err:?}"))
})?;
Ok(field_elem)
}
}
#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<F: FieldExtension> Add<&FieldElement<F>> for FieldElement<F> {
type Output = FieldElement<F>;
#[inline]
fn add(self, rhs: &FieldElement<F>) -> Self::Output {
FieldElement(self.0 + rhs.0)
}
}
#[macros::op_variants(owned)]
impl<'a, F: FieldExtension> AddAssign<&'a FieldElement<F>> for FieldElement<F> {
#[inline]
fn add_assign(&mut self, rhs: &'a FieldElement<F>) {
*self = *self + rhs;
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<F: FieldExtension> Sub<&FieldElement<F>> for FieldElement<F> {
type Output = FieldElement<F>;
#[inline]
fn sub(self, rhs: &FieldElement<F>) -> Self::Output {
FieldElement(self.0 - rhs.0)
}
}
#[macros::op_variants(owned)]
impl<'a, F: FieldExtension> SubAssign<&'a FieldElement<F>> for FieldElement<F> {
#[inline]
fn sub_assign(&mut self, rhs: &'a FieldElement<F>) {
*self = *self - rhs;
}
}
#[macros::op_variants(owned, borrowed, flipped_commutative)]
impl<F: FieldExtension> Mul<&FieldElement<F>> for FieldElement<F> {
type Output = FieldElement<F>;
#[inline]
fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
FieldElement(self.0 * rhs.0)
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<F: FieldExtension> Mul<&SubfieldElement<F>> for FieldElement<F> {
type Output = FieldElement<F>;
#[inline]
fn mul(self, rhs: &SubfieldElement<F>) -> Self::Output {
FieldElement(self.0 * rhs.0)
}
}
#[macros::op_variants(owned)]
impl<'a, F: FieldExtension> MulAssign<&'a FieldElement<F>> for FieldElement<F> {
#[inline]
fn mul_assign(&mut self, rhs: &'a FieldElement<F>) {
*self = *self * rhs;
}
}
#[macros::op_variants(owned)]
impl<'a, F: FieldExtension> MulAssign<&'a SubfieldElement<F>> for FieldElement<F> {
#[inline]
fn mul_assign(&mut self, rhs: &'a SubfieldElement<F>) {
*self = *self * rhs;
}
}
#[macros::op_variants(borrowed)]
impl<F: FieldExtension> Neg for FieldElement<F> {
type Output = FieldElement<F>;
#[inline]
fn neg(self) -> Self::Output {
FieldElement(-self.0)
}
}
#[macros::op_variants(owned, borrowed, flipped)]
impl<F: FieldExtension> Div<&FieldElement<F>> for FieldElement<F> {
type Output = CtOption<FieldElement<F>>;
#[inline]
fn div(self, rhs: &FieldElement<F>) -> Self::Output {
rhs.0.invert().map(|inv| FieldElement(self.0 * inv))
}
}
impl<F: FieldExtension> ConstantTimeEq for FieldElement<F> {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
self.0.ct_eq(&other.0)
}
}
impl<F: FieldExtension> ConditionallySelectable for FieldElement<F> {
#[inline]
fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
let selected = F::conditional_select(&a.0, &b.0, choice);
FieldElement(selected)
}
}
impl<F: FieldExtension> AdditiveShares for FieldElement<F> {}
impl<F: FieldExtension> From<bool> for FieldElement<F> {
#[inline]
fn from(value: bool) -> Self {
FieldElement(F::from(value as u64))
}
}
impl<F: FieldExtension> From<u8> for FieldElement<F> {
#[inline]
fn from(value: u8) -> Self {
FieldElement(F::from(value as u64))
}
}
impl<F: FieldExtension> From<u16> for FieldElement<F> {
#[inline]
fn from(value: u16) -> Self {
FieldElement(F::from(value as u64))
}
}
impl<F: FieldExtension> From<u32> for FieldElement<F> {
#[inline]
fn from(value: u32) -> Self {
FieldElement(F::from(value as u64))
}
}
impl<F: FieldExtension> From<u64> for FieldElement<F> {
#[inline]
fn from(value: u64) -> Self {
FieldElement(F::from(value))
}
}
impl<F: FieldExtension> From<u128> for FieldElement<F> {
#[inline]
fn from(value: u128) -> Self {
FieldElement(F::from(value))
}
}
impl<F: PrimeFieldExtension> From<SubfieldElement<F>> for FieldElement<F> {
#[inline]
fn from(value: SubfieldElement<F>) -> Self {
FieldElement(value.0)
}
}
impl<F: FieldExtension> Sum for FieldElement<F> {
#[inline]
fn sum<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
F::acc(&mut acc, x.0);
acc
});
FieldElement(F::reduce_mod_order(tmp))
}
}
impl<'a, F: FieldExtension> Sum<&'a FieldElement<F>> for FieldElement<F> {
#[inline]
fn sum<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
F::acc(&mut acc, x.0);
acc
});
FieldElement(F::reduce_mod_order(tmp))
}
}
impl<F: FieldExtension> Product for FieldElement<F> {
#[inline]
fn product<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
iter.fold(FieldElement::one(), |acc, x| acc * x)
}
}
impl<'a, F: FieldExtension> Product<&'a FieldElement<F>> for FieldElement<F> {
#[inline]
fn product<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
iter.fold(FieldElement::one(), |acc, x| acc * x)
}
}
impl<F: FieldExtension> FromUniformBytes for FieldElement<F> {
type UniformBytes = <F as FromUniformBytes>::UniformBytes;
fn from_uniform_bytes(bytes: &Array<u8, Self::UniformBytes>) -> Self {
Self(F::from_uniform_bytes(bytes))
}
}
impl<F: FieldExtension> IntoWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
#[inline]
fn to_wide(&self) -> <F as MulAccReduce>::WideType {
<F as MulAccReduce>::to_wide(&self.0)
}
#[inline]
fn zero_wide() -> <F as MulAccReduce>::WideType {
<F as MulAccReduce>::zero_wide()
}
}
impl<F: FieldExtension> ReduceWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
#[inline]
fn reduce_mod_order(a: <F as MulAccReduce>::WideType) -> Self {
Self(F::reduce_mod_order(a))
}
}
impl<F: FieldExtension> MulAccReduce for FieldElement<F> {
type WideType = <F as MulAccReduce>::WideType;
#[inline]
fn mul_acc(acc: &mut Self::WideType, a: Self, b: Self) {
F::mul_acc(acc, a.0, b.0);
}
}
impl<F: FieldExtension> DefaultDotProduct for FieldElement<F> {}
impl<'a, F: FieldExtension> MulAccReduce<&'a Self, Self> for FieldElement<F> {
type WideType = <F as MulAccReduce>::WideType;
#[inline]
fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: Self) {
F::mul_acc(acc, a.0, b.0);
}
}
impl<F: FieldExtension> DefaultDotProduct<&Self, Self> for FieldElement<F> {}
impl<'a, F: FieldExtension> MulAccReduce<Self, &'a Self> for FieldElement<F> {
type WideType = <F as MulAccReduce>::WideType;
#[inline]
fn mul_acc(acc: &mut Self::WideType, a: Self, b: &'a Self) {
F::mul_acc(acc, a.0, b.0);
}
}
impl<F: FieldExtension> DefaultDotProduct<Self, &Self> for FieldElement<F> {}
impl<'a, 'b, F: FieldExtension> MulAccReduce<&'a Self, &'b Self> for FieldElement<F> {
type WideType = <F as MulAccReduce>::WideType;
#[inline]
fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: &'b Self) {
F::mul_acc(acc, a.0, b.0);
}
}
impl<F: FieldExtension> DefaultDotProduct<&Self, &Self> for FieldElement<F> {}
impl<F: FieldExtension> Zero for FieldElement<F> {
fn zero() -> Self {
FieldElement(F::ZERO)
}
fn is_zero(&self) -> bool {
self.0.is_zero().into()
}
}
impl<F: FieldExtension> One for FieldElement<F> {
fn one() -> Self {
FieldElement(F::ONE)
}
}
impl<F: FieldExtension> Field for FieldElement<F> {
const ZERO: Self = FieldElement(F::ZERO);
const ONE: Self = FieldElement(F::ONE);
fn random(rng: impl RngCore) -> Self {
Self(<F as Field>::random(rng))
}
fn square(&self) -> Self {
FieldElement(self.0.square())
}
fn double(&self) -> Self {
FieldElement(self.0.double())
}
fn invert(&self) -> CtOption<Self> {
self.0.invert().map(FieldElement)
}
fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
let (choice, sqrt) = F::sqrt_ratio(&num.0, &div.0);
(choice, FieldElement(sqrt))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::algebra::field::mersenne::Mersenne107;
#[test]
fn test_fieldelement_mersenne107_wincode() {
let elem = FieldElement::<Mersenne107>::new(Mersenne107::from(42u64));
let bytes = wincode::serialize(&elem).unwrap();
let decoded: FieldElement<Mersenne107> = wincode::deserialize(&bytes).unwrap();
assert_eq!(elem, decoded);
}
}