Skip to main content

commonware_math/fields/
goldilocks.rs

1use crate::algebra::{Additive, Field, FieldNTT, Multiplicative, Object, Random, Ring};
2use commonware_codec::{FixedSize, Read, Write};
3use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
4use rand_core::CryptoRngCore;
5
6/// The modulus P := 2^64 - 2^32 + 1.
7///
8/// This is a prime number, and we use it to form a field of this order.
9const P: u64 = u64::wrapping_neg(1 << 32) + 1;
10
11/// An element of the [Goldilocks field](https://xn--2-umb.com/22/goldilocks/).
12#[derive(Clone, Copy, PartialEq, Eq)]
13pub struct F(u64);
14
15impl FixedSize for F {
16    const SIZE: usize = u64::SIZE;
17}
18
19impl Write for F {
20    fn write(&self, buf: &mut impl bytes::BufMut) {
21        self.0.write(buf)
22    }
23}
24
25impl Read for F {
26    type Cfg = <u64 as Read>::Cfg;
27
28    fn read_cfg(
29        buf: &mut impl bytes::Buf,
30        cfg: &Self::Cfg,
31    ) -> Result<Self, commonware_codec::Error> {
32        let x = u64::read_cfg(buf, cfg)?;
33        if x >= P {
34            return Err(commonware_codec::Error::Invalid("F", "out of range"));
35        }
36        Ok(Self(x))
37    }
38}
39
40impl core::fmt::Debug for F {
41    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
42        write!(f, "{:016X}", self.0)
43    }
44}
45
46#[cfg(any(test, feature = "arbitrary"))]
47impl arbitrary::Arbitrary<'_> for F {
48    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
49        let x = u.arbitrary::<u64>()?;
50        Ok(Self::reduce_64(x))
51    }
52}
53
54impl F {
55    // The following constants are not randomly chosen, but computed in a specific
56    // way. They could be computed at compile time, with each definition actually
57    // doing the computation, but to avoid burdening compilation, we instead enforce
58    // where they originate from with tests.
59
60    /// Any non-zero element x = GENERATOR^k, for some k.
61    ///
62    /// This is chosen such that GENERATOR^((P - 1) / 64) = 8.
63    #[cfg(test)]
64    pub const GENERATOR: Self = Self(0xd64f951101aff9bf);
65
66    /// An element of order 2^32.
67    ///
68    /// This is specifically chosen such that ROOT_OF_UNITY^(2^26) = 8.
69    ///
70    /// That enables optimizations when doing NTTs, and things like that.
71    pub const ROOT_OF_UNITY: Self = Self(0xee41f5320c4ea145);
72
73    /// An element guaranteed not to be any power of [Self::ROOT_OF_UNITY].
74    pub const NOT_ROOT_OF_UNITY: Self = Self(0x79bc2f50acd74161);
75
76    /// The inverse of [Self::NOT_ROOT_OF_UNITY].
77    pub const NOT_ROOT_OF_UNITY_INV: Self = Self(0x1036c4023580ce8d);
78
79    /// The zero element of the field.
80    ///
81    /// This is the identity for addition.
82    const ZERO: Self = Self(0);
83
84    /// The one element of the field.
85    ///
86    /// This is the identity for multiplication.
87    const ONE: Self = Self(1);
88
89    const fn add_inner(self, b: Self) -> Self {
90        // We want to calculate self + b mod P.
91        // At a high level, this can be done by adding self + b, as integers,
92        // and then subtracting P as long as the result >= P.
93        //
94        // How many times do we need to do this?
95        //
96        // self <= P - 1
97        // b <= P - 1
98        // ∴ self + b <= 2P - 2
99        // ∴ self + b - P <= P - 1
100        //
101        // So, we need to subtract P at most once.
102
103        // addition + 2^64 * overflow = self + b
104        let (addition, overflow) = self.0.overflowing_add(b.0);
105        // In the case of overflow = 1, addition + 2^64 > P, so we need to
106        // subtract. The result of this subtraction will be < 2^64,
107        // so we can compute it by calculating addition - P, wrapping around.
108        let (subtraction, underflow) = addition.overflowing_sub(P);
109        // In the case of overflow, we use the subtraction (as mentioned above).
110        // Otherwise, use the subtraction as long as we didn't underflow
111        if overflow || !underflow {
112            Self(subtraction)
113        } else {
114            Self(addition)
115        }
116    }
117
118    const fn sub_inner(self, b: Self) -> Self {
119        // The strategy here is to perform the subtraction, and then (maybe) add back P.
120        // If no underflow happened, the result is reduced, since both values were < P.
121        // If an underflow happened, the largest result we can have is -1. Adding
122        // P gives us P - 1, which is < P, so everything works.
123        let (subtraction, underflow) = self.0.overflowing_sub(b.0);
124        if underflow {
125            Self(subtraction.wrapping_add(P))
126        } else {
127            Self(subtraction)
128        }
129    }
130
131    const fn reduce_64(x: u64) -> Self {
132        // 2 * P > 2^64 - 1 (by a long margin)
133        // We thus need to subtract P at most once.
134        let (subtraction, underflow) = x.overflowing_sub(P);
135        if underflow {
136            Self(x)
137        } else {
138            Self(subtraction)
139        }
140    }
141
142    /// Reduce a 128 bit integer into a field element.
143    const fn reduce_128(x: u128) -> Self {
144        // We exploit special properties of the field.
145        //
146        // First, 2^64 = 2^32 - 1 mod P.
147        //
148        // Second, 2^96 = 2^32(2^32 - 1) = 2^64 - 2^32 = -1 mod P.
149        //
150        // Thus, if we write a 128 bit integer x as:
151        //     x = c 2^96 + b 2^64 + a
152        // We have:
153        //     x = b (2^32 - 1) + (a - c) mod P
154        // And this expression will be our strategy for performing the reduction.
155        let a = x as u64;
156        let b = ((x >> 64) & 0xFF_FF_FF_FF) as u64;
157        let c = (x >> 96) as u64;
158
159        // While we lean on existing code, we need to be careful because some of
160        // these types are partially reduced.
161        //
162        // First, if we look at a - c, the end result with our field code can
163        // be any 64 bit value (consider c = 0). We can also make the same assumption
164        // for (b << 32) - b. The question then becomes, is Field(x) + Field(y)
165        // ok even if both x and y are arbitrary u64 values?
166        //
167        // Yes. Even if x and y have the maximum value, a single subtraction of P
168        // would suffice to make their sum < P. Thus, our strategy for field addition
169        // will always work.
170        //
171        // Note: (b << 32) - b = b * (2^32 - 1). Since b <= 2^32 - 1, this is at most
172        // (2^32 - 1)^2 = 2^64 - 2^33 + 1 < 2^64. Since b << 32 >= b always,
173        // this subtraction will never underflow.
174        Self(a).sub_inner(Self(c)).add_inner(Self((b << 32) - b))
175    }
176
177    const fn mul_inner(self, b: Self) -> Self {
178        // We do a u64 x u64 -> u128 multiplication, then reduce mod P
179        Self::reduce_128((self.0 as u128) * (b.0 as u128))
180    }
181
182    const fn neg_inner(self) -> Self {
183        Self::ZERO.sub_inner(self)
184    }
185
186    /// Return the multiplicative inverse of a field element.
187    ///
188    /// [Self::zero] will return [Self::zero].
189    pub fn inv(self) -> Self {
190        self.exp(&[P - 2])
191    }
192
193    /// Convert a stream of u64s into a stream of field elements.
194    pub fn stream_from_u64s(inner: impl Iterator<Item = u64>) -> impl Iterator<Item = Self> {
195        struct Iter<I> {
196            acc: u128,
197            acc_bits: u32,
198            inner: I,
199        }
200
201        impl<I: Iterator<Item = u64>> Iterator for Iter<I> {
202            type Item = F;
203
204            fn next(&mut self) -> Option<Self::Item> {
205                while self.acc_bits < 63 {
206                    let Some(x) = self.inner.next() else {
207                        break;
208                    };
209                    let x = u128::from(x);
210                    self.acc |= x << self.acc_bits;
211                    self.acc_bits += 64;
212                }
213                if self.acc_bits > 0 {
214                    self.acc_bits = self.acc_bits.saturating_sub(63);
215                    let out = F((self.acc as u64) & ((1 << 63) - 1));
216                    self.acc >>= 63;
217                    return Some(out);
218                }
219                None
220            }
221        }
222
223        Iter {
224            acc: 0,
225            acc_bits: 0,
226            inner,
227        }
228    }
229
230    /// Convert a stream produced by [F::stream_from_u64s] back to the original stream.
231    ///
232    /// This may produce a single extra 0 element.
233    pub fn stream_to_u64s(inner: impl Iterator<Item = Self>) -> impl Iterator<Item = u64> {
234        struct Iter<I> {
235            acc: u128,
236            acc_bits: u32,
237            inner: I,
238        }
239
240        impl<I: Iterator<Item = F>> Iterator for Iter<I> {
241            type Item = u64;
242
243            fn next(&mut self) -> Option<Self::Item> {
244                // Try and fill acc with 64 bits of data.
245                while self.acc_bits < 64 {
246                    let Some(F(x)) = self.inner.next() else {
247                        break;
248                    };
249                    // Ignore any upper bits of x
250                    let x = u128::from(x & ((1 << 63) - 1));
251                    self.acc |= x << self.acc_bits;
252                    self.acc_bits += 63;
253                }
254                if self.acc_bits > 0 {
255                    self.acc_bits = self.acc_bits.saturating_sub(64);
256                    let out = self.acc as u64;
257                    self.acc >>= 64;
258                    return Some(out);
259                }
260                None
261            }
262        }
263        Iter {
264            acc: 0,
265            acc_bits: 0,
266            inner,
267        }
268    }
269
270    /// How many elements are used to encode a given number of bits?
271    ///
272    /// This is based on what [F::stream_from_u64s] does.
273    pub const fn bits_to_elements(bits: usize) -> usize {
274        bits.div_ceil(63)
275    }
276
277    /// Convert this element to little-endian bytes.
278    pub const fn to_le_bytes(&self) -> [u8; 8] {
279        self.0.to_le_bytes()
280    }
281}
282
283impl Object for F {}
284
285impl Random for F {
286    fn random(mut rng: impl CryptoRngCore) -> Self {
287        // this fails only about once every 2^32 attempts
288        loop {
289            let x = rng.next_u64();
290            if x < P {
291                return Self(x);
292            }
293        }
294    }
295}
296
297impl Add for F {
298    type Output = Self;
299
300    fn add(self, b: Self) -> Self::Output {
301        self.add_inner(b)
302    }
303}
304
305impl<'a> Add<&'a Self> for F {
306    type Output = Self;
307
308    fn add(self, rhs: &'a Self) -> Self::Output {
309        self + *rhs
310    }
311}
312
313impl<'a> AddAssign<&'a Self> for F {
314    fn add_assign(&mut self, rhs: &'a Self) {
315        *self = *self + rhs
316    }
317}
318
319impl<'a> Sub<&'a Self> for F {
320    type Output = Self;
321
322    fn sub(self, rhs: &'a Self) -> Self::Output {
323        self - *rhs
324    }
325}
326
327impl<'a> SubAssign<&'a Self> for F {
328    fn sub_assign(&mut self, rhs: &'a Self) {
329        *self = *self - rhs;
330    }
331}
332
333impl Additive for F {
334    fn zero() -> Self {
335        Self::ZERO
336    }
337}
338
339impl Sub for F {
340    type Output = Self;
341
342    fn sub(self, b: Self) -> Self::Output {
343        self.sub_inner(b)
344    }
345}
346
347impl Mul for F {
348    type Output = Self;
349
350    fn mul(self, b: Self) -> Self::Output {
351        Self::mul_inner(self, b)
352    }
353}
354
355impl<'a> Mul<&'a Self> for F {
356    type Output = Self;
357
358    fn mul(self, rhs: &'a Self) -> Self::Output {
359        self * *rhs
360    }
361}
362
363impl<'a> MulAssign<&'a Self> for F {
364    fn mul_assign(&mut self, rhs: &'a Self) {
365        *self = *self * rhs;
366    }
367}
368
369impl Multiplicative for F {}
370
371impl Neg for F {
372    type Output = Self;
373
374    fn neg(self) -> Self::Output {
375        self.neg_inner()
376    }
377}
378
379impl From<u64> for F {
380    fn from(value: u64) -> Self {
381        Self::reduce_64(value)
382    }
383}
384
385impl Ring for F {
386    fn one() -> Self {
387        Self::ONE
388    }
389}
390
391impl Field for F {
392    fn inv(&self) -> Self {
393        Self::inv(*self)
394    }
395}
396
397impl FieldNTT for F {
398    const MAX_LG_ROOT_ORDER: u8 = 32;
399
400    fn root_of_unity(lg: u8) -> Option<Self> {
401        if lg > Self::MAX_LG_ROOT_ORDER {
402            return None;
403        }
404        let mut out = Self::ROOT_OF_UNITY;
405        for _ in 0..(Self::MAX_LG_ROOT_ORDER - lg) {
406            out = out * out;
407        }
408        Some(out)
409    }
410
411    fn coset_shift() -> Self {
412        Self::NOT_ROOT_OF_UNITY
413    }
414
415    fn coset_shift_inv() -> Self {
416        Self::NOT_ROOT_OF_UNITY_INV
417    }
418
419    fn div_2(&self) -> Self {
420        if self.0 & 1 == 0 {
421            Self(self.0 >> 1)
422        } else {
423            let (addition, carry) = self.0.overflowing_add(P);
424            Self((u64::from(carry) << 63) | (addition >> 1))
425        }
426    }
427}
428
429#[cfg(any(test, feature = "fuzz"))]
430pub mod fuzz {
431    use super::*;
432    use crate::algebra::test_suites;
433    use arbitrary::{Arbitrary, Unstructured};
434    use commonware_codec::{Encode as _, ReadExt as _};
435
436    #[derive(Debug)]
437    pub struct NonCanonicalU64(pub u64);
438
439    impl Arbitrary<'_> for NonCanonicalU64 {
440        fn arbitrary(u: &mut Unstructured<'_>) -> arbitrary::Result<Self> {
441            Ok(Self(u.int_in_range(P..=u64::MAX)?))
442        }
443    }
444
445    #[derive(Debug, Arbitrary)]
446    pub enum Plan {
447        StreamRoundtrip(Vec<u64>),
448        ReadRejectsOutOfRange(NonCanonicalU64),
449        FuzzField,
450    }
451
452    impl Plan {
453        pub fn run(self, u: &mut Unstructured<'_>) -> arbitrary::Result<()> {
454            match self {
455                Self::StreamRoundtrip(data) => {
456                    let mut roundtrip =
457                        F::stream_to_u64s(F::stream_from_u64s(data.clone().into_iter()))
458                            .collect::<Vec<_>>();
459                    roundtrip.truncate(data.len());
460                    assert_eq!(data, roundtrip);
461                }
462                Self::ReadRejectsOutOfRange(NonCanonicalU64(x)) => {
463                    let result = F::read(&mut x.encode());
464                    assert!(matches!(
465                        result,
466                        Err(commonware_codec::Error::Invalid("F", "out of range"))
467                    ));
468                }
469                Self::FuzzField => {
470                    test_suites::fuzz_field_ntt::<F>(u)?;
471                }
472            }
473            Ok(())
474        }
475    }
476
477    #[test]
478    fn test_fuzz() {
479        commonware_invariants::minifuzz::test(|u| u.arbitrary::<Plan>()?.run(u));
480    }
481
482    #[test]
483    fn test_read_cfg_rejects_modulus_regression_case() {
484        let mut u = Unstructured::new(&[]);
485        Plan::ReadRejectsOutOfRange(NonCanonicalU64(P))
486            .run(&mut u)
487            .expect("regression plan should succeed");
488    }
489}
490
491#[cfg(test)]
492mod test {
493    use super::*;
494
495    #[test]
496    fn test_generator_calculation() {
497        assert_eq!(F::GENERATOR, F(7).exp(&[133]));
498    }
499
500    #[test]
501    fn test_root_of_unity_calculation() {
502        assert_eq!(F::ROOT_OF_UNITY, F::GENERATOR.exp(&[(P - 1) >> 32]));
503    }
504
505    #[test]
506    fn test_not_root_of_unity_calculation() {
507        assert_eq!(F::NOT_ROOT_OF_UNITY, F::GENERATOR.exp(&[1 << 32]));
508    }
509
510    #[test]
511    fn test_not_root_of_unity_inv_calculation() {
512        assert_eq!(F::NOT_ROOT_OF_UNITY * F::NOT_ROOT_OF_UNITY_INV, F::one());
513    }
514
515    #[test]
516    fn test_root_of_unity_exp() {
517        assert_eq!(F::ROOT_OF_UNITY.exp(&[1 << 26]), F(8));
518    }
519
520    #[cfg(feature = "arbitrary")]
521    mod conformance {
522        use super::*;
523        use commonware_codec::conformance::CodecConformance;
524
525        commonware_conformance::conformance_tests! {
526            CodecConformance<F>
527        }
528    }
529}