Skip to main content

primitives/algebra/field/
field_element.rs

1use std::{
2    iter::{Product, Sum},
3    ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign},
4};
5
6use derive_more::derive::{AsMut, AsRef};
7use ff::Field;
8use hybrid_array::Array;
9use num_traits::{One, Zero};
10use rand::RngCore;
11use serde::{Deserialize, Serialize};
12use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
13use wincode::{SchemaRead, SchemaWrite};
14
15use crate::{
16    algebra::{
17        field::{ByteSize, FieldExtension, PrimeFieldExtension, SubfieldElement},
18        ops::{AccReduce, DefaultDotProduct, IntoWide, MulAccReduce, ReduceWide},
19        uniform_bytes::FromUniformBytes,
20    },
21    errors::PrimitiveError,
22    random::{CryptoRngCore, Random},
23    sharing::unauthenticated::AdditiveShares,
24};
25
26/// A field element wrapper.
27#[derive(
28    Copy, Clone, Debug, PartialOrd, PartialEq, Eq, Hash, AsRef, AsMut, SchemaRead, SchemaWrite,
29)]
30#[repr(transparent)]
31pub struct FieldElement<F: FieldExtension>(pub F);
32
33impl<F: FieldExtension> FieldElement<F> {
34    /// Construct a field element wrapper from an inner field element
35    #[inline]
36    pub fn new(inner: F) -> Self {
37        FieldElement(inner)
38    }
39
40    /// Get the inner value of the field element
41    #[inline]
42    pub fn inner(&self) -> F {
43        self.0
44    }
45
46    /// Compute the exponentiation of the given field element
47    #[inline]
48    pub fn pow(&self, exp: u64) -> Self {
49        FieldElement::new(self.0.pow([exp]))
50    }
51
52    /// Construct a field element from the given bytes
53    #[inline]
54    pub fn from_be_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
55        let mut bytes = bytes.to_vec();
56        bytes.reverse();
57        Ok(FieldElement(F::from_le_bytes(&bytes).ok_or_else(|| {
58            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
59        })?))
60    }
61
62    pub fn from_le_bytes(bytes: &[u8]) -> Result<FieldElement<F>, PrimitiveError> {
63        Ok(FieldElement(F::from_le_bytes(bytes).ok_or_else(|| {
64            PrimitiveError::DeserializationFailed("Invalid field element encoding".to_string())
65        })?))
66    }
67
68    /// Convert the field element to little-endian bytes
69    #[inline]
70    pub fn to_le_bytes(&self) -> Array<u8, ByteSize<F>> {
71        self.0.to_le_bytes()
72    }
73
74    /// Convert the field element to big-endian bytes
75    #[inline]
76    pub fn to_be_bytes(&self) -> Array<u8, ByteSize<F>> {
77        let mut rev = self.0.to_le_bytes();
78        rev.as_mut().reverse();
79        rev
80    }
81
82    pub fn to_biguint(&self) -> num_bigint::BigUint {
83        num_bigint::BigUint::from_bytes_le(self.to_le_bytes().as_ref())
84    }
85}
86
87impl<F: FieldExtension> Random for FieldElement<F> {
88    #[inline]
89    fn random(rng: impl CryptoRngCore) -> Self {
90        FieldElement(Random::random(rng))
91    }
92}
93
94impl<F: FieldExtension> Default for FieldElement<F> {
95    fn default() -> Self {
96        Self::zero()
97    }
98}
99
100impl<F: FieldExtension> Serialize for FieldElement<F> {
101    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
102        let bytes = self
103            .to_le_bytes()
104            .into_iter()
105            .collect::<Array<u8, F::FieldBytesSize>>();
106        serde_bytes::serialize(AsRef::<[u8]>::as_ref(&bytes), serializer)
107    }
108}
109
110impl<'de, F: FieldExtension> Deserialize<'de> for FieldElement<F> {
111    fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
112        let bytes: &[u8] = serde_bytes::deserialize(deserializer)?;
113        let field_elem = FieldElement::from_le_bytes(bytes).map_err(|err| {
114            serde::de::Error::custom(format!("Failed to deserialize field element: {err:?}"))
115        })?;
116        Ok(field_elem)
117    }
118}
119
120// --------------
121// | Arithmetic |
122// --------------
123
124// === Addition === //
125
126#[macros::op_variants(owned, borrowed, flipped_commutative)]
127impl<F: FieldExtension> Add<&FieldElement<F>> for FieldElement<F> {
128    type Output = FieldElement<F>;
129
130    #[inline]
131    fn add(self, rhs: &FieldElement<F>) -> Self::Output {
132        FieldElement(self.0 + rhs.0)
133    }
134}
135
136#[macros::op_variants(owned)]
137impl<'a, F: FieldExtension> AddAssign<&'a FieldElement<F>> for FieldElement<F> {
138    #[inline]
139    fn add_assign(&mut self, rhs: &'a FieldElement<F>) {
140        *self = *self + rhs;
141    }
142}
143
144// === Subtraction === //
145
146#[macros::op_variants(owned, borrowed, flipped)]
147impl<F: FieldExtension> Sub<&FieldElement<F>> for FieldElement<F> {
148    type Output = FieldElement<F>;
149
150    #[inline]
151    fn sub(self, rhs: &FieldElement<F>) -> Self::Output {
152        FieldElement(self.0 - rhs.0)
153    }
154}
155
156#[macros::op_variants(owned)]
157impl<'a, F: FieldExtension> SubAssign<&'a FieldElement<F>> for FieldElement<F> {
158    #[inline]
159    fn sub_assign(&mut self, rhs: &'a FieldElement<F>) {
160        *self = *self - rhs;
161    }
162}
163
164// === Multiplication === //
165
166#[macros::op_variants(owned, borrowed, flipped_commutative)]
167impl<F: FieldExtension> Mul<&FieldElement<F>> for FieldElement<F> {
168    type Output = FieldElement<F>;
169
170    #[inline]
171    fn mul(self, rhs: &FieldElement<F>) -> Self::Output {
172        FieldElement(self.0 * rhs.0)
173    }
174}
175
176#[macros::op_variants(owned, borrowed, flipped)]
177impl<F: FieldExtension> Mul<&SubfieldElement<F>> for FieldElement<F> {
178    type Output = FieldElement<F>;
179
180    #[inline]
181    fn mul(self, rhs: &SubfieldElement<F>) -> Self::Output {
182        FieldElement(self.0 * rhs.0)
183    }
184}
185
186#[macros::op_variants(owned)]
187impl<'a, F: FieldExtension> MulAssign<&'a FieldElement<F>> for FieldElement<F> {
188    #[inline]
189    fn mul_assign(&mut self, rhs: &'a FieldElement<F>) {
190        *self = *self * rhs;
191    }
192}
193
194#[macros::op_variants(owned)]
195impl<'a, F: FieldExtension> MulAssign<&'a SubfieldElement<F>> for FieldElement<F> {
196    #[inline]
197    fn mul_assign(&mut self, rhs: &'a SubfieldElement<F>) {
198        *self = *self * rhs;
199    }
200}
201
202// === Negation === //
203
204#[macros::op_variants(borrowed)]
205impl<F: FieldExtension> Neg for FieldElement<F> {
206    type Output = FieldElement<F>;
207
208    #[inline]
209    fn neg(self) -> Self::Output {
210        FieldElement(-self.0)
211    }
212}
213
214// === Division === //
215
216#[macros::op_variants(owned, borrowed, flipped)]
217impl<F: FieldExtension> Div<&FieldElement<F>> for FieldElement<F> {
218    type Output = CtOption<FieldElement<F>>;
219
220    #[inline]
221    fn div(self, rhs: &FieldElement<F>) -> Self::Output {
222        rhs.0.invert().map(|inv| FieldElement(self.0 * inv))
223    }
224}
225
226// === Equality === //
227
228impl<F: FieldExtension> ConstantTimeEq for FieldElement<F> {
229    #[inline]
230    fn ct_eq(&self, other: &Self) -> Choice {
231        self.0.ct_eq(&other.0)
232    }
233}
234
235impl<F: FieldExtension> ConditionallySelectable for FieldElement<F> {
236    #[inline]
237    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
238        let selected = F::conditional_select(&a.0, &b.0, choice);
239        FieldElement(selected)
240    }
241}
242
243// === Other === //
244
245impl<F: FieldExtension> AdditiveShares for FieldElement<F> {}
246
247// ---------------
248// | Conversions |
249// ---------------
250
251impl<F: FieldExtension> From<bool> for FieldElement<F> {
252    #[inline]
253    fn from(value: bool) -> Self {
254        FieldElement(F::from(value as u64))
255    }
256}
257
258impl<F: FieldExtension> From<u8> for FieldElement<F> {
259    #[inline]
260    fn from(value: u8) -> Self {
261        FieldElement(F::from(value as u64))
262    }
263}
264
265impl<F: FieldExtension> From<u16> for FieldElement<F> {
266    #[inline]
267    fn from(value: u16) -> Self {
268        FieldElement(F::from(value as u64))
269    }
270}
271
272impl<F: FieldExtension> From<u32> for FieldElement<F> {
273    #[inline]
274    fn from(value: u32) -> Self {
275        FieldElement(F::from(value as u64))
276    }
277}
278
279impl<F: FieldExtension> From<u64> for FieldElement<F> {
280    #[inline]
281    fn from(value: u64) -> Self {
282        FieldElement(F::from(value))
283    }
284}
285
286impl<F: FieldExtension> From<u128> for FieldElement<F> {
287    #[inline]
288    fn from(value: u128) -> Self {
289        FieldElement(F::from(value))
290    }
291}
292
293impl<F: PrimeFieldExtension> From<SubfieldElement<F>> for FieldElement<F> {
294    #[inline]
295    fn from(value: SubfieldElement<F>) -> Self {
296        FieldElement(value.0)
297    }
298}
299
300// -------------------
301// | Iterator Traits |
302// -------------------
303impl<F: FieldExtension> Sum for FieldElement<F> {
304    #[inline]
305    fn sum<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
306        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
307            F::acc(&mut acc, x.0);
308            acc
309        });
310
311        FieldElement(F::reduce_mod_order(tmp))
312    }
313}
314
315impl<'a, F: FieldExtension> Sum<&'a FieldElement<F>> for FieldElement<F> {
316    #[inline]
317    fn sum<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
318        let tmp = iter.fold(<F as AccReduce>::zero_wide(), |mut acc, x| {
319            F::acc(&mut acc, x.0);
320            acc
321        });
322
323        FieldElement(F::reduce_mod_order(tmp))
324    }
325}
326
327impl<F: FieldExtension> Product for FieldElement<F> {
328    #[inline]
329    fn product<I: Iterator<Item = FieldElement<F>>>(iter: I) -> Self {
330        iter.fold(FieldElement::one(), |acc, x| acc * x)
331    }
332}
333
334impl<'a, F: FieldExtension> Product<&'a FieldElement<F>> for FieldElement<F> {
335    #[inline]
336    fn product<I: Iterator<Item = &'a FieldElement<F>>>(iter: I) -> Self {
337        iter.fold(FieldElement::one(), |acc, x| acc * x)
338    }
339}
340
341impl<F: FieldExtension> FromUniformBytes for FieldElement<F> {
342    type UniformBytes = <F as FromUniformBytes>::UniformBytes;
343
344    fn from_uniform_bytes(bytes: &Array<u8, Self::UniformBytes>) -> Self {
345        Self(F::from_uniform_bytes(bytes))
346    }
347}
348
349// Dot product: FieldElement<F> x FieldElement<F>
350impl<F: FieldExtension> IntoWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
351    #[inline]
352    fn to_wide(&self) -> <F as MulAccReduce>::WideType {
353        <F as MulAccReduce>::to_wide(&self.0)
354    }
355
356    #[inline]
357    fn zero_wide() -> <F as MulAccReduce>::WideType {
358        <F as MulAccReduce>::zero_wide()
359    }
360}
361
362impl<F: FieldExtension> ReduceWide<<F as MulAccReduce>::WideType> for FieldElement<F> {
363    #[inline]
364    fn reduce_mod_order(a: <F as MulAccReduce>::WideType) -> Self {
365        Self(F::reduce_mod_order(a))
366    }
367}
368
369impl<F: FieldExtension> MulAccReduce for FieldElement<F> {
370    type WideType = <F as MulAccReduce>::WideType;
371
372    #[inline]
373    fn mul_acc(acc: &mut Self::WideType, a: Self, b: Self) {
374        F::mul_acc(acc, a.0, b.0);
375    }
376}
377
378impl<F: FieldExtension> DefaultDotProduct for FieldElement<F> {}
379
380// Dot product: &FieldElement<F> x FieldElement<F>
381impl<'a, F: FieldExtension> MulAccReduce<&'a Self, Self> for FieldElement<F> {
382    type WideType = <F as MulAccReduce>::WideType;
383
384    #[inline]
385    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: Self) {
386        F::mul_acc(acc, a.0, b.0);
387    }
388}
389
390impl<F: FieldExtension> DefaultDotProduct<&Self, Self> for FieldElement<F> {}
391
392// Dot product: FieldElement<F> x &FieldElement<F>
393impl<'a, F: FieldExtension> MulAccReduce<Self, &'a Self> for FieldElement<F> {
394    type WideType = <F as MulAccReduce>::WideType;
395
396    #[inline]
397    fn mul_acc(acc: &mut Self::WideType, a: Self, b: &'a Self) {
398        F::mul_acc(acc, a.0, b.0);
399    }
400}
401
402impl<F: FieldExtension> DefaultDotProduct<Self, &Self> for FieldElement<F> {}
403
404// Dot product: &FieldElement<F> x &FieldElement<F>
405impl<'a, 'b, F: FieldExtension> MulAccReduce<&'a Self, &'b Self> for FieldElement<F> {
406    type WideType = <F as MulAccReduce>::WideType;
407
408    #[inline]
409    fn mul_acc(acc: &mut Self::WideType, a: &'a Self, b: &'b Self) {
410        F::mul_acc(acc, a.0, b.0);
411    }
412}
413
414impl<F: FieldExtension> DefaultDotProduct<&Self, &Self> for FieldElement<F> {}
415
416// ----------------
417// | Zero and One |
418// ----------------
419
420impl<F: FieldExtension> Zero for FieldElement<F> {
421    /// The field's additive identity.
422    fn zero() -> Self {
423        FieldElement(F::ZERO)
424    }
425
426    fn is_zero(&self) -> bool {
427        self.0.is_zero().into()
428    }
429}
430
431impl<F: FieldExtension> One for FieldElement<F> {
432    /// The field's multiplicative identity.
433    fn one() -> Self {
434        FieldElement(F::ONE)
435    }
436}
437
438// ---------------
439// | Field trait |
440// ---------------
441
442impl<F: FieldExtension> Field for FieldElement<F> {
443    const ZERO: Self = FieldElement(F::ZERO);
444    const ONE: Self = FieldElement(F::ONE);
445
446    fn random(rng: impl RngCore) -> Self {
447        Self(<F as Field>::random(rng))
448    }
449
450    fn square(&self) -> Self {
451        FieldElement(self.0.square())
452    }
453
454    fn double(&self) -> Self {
455        FieldElement(self.0.double())
456    }
457
458    fn invert(&self) -> CtOption<Self> {
459        self.0.invert().map(FieldElement)
460    }
461
462    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
463        let (choice, sqrt) = F::sqrt_ratio(&num.0, &div.0);
464        (choice, FieldElement(sqrt))
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use crate::algebra::field::mersenne::Mersenne107;
472
473    #[test]
474    fn test_fieldelement_mersenne107_wincode() {
475        let elem = FieldElement::<Mersenne107>::new(Mersenne107::from(42u64));
476        let bytes = wincode::serialize(&elem).unwrap();
477
478        let decoded: FieldElement<Mersenne107> = wincode::deserialize(&bytes).unwrap();
479
480        assert_eq!(elem, decoded);
481    }
482}