Skip to main content

primitives/algebra/field/
field_element.rs

1use std::{
2    iter::{Product, Sum},
3    ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
4};
5
6use derive_more::derive::{AsMut, AsRef};
7use ff::Field;
8use hybrid_array::Array;
9use num_traits::{One, Zero};
10use rand::RngCore;
11use serde::{Deserialize, Serialize};
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
13use wincode::{SchemaRead, SchemaWrite};
14
15use crate::{
16    algebra::{
17        field::{ByteSize, FieldExtension, PrimeFieldExtension, SubfieldElement},
18        ops::{AccReduce, DefaultDotProduct, IntoWide, MulAccReduce, ReduceWide},
19        uniform_bytes::FromUniformBytes,
20    },
21    errors::PrimitiveError,
22    random::{CryptoRngCore, Random},
23    sharing::unauthenticated::AdditiveShares,
24};
25
26/// A field element wrapper.
27#[derive(
28    Copy, Clone, Debug, PartialOrd, PartialEq, Eq, Hash, AsRef, AsMut, SchemaRead, SchemaWrite,
29)]
30#[repr(transparent)]
31pub struct FieldElement<F: FieldExtension>(pub F);
32
33// SAFETY: FieldElement<F> is #[repr(transparent)] over F.
34unsafe impl<F: FieldExtension> bytemuck::TransparentWrapper<F> for FieldElement<F> {}
35
36impl<F: FieldExtension> FieldElement<F> {
37    /// Construct a field element wrapper from an inner field element
38    #[inline]
39    pub fn new(inner: F) -> Self {
40        FieldElement(inner)
41    }
42
43    /// Get the inner value of the field element
44    #[inline]
45    pub fn inner(&self) -> F {
46        self.0
47    }
48
49    /// Compute the exponentiation of the given field element
50    #[inline]
51    pub fn pow(&self, exp: u64) -> Self {
52        FieldElement::new(self.0.pow([exp]))
53    }
54
55    /// Construct a field element from the given bytes
56    #[inline]
57    pub fn from_be_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
58        let mut bytes = bytes.to_vec();
59        bytes.reverse();
60        Ok(FieldElement(F::from_le_bytes(&bytes).ok_or_else(|| {
61            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
62        })?))
63    }
64
65    pub fn from_le_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
66        Ok(FieldElement(F::from_le_bytes(bytes).ok_or_else(|| {
67            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
68        })?))
69    }
70
71    /// Convert the field element to little-endian bytes
72    #[inline]
73    pub fn to_le_bytes(&self) -> Array<u8, ByteSize<F>> {
74        self.0.to_le_bytes()
75    }
76
77    /// Convert the field element to big-endian bytes
78    #[inline]
79    pub fn to_be_bytes(&self) -> Array<u8, ByteSize<F>> {
80        let mut rev = self.0.to_le_bytes();
81        rev.as_mut().reverse();
82        rev
83    }
84
85    pub fn to_biguint(&self) -> num_bigint::BigUint {
86        num_bigint::BigUint::from_bytes_le(self.to_le_bytes().as_ref())
87    }
88}
89
90impl<F: FieldExtension> Random for FieldElement<F> {
91    #[inline]
92    fn random(rng: impl CryptoRngCore) -> Self {
93        FieldElement(Random::random(rng))
94    }
95}
96
97impl<F: FieldExtension> Default for FieldElement<F> {
98    fn default() -> Self {
99        Self::zero()
100    }
101}
102
103impl<F: FieldExtension> Serialize for FieldElement<F> {
104    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
105        let bytes = self
106            .to_le_bytes()
107            .into_iter()
108            .collect::<Array<u8, F::FieldBytesSize>>();
109        serde_bytes::serialize(AsRef::<[u8]>::as_ref(&bytes), serializer)
110    }
111}
112
113impl<'de, F: FieldExtension> Deserialize<'de> for FieldElement<F> {
114    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
115        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
116        let field_elem = FieldElement::from_le_bytes(bytes).map_err(|err| {
117            serde::de::Error::custom(format!("Failed to deserialize field element: {err:?}"))
118        })?;
119        Ok(field_elem)
120    }
121}
122
123// --------------
124// | Arithmetic |
125// --------------
126
127// === Addition === //
128
129#[macros::op_variants(owned, borrowed, flipped_commutative)]
130impl<F: FieldExtension> Add<&FieldElement<F>> for FieldElement<F> {
131    type Output = FieldElement<F>;
132
133    #[inline]
134    fn add(self, rhs: &FieldElement<F>) -> Self::Output {
135        FieldElement(self.0 + rhs.0)
136    }
137}
138
139#[macros::op_variants(owned)]
140impl<'a, F: FieldExtension> AddAssign<&'a FieldElement<F>> for FieldElement<F> {
141    #[inline]
142    fn add_assign(&mut self, rhs: &'a FieldElement<F>) {
143        *self = *self + rhs;
144    }
145}
146
147// === Subtraction === //
148
149#[macros::op_variants(owned, borrowed, flipped)]
150impl<F: FieldExtension> Sub<&FieldElement<F>> for FieldElement<F> {
151    type Output = FieldElement<F>;
152
153    #[inline]
154    fn sub(self, rhs: &FieldElement<F>) -> Self::Output {
155        FieldElement(self.0 - rhs.0)
156    }
157}
158
159#[macros::op_variants(owned)]
160impl<'a, F: FieldExtension> SubAssign<&'a FieldElement<F>> for FieldElement<F> {
161    #[inline]
162    fn sub_assign(&mut self, rhs: &'a FieldElement<F>) {
163        *self = *self - rhs;
164    }
165}
166
167// === Multiplication === //
168
169#[macros::op_variants(owned, borrowed, flipped_commutative)]
170impl<F: FieldExtension> Mul<&FieldElement<F>> for FieldElement<F> {
171    type Output = FieldElement<F>;
172
173    #[inline]
174    fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
175        FieldElement(self.0 * rhs.0)
176    }
177}
178
179#[macros::op_variants(owned, borrowed, flipped)]
180impl<F: FieldExtension> Mul<&SubfieldElement<F>> for FieldElement<F> {
181    type Output = FieldElement<F>;
182
183    #[inline]
184    fn mul(self, rhs: &SubfieldElement<F>) -> Self::Output {
185        FieldElement(self.0 * rhs.0)
186    }
187}
188
189#[macros::op_variants(owned)]
190impl<'a, F: FieldExtension> MulAssign<&'a FieldElement<F>> for FieldElement<F> {
191    #[inline]
192    fn mul_assign(&mut self, rhs: &'a FieldElement<F>) {
193        *self = *self * rhs;
194    }
195}
196
197#[macros::op_variants(owned)]
198impl<'a, F: FieldExtension> MulAssign<&'a SubfieldElement<F>> for FieldElement<F> {
199    #[inline]
200    fn mul_assign(&mut self, rhs: &'a SubfieldElement<F>) {
201        *self = *self * rhs;
202    }
203}
204
205// === Negation === //
206
207#[macros::op_variants(borrowed)]
208impl<F: FieldExtension> Neg for FieldElement<F> {
209    type Output = FieldElement<F>;
210
211    #[inline]
212    fn neg(self) -> Self::Output {
213        FieldElement(-self.0)
214    }
215}
216
217// === Division === //
218
219#[macros::op_variants(owned, borrowed, flipped)]
220impl<F: FieldExtension> Div<&FieldElement<F>> for FieldElement<F> {
221    type Output = CtOption<FieldElement<F>>;
222
223    #[inline]
224    fn div(self, rhs: &FieldElement<F>) -> Self::Output {
225        rhs.0.invert().map(|inv| FieldElement(self.0 * inv))
226    }
227}
228
229// === Equality === //
230
231impl<F: FieldExtension> ConstantTimeEq for FieldElement<F> {
232    #[inline]
233    fn ct_eq(&self, other: &Self) -> Choice {
234        self.0.ct_eq(&other.0)
235    }
236}
237
238impl<F: FieldExtension> ConditionallySelectable for FieldElement<F> {
239    #[inline]
240    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
241        let selected = F::conditional_select(&a.0, &b.0, choice);
242        FieldElement(selected)
243    }
244}
245
246// === Other === //
247
248impl<F: FieldExtension> AdditiveShares for FieldElement<F> {}
249
250// ---------------
251// | Conversions |
252// ---------------
253
254impl<F: FieldExtension> From<bool> for FieldElement<F> {
255    #[inline]
256    fn from(value: bool) -> Self {
257        FieldElement(F::from(value as u64))
258    }
259}
260
261impl<F: FieldExtension> From<u8> for FieldElement<F> {
262    #[inline]
263    fn from(value: u8) -> Self {
264        FieldElement(F::from(value as u64))
265    }
266}
267
268impl<F: FieldExtension> From<u16> for FieldElement<F> {
269    #[inline]
270    fn from(value: u16) -> Self {
271        FieldElement(F::from(value as u64))
272    }
273}
274
275impl<F: FieldExtension> From<u32> for FieldElement<F> {
276    #[inline]
277    fn from(value: u32) -> Self {
278        FieldElement(F::from(value as u64))
279    }
280}
281
282impl<F: FieldExtension> From<u64> for FieldElement<F> {
283    #[inline]
284    fn from(value: u64) -> Self {
285        FieldElement(F::from(value))
286    }
287}
288
289impl<F: FieldExtension> From<u128> for FieldElement<F> {
290    #[inline]
291    fn from(value: u128) -> Self {
292        FieldElement(F::from(value))
293    }
294}
295
296impl<F: PrimeFieldExtension> From<SubfieldElement<F>> for FieldElement<F> {
297    #[inline]
298    fn from(value: SubfieldElement<F>) -> Self {
299        FieldElement(value.0)
300    }
301}
302
303// -------------------
304// | Iterator Traits |
305// -------------------
306impl<F: FieldExtension> Sum for FieldElement<F> {
307    #[inline]
308    fn sum<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
309        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
310            F::acc(&mut acc, x.0);
311            acc
312        });
313
314        FieldElement(F::reduce_mod_order(tmp))
315    }
316}
317
318impl<'a, F: FieldExtension> Sum<&'a FieldElement<F>> for FieldElement<F> {
319    #[inline]
320    fn sum<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
321        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
322            F::acc(&mut acc, x.0);
323            acc
324        });
325
326        FieldElement(F::reduce_mod_order(tmp))
327    }
328}
329
330impl<F: FieldExtension> Product for FieldElement<F> {
331    #[inline]
332    fn product<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
333        iter.fold(FieldElement::one(), |acc, x| acc * x)
334    }
335}
336
337impl<'a, F: FieldExtension> Product<&'a FieldElement<F>> for FieldElement<F> {
338    #[inline]
339    fn product<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
340        iter.fold(FieldElement::one(), |acc, x| acc * x)
341    }
342}
343
344impl<F: FieldExtension> FromUniformBytes for FieldElement<F> {
345    type UniformBytes = <F as FromUniformBytes>::UniformBytes;
346
347    fn from_uniform_bytes(bytes: &Array<u8, Self::UniformBytes>) -> Self {
348        Self(F::from_uniform_bytes(bytes))
349    }
350}
351
352// Dot product: FieldElement<F> x FieldElement<F>
353impl<F: FieldExtension> IntoWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
354    #[inline]
355    fn to_wide(&self) -> <F as MulAccReduce>::WideType {
356        <F as MulAccReduce>::to_wide(&self.0)
357    }
358
359    #[inline]
360    fn zero_wide() -> <F as MulAccReduce>::WideType {
361        <F as MulAccReduce>::zero_wide()
362    }
363}
364
365impl<F: FieldExtension> ReduceWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
366    #[inline]
367    fn reduce_mod_order(a: <F as MulAccReduce>::WideType) -> Self {
368        Self(F::reduce_mod_order(a))
369    }
370}
371
372impl<F: FieldExtension> MulAccReduce for FieldElement<F> {
373    type WideType = <F as MulAccReduce>::WideType;
374
375    #[inline]
376    fn mul_acc(acc: &mut Self::WideType, a: Self, b: Self) {
377        F::mul_acc(acc, a.0, b.0);
378    }
379}
380
381impl<F: FieldExtension> DefaultDotProduct for FieldElement<F> {}
382
383// Dot product: &FieldElement<F> x FieldElement<F>
384impl<'a, F: FieldExtension> MulAccReduce<&'a Self, Self> for FieldElement<F> {
385    type WideType = <F as MulAccReduce>::WideType;
386
387    #[inline]
388    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: Self) {
389        F::mul_acc(acc, a.0, b.0);
390    }
391}
392
393impl<F: FieldExtension> DefaultDotProduct<&Self, Self> for FieldElement<F> {}
394
395// Dot product: FieldElement<F> x &FieldElement<F>
396impl<'a, F: FieldExtension> MulAccReduce<Self, &'a Self> for FieldElement<F> {
397    type WideType = <F as MulAccReduce>::WideType;
398
399    #[inline]
400    fn mul_acc(acc: &mut Self::WideType, a: Self, b: &'a Self) {
401        F::mul_acc(acc, a.0, b.0);
402    }
403}
404
405impl<F: FieldExtension> DefaultDotProduct<Self, &Self> for FieldElement<F> {}
406
407// Dot product: &FieldElement<F> x &FieldElement<F>
408impl<'a, 'b, F: FieldExtension> MulAccReduce<&'a Self, &'b Self> for FieldElement<F> {
409    type WideType = <F as MulAccReduce>::WideType;
410
411    #[inline]
412    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: &'b Self) {
413        F::mul_acc(acc, a.0, b.0);
414    }
415}
416
417impl<F: FieldExtension> DefaultDotProduct<&Self, &Self> for FieldElement<F> {}
418
419// ----------------
420// | Zero and One |
421// ----------------
422
423impl<F: FieldExtension> Zero for FieldElement<F> {
424    /// The field's additive identity.
425    fn zero() -> Self {
426        FieldElement(F::ZERO)
427    }
428
429    fn is_zero(&self) -> bool {
430        self.0.is_zero().into()
431    }
432}
433
434impl<F: FieldExtension> One for FieldElement<F> {
435    /// The field's multiplicative identity.
436    fn one() -> Self {
437        FieldElement(F::ONE)
438    }
439}
440
441// ---------------
442// | Field trait |
443// ---------------
444
445impl<F: FieldExtension> Field for FieldElement<F> {
446    const ZERO: Self = FieldElement(F::ZERO);
447    const ONE: Self = FieldElement(F::ONE);
448
449    fn random(rng: impl RngCore) -> Self {
450        Self(<F as Field>::random(rng))
451    }
452
453    fn square(&self) -> Self {
454        FieldElement(self.0.square())
455    }
456
457    fn double(&self) -> Self {
458        FieldElement(self.0.double())
459    }
460
461    fn invert(&self) -> CtOption<Self> {
462        self.0.invert().map(FieldElement)
463    }
464
465    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
466        let (choice, sqrt) = F::sqrt_ratio(&num.0, &div.0);
467        (choice, FieldElement(sqrt))
468    }
469}
470
471#[cfg(test)]
472mod tests {
473    use super::*;
474    use crate::algebra::field::mersenne::Mersenne107;
475
476    #[test]
477    fn test_fieldelement_mersenne107_wincode() {
478        let elem = FieldElement::<Mersenne107>::new(Mersenne107::from(42u64));
479        let bytes = wincode::serialize(&elem).unwrap();
480
481        let decoded: FieldElement<Mersenne107> = wincode::deserialize(&bytes).unwrap();
482
483        assert_eq!(elem, decoded);
484    }
485}