Skip to main content

primitives/algebra/field/
field_element.rs

1use std::{
2    iter::{Product, Sum},
3    ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
4};
5
6use derive_more::derive::{AsMut, AsRef};
7use ff::Field;
8use hybrid_array::Array;
9use num_traits::{One, Zero};
10use rand::{distributions::Standard, prelude::Distribution, RngCore};
11use serde::{Deserialize, Serialize};
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
13use wincode::{SchemaRead, SchemaWrite};
14
15use crate::{
16    algebra::{
17        field::{ByteSize, FieldExtension, PrimeFieldExtension, SubfieldElement},
18        ops::{AccReduce, DefaultDotProduct, IntoWide, MulAccReduce, ReduceWide},
19        uniform_bytes::FromUniformBytes,
20    },
21    errors::PrimitiveError,
22    sharing::unauthenticated::AdditiveShares,
23};
24
25/// A field element wrapper.
26#[derive(
27    Copy, Clone, Debug, PartialOrd, PartialEq, Eq, Hash, AsRef, AsMut, SchemaRead, SchemaWrite,
28)]
29#[repr(transparent)]
30pub struct FieldElement<F: FieldExtension>(pub(crate) F);
31
32impl<F: FieldExtension> FieldElement<F> {
33    /// Construct a field element wrapper from an inner field element
34    #[inline]
35    pub fn new(inner: F) -> Self {
36        FieldElement(inner)
37    }
38
39    /// Get the inner value of the field element
40    #[inline]
41    pub fn inner(&self) -> F {
42        self.0
43    }
44
45    /// Compute the exponentiation of the given field element
46    #[inline]
47    pub fn pow(&self, exp: u64) -> Self {
48        FieldElement::new(self.0.pow([exp]))
49    }
50
51    /// Construct a field element from the given bytes
52    #[inline]
53    pub fn from_be_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
54        let mut bytes = bytes.to_vec();
55        bytes.reverse();
56        Ok(FieldElement(F::from_le_bytes(&bytes).ok_or_else(|| {
57            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
58        })?))
59    }
60
61    pub fn from_le_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
62        Ok(FieldElement(F::from_le_bytes(bytes).ok_or_else(|| {
63            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
64        })?))
65    }
66
67    /// Convert the field element to little-endian bytes
68    #[inline]
69    pub fn to_le_bytes(&self) -> Array<u8, ByteSize<F>> {
70        self.0.to_le_bytes()
71    }
72
73    /// Convert the field element to big-endian bytes
74    #[inline]
75    pub fn to_be_bytes(&self) -> Array<u8, ByteSize<F>> {
76        let mut rev = self.0.to_le_bytes();
77        rev.as_mut().reverse();
78        rev
79    }
80
81    pub fn to_biguint(&self) -> num_bigint::BigUint {
82        num_bigint::BigUint::from_bytes_le(self.to_le_bytes().as_ref())
83    }
84}
85
86impl<F: FieldExtension> Distribution<FieldElement<F>> for Standard {
87    fn sample<R: RngCore + ?Sized>(&self, rng: &mut R) -> FieldElement<F> {
88        FieldElement(F::random(rng))
89    }
90}
91
92impl<F: FieldExtension> Default for FieldElement<F> {
93    fn default() -> Self {
94        Self::zero()
95    }
96}
97
98impl<F: FieldExtension> Serialize for FieldElement<F> {
99    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
100        let bytes = self
101            .to_le_bytes()
102            .into_iter()
103            .collect::<Array<u8, F::FieldBytesSize>>();
104        serde_bytes::serialize(AsRef::<[u8]>::as_ref(&bytes), serializer)
105    }
106}
107
108impl<'de, F: FieldExtension> Deserialize<'de> for FieldElement<F> {
109    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
110        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
111        let field_elem = FieldElement::from_le_bytes(bytes).map_err(|err| {
112            serde::de::Error::custom(format!("Failed to deserialize field element: {err:?}"))
113        })?;
114        Ok(field_elem)
115    }
116}
117
118// --------------
119// | Arithmetic |
120// --------------
121
122// === Addition === //
123
124#[macros::op_variants(owned, borrowed, flipped_commutative)]
125impl<F: FieldExtension> Add<&FieldElement<F>> for FieldElement<F> {
126    type Output = FieldElement<F>;
127
128    #[inline]
129    fn add(self, rhs: &FieldElement<F>) -> Self::Output {
130        FieldElement(self.0 + rhs.0)
131    }
132}
133
134#[macros::op_variants(owned)]
135impl<'a, F: FieldExtension> AddAssign<&'a FieldElement<F>> for FieldElement<F> {
136    #[inline]
137    fn add_assign(&mut self, rhs: &'a FieldElement<F>) {
138        *self = *self + rhs;
139    }
140}
141
142// === Subtraction === //
143
144#[macros::op_variants(owned, borrowed, flipped)]
145impl<F: FieldExtension> Sub<&FieldElement<F>> for FieldElement<F> {
146    type Output = FieldElement<F>;
147
148    #[inline]
149    fn sub(self, rhs: &FieldElement<F>) -> Self::Output {
150        FieldElement(self.0 - rhs.0)
151    }
152}
153
154#[macros::op_variants(owned)]
155impl<'a, F: FieldExtension> SubAssign<&'a FieldElement<F>> for FieldElement<F> {
156    #[inline]
157    fn sub_assign(&mut self, rhs: &'a FieldElement<F>) {
158        *self = *self - rhs;
159    }
160}
161
162// === Multiplication === //
163
164#[macros::op_variants(owned, borrowed, flipped_commutative)]
165impl<F: FieldExtension> Mul<&FieldElement<F>> for FieldElement<F> {
166    type Output = FieldElement<F>;
167
168    #[inline]
169    fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
170        FieldElement(self.0 * rhs.0)
171    }
172}
173
174#[macros::op_variants(owned, borrowed, flipped)]
175impl<F: FieldExtension> Mul<&SubfieldElement<F>> for FieldElement<F> {
176    type Output = FieldElement<F>;
177
178    #[inline]
179    fn mul(self, rhs: &SubfieldElement<F>) -> Self::Output {
180        FieldElement(self.0 * rhs.0)
181    }
182}
183
184#[macros::op_variants(owned)]
185impl<'a, F: FieldExtension> MulAssign<&'a FieldElement<F>> for FieldElement<F> {
186    #[inline]
187    fn mul_assign(&mut self, rhs: &'a FieldElement<F>) {
188        *self = *self * rhs;
189    }
190}
191
192#[macros::op_variants(owned)]
193impl<'a, F: FieldExtension> MulAssign<&'a SubfieldElement<F>> for FieldElement<F> {
194    #[inline]
195    fn mul_assign(&mut self, rhs: &'a SubfieldElement<F>) {
196        *self = *self * rhs;
197    }
198}
199
200// === Negation === //
201
202#[macros::op_variants(borrowed)]
203impl<F: FieldExtension> Neg for FieldElement<F> {
204    type Output = FieldElement<F>;
205
206    #[inline]
207    fn neg(self) -> Self::Output {
208        FieldElement(-self.0)
209    }
210}
211
212// === Division === //
213
214#[macros::op_variants(owned, borrowed, flipped)]
215impl<F: FieldExtension> Div<&FieldElement<F>> for FieldElement<F> {
216    type Output = CtOption<FieldElement<F>>;
217
218    #[inline]
219    fn div(self, rhs: &FieldElement<F>) -> Self::Output {
220        rhs.0.invert().map(|inv| FieldElement(self.0 * inv))
221    }
222}
223
224// === Equality === //
225
226impl<F: FieldExtension> ConstantTimeEq for FieldElement<F> {
227    #[inline]
228    fn ct_eq(&self, other: &Self) -> Choice {
229        self.0.ct_eq(&other.0)
230    }
231}
232
233impl<F: FieldExtension> ConditionallySelectable for FieldElement<F> {
234    #[inline]
235    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
236        let selected = F::conditional_select(&a.0, &b.0, choice);
237        FieldElement(selected)
238    }
239}
240
241// === Other === //
242
243impl<F: FieldExtension> AdditiveShares for FieldElement<F> {}
244
245// ---------------
246// | Conversions |
247// ---------------
248
249impl<F: FieldExtension> From<bool> for FieldElement<F> {
250    #[inline]
251    fn from(value: bool) -> Self {
252        FieldElement(F::from(value as u64))
253    }
254}
255
256impl<F: FieldExtension> From<u8> for FieldElement<F> {
257    #[inline]
258    fn from(value: u8) -> Self {
259        FieldElement(F::from(value as u64))
260    }
261}
262
263impl<F: FieldExtension> From<u16> for FieldElement<F> {
264    #[inline]
265    fn from(value: u16) -> Self {
266        FieldElement(F::from(value as u64))
267    }
268}
269
270impl<F: FieldExtension> From<u32> for FieldElement<F> {
271    #[inline]
272    fn from(value: u32) -> Self {
273        FieldElement(F::from(value as u64))
274    }
275}
276
277impl<F: FieldExtension> From<u64> for FieldElement<F> {
278    #[inline]
279    fn from(value: u64) -> Self {
280        FieldElement(F::from(value))
281    }
282}
283
284impl<F: FieldExtension> From<u128> for FieldElement<F> {
285    #[inline]
286    fn from(value: u128) -> Self {
287        FieldElement(F::from(value))
288    }
289}
290
291impl<F: PrimeFieldExtension> From<SubfieldElement<F>> for FieldElement<F> {
292    #[inline]
293    fn from(value: SubfieldElement<F>) -> Self {
294        FieldElement(value.0)
295    }
296}
297
298// -------------------
299// | Iterator Traits |
300// -------------------
301impl<F: FieldExtension> Sum for FieldElement<F> {
302    #[inline]
303    fn sum<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
304        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
305            F::acc(&mut acc, x.0);
306            acc
307        });
308
309        FieldElement(F::reduce_mod_order(tmp))
310    }
311}
312
313impl<'a, F: FieldExtension> Sum<&'a FieldElement<F>> for FieldElement<F> {
314    #[inline]
315    fn sum<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
316        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
317            F::acc(&mut acc, x.0);
318            acc
319        });
320
321        FieldElement(F::reduce_mod_order(tmp))
322    }
323}
324
325impl<F: FieldExtension> Product for FieldElement<F> {
326    #[inline]
327    fn product<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
328        iter.fold(FieldElement::one(), |acc, x| acc * x)
329    }
330}
331
332impl<'a, F: FieldExtension> Product<&'a FieldElement<F>> for FieldElement<F> {
333    #[inline]
334    fn product<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
335        iter.fold(FieldElement::one(), |acc, x| acc * x)
336    }
337}
338
339impl<F: FieldExtension> FromUniformBytes for FieldElement<F> {
340    type UniformBytes = <F as FromUniformBytes>::UniformBytes;
341
342    fn from_uniform_bytes(bytes: &Array<u8, Self::UniformBytes>) -> Self {
343        Self(F::from_uniform_bytes(bytes))
344    }
345}
346
347// Dot product: FieldElement<F> x FieldElement<F>
348impl<F: FieldExtension> IntoWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
349    #[inline]
350    fn to_wide(&self) -> <F as MulAccReduce>::WideType {
351        <F as MulAccReduce>::to_wide(&self.0)
352    }
353
354    #[inline]
355    fn zero_wide() -> <F as MulAccReduce>::WideType {
356        <F as MulAccReduce>::zero_wide()
357    }
358}
359
360impl<F: FieldExtension> ReduceWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
361    #[inline]
362    fn reduce_mod_order(a: <F as MulAccReduce>::WideType) -> Self {
363        Self(F::reduce_mod_order(a))
364    }
365}
366
367impl<F: FieldExtension> MulAccReduce for FieldElement<F> {
368    type WideType = <F as MulAccReduce>::WideType;
369
370    #[inline]
371    fn mul_acc(acc: &mut Self::WideType, a: Self, b: Self) {
372        F::mul_acc(acc, a.0, b.0);
373    }
374}
375
376impl<F: FieldExtension> DefaultDotProduct for FieldElement<F> {}
377
378// Dot product: &FieldElement<F> x FieldElement<F>
379impl<'a, F: FieldExtension> MulAccReduce<&'a Self, Self> for FieldElement<F> {
380    type WideType = <F as MulAccReduce>::WideType;
381
382    #[inline]
383    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: Self) {
384        F::mul_acc(acc, a.0, b.0);
385    }
386}
387
388impl<F: FieldExtension> DefaultDotProduct<&Self, Self> for FieldElement<F> {}
389
390// Dot product: FieldElement<F> x &FieldElement<F>
391impl<'a, F: FieldExtension> MulAccReduce<Self, &'a Self> for FieldElement<F> {
392    type WideType = <F as MulAccReduce>::WideType;
393
394    #[inline]
395    fn mul_acc(acc: &mut Self::WideType, a: Self, b: &'a Self) {
396        F::mul_acc(acc, a.0, b.0);
397    }
398}
399
400impl<F: FieldExtension> DefaultDotProduct<Self, &Self> for FieldElement<F> {}
401
402// Dot product: &FieldElement<F> x &FieldElement<F>
403impl<'a, 'b, F: FieldExtension> MulAccReduce<&'a Self, &'b Self> for FieldElement<F> {
404    type WideType = <F as MulAccReduce>::WideType;
405
406    #[inline]
407    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: &'b Self) {
408        F::mul_acc(acc, a.0, b.0);
409    }
410}
411
412impl<F: FieldExtension> DefaultDotProduct<&Self, &Self> for FieldElement<F> {}
413
414// ----------------
415// | Zero and One |
416// ----------------
417
418impl<F: FieldExtension> Zero for FieldElement<F> {
419    /// The field's additive identity.
420    fn zero() -> Self {
421        FieldElement(F::ZERO)
422    }
423
424    fn is_zero(&self) -> bool {
425        self.0.is_zero().into()
426    }
427}
428
429impl<F: FieldExtension> One for FieldElement<F> {
430    /// The field's multiplicative identity.
431    fn one() -> Self {
432        FieldElement(F::ONE)
433    }
434}
435
436// ---------------
437// | Field trait |
438// ---------------
439
440impl<F: FieldExtension> Field for FieldElement<F> {
441    const ZERO: Self = FieldElement(F::ZERO);
442    const ONE: Self = FieldElement(F::ONE);
443
444    fn random(rng: impl RngCore) -> Self {
445        Self(F::random(rng))
446    }
447
448    fn square(&self) -> Self {
449        FieldElement(self.0.square())
450    }
451
452    fn double(&self) -> Self {
453        FieldElement(self.0.double())
454    }
455
456    fn invert(&self) -> CtOption<Self> {
457        self.0.invert().map(FieldElement)
458    }
459
460    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
461        let (choice, sqrt) = F::sqrt_ratio(&num.0, &div.0);
462        (choice, FieldElement(sqrt))
463    }
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469    use crate::algebra::field::mersenne::Mersenne107;
470
471    #[test]
472    fn test_fieldelement_mersenne107_wincode() {
473        let elem = FieldElement::<Mersenne107>::new(Mersenne107::from(42u64));
474        let bytes = wincode::serialize(&elem).unwrap();
475
476        let decoded: FieldElement<Mersenne107> = wincode::deserialize(&bytes).unwrap();
477
478        assert_eq!(elem, decoded);
479    }
480}