Skip to main content

primitives/algebra/field/
field_element.rs

1use std::{
2    iter::{Product, Sum},
3    ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
4};
5
6use derive_more::derive::{AsMut, AsRef};
7use ff::Field;
8use hybrid_array::Array;
9use num_traits::{One, Zero};
10use rand::RngCore;
11use serde::{Deserialize, Serialize};
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
13use wincode::{SchemaRead, SchemaWrite};
14
15use crate::{
16    algebra::{
17        field::{ByteSize, FieldExtension, PrimeFieldExtension, SubfieldElement},
18        ops::{AccReduce, DefaultDotProduct, IntoWide, MulAccReduce, ReduceWide},
19        uniform_bytes::FromUniformBytes,
20    },
21    errors::PrimitiveError,
22    random::{CryptoRngCore, Random, RandomNonZero},
23    sharing::unauthenticated::AdditiveShares,
24};
25
26/// A field element wrapper.
27#[derive(
28    Copy, Clone, Debug, PartialOrd, PartialEq, Eq, Hash, AsRef, AsMut, SchemaRead, SchemaWrite,
29)]
30#[repr(transparent)]
31pub struct FieldElement<F: FieldExtension>(pub F);
32
33// SAFETY: FieldElement<F> is #[repr(transparent)] over F.
34unsafe impl<F: FieldExtension> bytemuck::TransparentWrapper<F> for FieldElement<F> {}
35
36impl<F: FieldExtension> FieldElement<F> {
37    /// Construct a field element wrapper from an inner field element
38    #[inline]
39    pub fn new(inner: F) -> Self {
40        FieldElement(inner)
41    }
42
43    /// Get the inner value of the field element
44    #[inline]
45    pub fn inner(&self) -> F {
46        self.0
47    }
48
49    /// Compute the exponentiation of the given field element
50    #[inline]
51    pub fn pow(&self, exp: u64) -> Self {
52        FieldElement::new(self.0.pow([exp]))
53    }
54
55    /// Construct a field element from the given bytes
56    #[inline]
57    pub fn from_be_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
58        let mut bytes = bytes.to_vec();
59        bytes.reverse();
60        Ok(FieldElement(F::from_le_bytes(&bytes).ok_or_else(|| {
61            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
62        })?))
63    }
64
65    pub fn from_le_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
66        Ok(FieldElement(F::from_le_bytes(bytes).ok_or_else(|| {
67            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
68        })?))
69    }
70
71    /// Convert the field element to little-endian bytes
72    #[inline]
73    pub fn to_le_bytes(&self) -> Array<u8, ByteSize<F>> {
74        self.0.to_le_bytes()
75    }
76
77    /// Convert the field element to big-endian bytes
78    #[inline]
79    pub fn to_be_bytes(&self) -> Array<u8, ByteSize<F>> {
80        let mut rev = self.0.to_le_bytes();
81        rev.as_mut().reverse();
82        rev
83    }
84
85    pub fn to_biguint(&self) -> num_bigint::BigUint {
86        num_bigint::BigUint::from_bytes_le(self.to_le_bytes().as_ref())
87    }
88}
89
90impl<F: FieldExtension> Random for FieldElement<F> {
91    #[inline]
92    fn random(rng: impl CryptoRngCore) -> Self {
93        FieldElement(Random::random(rng))
94    }
95}
96
97impl<F: FieldExtension> RandomNonZero for FieldElement<F> {
98    fn random_non_zero(mut rng: impl CryptoRngCore) -> Result<Self, PrimitiveError> {
99        Ok(FieldElement(F::random_non_zero(&mut rng)?))
100    }
101}
102
103impl<F: FieldExtension> Default for FieldElement<F> {
104    fn default() -> Self {
105        Self::zero()
106    }
107}
108
109impl<F: FieldExtension> Serialize for FieldElement<F> {
110    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
111        let bytes = self
112            .to_le_bytes()
113            .into_iter()
114            .collect::<Array<u8, F::FieldBytesSize>>();
115        serde_bytes::serialize(AsRef::<[u8]>::as_ref(&bytes), serializer)
116    }
117}
118
119impl<'de, F: FieldExtension> Deserialize<'de> for FieldElement<F> {
120    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
121        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
122        let field_elem = FieldElement::from_le_bytes(bytes).map_err(|err| {
123            serde::de::Error::custom(format!("Failed to deserialize field element: {err:?}"))
124        })?;
125        Ok(field_elem)
126    }
127}
128
129// --------------
130// | Arithmetic |
131// --------------
132
133// === Addition === //
134
135#[macros::op_variants(owned, borrowed, flipped_commutative)]
136impl<F: FieldExtension> Add<&FieldElement<F>> for FieldElement<F> {
137    type Output = FieldElement<F>;
138
139    #[inline]
140    fn add(self, rhs: &FieldElement<F>) -> Self::Output {
141        FieldElement(self.0 + rhs.0)
142    }
143}
144
145#[macros::op_variants(owned)]
146impl<'a, F: FieldExtension> AddAssign<&'a FieldElement<F>> for FieldElement<F> {
147    #[inline]
148    fn add_assign(&mut self, rhs: &'a FieldElement<F>) {
149        *self = *self + rhs;
150    }
151}
152
153// === Subtraction === //
154
155#[macros::op_variants(owned, borrowed, flipped)]
156impl<F: FieldExtension> Sub<&FieldElement<F>> for FieldElement<F> {
157    type Output = FieldElement<F>;
158
159    #[inline]
160    fn sub(self, rhs: &FieldElement<F>) -> Self::Output {
161        FieldElement(self.0 - rhs.0)
162    }
163}
164
165#[macros::op_variants(owned)]
166impl<'a, F: FieldExtension> SubAssign<&'a FieldElement<F>> for FieldElement<F> {
167    #[inline]
168    fn sub_assign(&mut self, rhs: &'a FieldElement<F>) {
169        *self = *self - rhs;
170    }
171}
172
173// === Multiplication === //
174
175#[macros::op_variants(owned, borrowed, flipped_commutative)]
176impl<F: FieldExtension> Mul<&FieldElement<F>> for FieldElement<F> {
177    type Output = FieldElement<F>;
178
179    #[inline]
180    fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
181        FieldElement(self.0 * rhs.0)
182    }
183}
184
185#[macros::op_variants(owned, borrowed, flipped)]
186impl<F: FieldExtension> Mul<&SubfieldElement<F>> for FieldElement<F> {
187    type Output = FieldElement<F>;
188
189    #[inline]
190    fn mul(self, rhs: &SubfieldElement<F>) -> Self::Output {
191        FieldElement(self.0 * rhs.0)
192    }
193}
194
195#[macros::op_variants(owned)]
196impl<'a, F: FieldExtension> MulAssign<&'a FieldElement<F>> for FieldElement<F> {
197    #[inline]
198    fn mul_assign(&mut self, rhs: &'a FieldElement<F>) {
199        *self = *self * rhs;
200    }
201}
202
203#[macros::op_variants(owned)]
204impl<'a, F: FieldExtension> MulAssign<&'a SubfieldElement<F>> for FieldElement<F> {
205    #[inline]
206    fn mul_assign(&mut self, rhs: &'a SubfieldElement<F>) {
207        *self = *self * rhs;
208    }
209}
210
211// === Negation === //
212
213#[macros::op_variants(borrowed)]
214impl<F: FieldExtension> Neg for FieldElement<F> {
215    type Output = FieldElement<F>;
216
217    #[inline]
218    fn neg(self) -> Self::Output {
219        FieldElement(-self.0)
220    }
221}
222
223// === Division === //
224
225#[macros::op_variants(owned, borrowed, flipped)]
226impl<F: FieldExtension> Div<&FieldElement<F>> for FieldElement<F> {
227    type Output = CtOption<FieldElement<F>>;
228
229    #[inline]
230    fn div(self, rhs: &FieldElement<F>) -> Self::Output {
231        rhs.0.invert().map(|inv| FieldElement(self.0 * inv))
232    }
233}
234
235// === Equality === //
236
237impl<F: FieldExtension> ConstantTimeEq for FieldElement<F> {
238    #[inline]
239    fn ct_eq(&self, other: &Self) -> Choice {
240        self.0.ct_eq(&other.0)
241    }
242}
243
244impl<F: FieldExtension> ConditionallySelectable for FieldElement<F> {
245    #[inline]
246    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
247        let selected = F::conditional_select(&a.0, &b.0, choice);
248        FieldElement(selected)
249    }
250}
251
252// === Other === //
253
254impl<F: FieldExtension> AdditiveShares for FieldElement<F> {}
255
256// ---------------
257// | Conversions |
258// ---------------
259
260impl<F: FieldExtension> From<bool> for FieldElement<F> {
261    #[inline]
262    fn from(value: bool) -> Self {
263        FieldElement(F::from(value as u64))
264    }
265}
266
267impl<F: FieldExtension> From<u8> for FieldElement<F> {
268    #[inline]
269    fn from(value: u8) -> Self {
270        FieldElement(F::from(value as u64))
271    }
272}
273
274impl<F: FieldExtension> From<u16> for FieldElement<F> {
275    #[inline]
276    fn from(value: u16) -> Self {
277        FieldElement(F::from(value as u64))
278    }
279}
280
281impl<F: FieldExtension> From<u32> for FieldElement<F> {
282    #[inline]
283    fn from(value: u32) -> Self {
284        FieldElement(F::from(value as u64))
285    }
286}
287
288impl<F: FieldExtension> From<u64> for FieldElement<F> {
289    #[inline]
290    fn from(value: u64) -> Self {
291        FieldElement(F::from(value))
292    }
293}
294
295impl<F: FieldExtension> From<u128> for FieldElement<F> {
296    #[inline]
297    fn from(value: u128) -> Self {
298        FieldElement(F::from(value))
299    }
300}
301
302impl<F: PrimeFieldExtension> From<SubfieldElement<F>> for FieldElement<F> {
303    #[inline]
304    fn from(value: SubfieldElement<F>) -> Self {
305        FieldElement(value.0)
306    }
307}
308
309// -------------------
310// | Iterator Traits |
311// -------------------
312impl<F: FieldExtension> Sum for FieldElement<F> {
313    #[inline]
314    fn sum<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
315        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
316            F::acc(&mut acc, x.0);
317            acc
318        });
319
320        FieldElement(F::reduce_mod_order(tmp))
321    }
322}
323
324impl<'a, F: FieldExtension> Sum<&'a FieldElement<F>> for FieldElement<F> {
325    #[inline]
326    fn sum<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
327        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
328            F::acc(&mut acc, x.0);
329            acc
330        });
331
332        FieldElement(F::reduce_mod_order(tmp))
333    }
334}
335
336impl<F: FieldExtension> Product for FieldElement<F> {
337    #[inline]
338    fn product<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
339        iter.fold(FieldElement::one(), |acc, x| acc * x)
340    }
341}
342
343impl<'a, F: FieldExtension> Product<&'a FieldElement<F>> for FieldElement<F> {
344    #[inline]
345    fn product<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
346        iter.fold(FieldElement::one(), |acc, x| acc * x)
347    }
348}
349
350impl<F: FieldExtension> FromUniformBytes for FieldElement<F> {
351    type UniformBytes = <F as FromUniformBytes>::UniformBytes;
352
353    fn from_uniform_bytes(bytes: &Array<u8, Self::UniformBytes>) -> Self {
354        Self(F::from_uniform_bytes(bytes))
355    }
356}
357
358// Dot product: FieldElement<F> x FieldElement<F>
359impl<F: FieldExtension> IntoWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
360    #[inline]
361    fn to_wide(&self) -> <F as MulAccReduce>::WideType {
362        <F as MulAccReduce>::to_wide(&self.0)
363    }
364
365    #[inline]
366    fn zero_wide() -> <F as MulAccReduce>::WideType {
367        <F as MulAccReduce>::zero_wide()
368    }
369}
370
371impl<F: FieldExtension> ReduceWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
372    #[inline]
373    fn reduce_mod_order(a: <F as MulAccReduce>::WideType) -> Self {
374        Self(F::reduce_mod_order(a))
375    }
376}
377
378impl<F: FieldExtension> MulAccReduce for FieldElement<F> {
379    type WideType = <F as MulAccReduce>::WideType;
380
381    #[inline]
382    fn mul_acc(acc: &mut Self::WideType, a: Self, b: Self) {
383        F::mul_acc(acc, a.0, b.0);
384    }
385}
386
387impl<F: FieldExtension> DefaultDotProduct for FieldElement<F> {}
388
389// Dot product: &FieldElement<F> x FieldElement<F>
390impl<'a, F: FieldExtension> MulAccReduce<&'a Self, Self> for FieldElement<F> {
391    type WideType = <F as MulAccReduce>::WideType;
392
393    #[inline]
394    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: Self) {
395        F::mul_acc(acc, a.0, b.0);
396    }
397}
398
399impl<F: FieldExtension> DefaultDotProduct<&Self, Self> for FieldElement<F> {}
400
401// Dot product: FieldElement<F> x &FieldElement<F>
402impl<'a, F: FieldExtension> MulAccReduce<Self, &'a Self> for FieldElement<F> {
403    type WideType = <F as MulAccReduce>::WideType;
404
405    #[inline]
406    fn mul_acc(acc: &mut Self::WideType, a: Self, b: &'a Self) {
407        F::mul_acc(acc, a.0, b.0);
408    }
409}
410
411impl<F: FieldExtension> DefaultDotProduct<Self, &Self> for FieldElement<F> {}
412
413// Dot product: &FieldElement<F> x &FieldElement<F>
414impl<'a, 'b, F: FieldExtension> MulAccReduce<&'a Self, &'b Self> for FieldElement<F> {
415    type WideType = <F as MulAccReduce>::WideType;
416
417    #[inline]
418    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: &'b Self) {
419        F::mul_acc(acc, a.0, b.0);
420    }
421}
422
423impl<F: FieldExtension> DefaultDotProduct<&Self, &Self> for FieldElement<F> {}
424
425// ----------------
426// | Zero and One |
427// ----------------
428
429impl<F: FieldExtension> Zero for FieldElement<F> {
430    /// The field's additive identity.
431    fn zero() -> Self {
432        FieldElement(F::ZERO)
433    }
434
435    fn is_zero(&self) -> bool {
436        self.0.is_zero().into()
437    }
438}
439
440impl<F: FieldExtension> One for FieldElement<F> {
441    /// The field's multiplicative identity.
442    fn one() -> Self {
443        FieldElement(F::ONE)
444    }
445}
446
447// ---------------
448// | Field trait |
449// ---------------
450
451impl<F: FieldExtension> Field for FieldElement<F> {
452    const ZERO: Self = FieldElement(F::ZERO);
453    const ONE: Self = FieldElement(F::ONE);
454
455    fn random(rng: impl RngCore) -> Self {
456        Self(<F as Field>::random(rng))
457    }
458
459    fn square(&self) -> Self {
460        FieldElement(self.0.square())
461    }
462
463    fn double(&self) -> Self {
464        FieldElement(self.0.double())
465    }
466
467    fn invert(&self) -> CtOption<Self> {
468        self.0.invert().map(FieldElement)
469    }
470
471    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
472        let (choice, sqrt) = F::sqrt_ratio(&num.0, &div.0);
473        (choice, FieldElement(sqrt))
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480    use crate::algebra::field::mersenne::Mersenne107;
481
482    #[test]
483    fn test_fieldelement_mersenne107_wincode() {
484        let elem = FieldElement::<Mersenne107>::new(Mersenne107::from(42u64));
485        let bytes = wincode::serialize(&elem).unwrap();
486
487        let decoded: FieldElement<Mersenne107> = wincode::deserialize(&bytes).unwrap();
488
489        assert_eq!(elem, decoded);
490    }
491}