prime_field/
lib.rs

1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3#![no_std]
4
5pub use subtle;
6pub use zeroize;
7pub use rand_core;
8pub use crypto_bigint;
9pub use ff;
10
11#[doc(hidden)]
12pub mod __prime_field_private {
13  pub use paste;
14  pub use ff_group_tests;
15
16  use crypto_bigint::{Word, Uint, modular::ConstMontyParams};
17
18  /// Remove the "0x"-prefix from a hex string.
19  ///
20  /// May panic if the string isn't valid hex.
21  pub const fn hex_str_without_prefix(hex: &str) -> &str {
22    if hex.len() < 2 {
23      return hex;
24    }
25
26    if hex.as_bytes()[1] == b'x' {
27      assert!(hex.as_bytes()[0] == b'0', "invalid hex string for modulus");
28      hex.split_at(2).1
29    } else {
30      hex
31    }
32  }
33
34  pub const fn uint_to_u64_words<const LIMBS: usize, const WORDS: usize>(
35    value: Uint<LIMBS>,
36  ) -> [u64; WORDS] {
37    let mut res = [0u64; WORDS];
38    let mut i = 0;
39    while i < Uint::<LIMBS>::LIMBS {
40      let word: Word = value.as_limbs()[i].0;
41      let bits = i * (Word::BITS as usize);
42      let j = bits / (u64::BITS as usize);
43      res[j] |= word << (bits % (u64::BITS as usize));
44      if (j + 1) < WORDS {
45        if let Some(remaining_bits) =
46          ((bits % (u64::BITS as usize)) + (Word::BITS as usize)).checked_sub(u64::BITS as usize)
47        {
48          if remaining_bits != 0 {
49            res[j + 1] |= word >> ((Word::BITS as usize) - remaining_bits);
50          }
51        }
52      }
53      i += 1;
54    }
55    res
56  }
57
58  pub const fn u64_words_to_uint<const LIMBS: usize, const WORDS: usize>(
59    words: [u64; WORDS],
60  ) -> Uint<LIMBS> {
61    let mut reconstruction = Uint::<LIMBS>::ZERO;
62    let mut i = 0;
63    while i < WORDS {
64      reconstruction = reconstruction
65        .bitor(&Uint::<LIMBS>::from_u64(words[i]).shl_vartime((i * (u64::BITS as usize)) as u32));
66      i += 1;
67    }
68    reconstruction
69  }
70
71  #[allow(non_snake_case)]
72  pub const fn calculate_S<const LIMBS: usize, P: ConstMontyParams<LIMBS>>() -> u32 {
73    let mut i = 0;
74    loop {
75      let bit = P::MODULUS.as_ref().wrapping_sub(&Uint::<LIMBS>::ONE).bit_vartime(i);
76      if !bit {
77        i += 1;
78        continue;
79      }
80      break;
81    }
82    i
83  }
84}
85
86#[macro_export]
87macro_rules! odd_prime_field {
88  (
89    $name: ident,
90    $modulus_as_be_hex: expr,
91    $multiplicative_generator_as_be_hex: expr,
92    $big_endian: literal
93  ) => {
94    prime_field::__prime_field_private::paste::paste! {
95      mod [<$name __prime_field_private>] {
96        use core::{
97          ops::*,
98          iter::{Sum, Product},
99        };
100        use prime_field::{
101          subtle::{
102            Choice, CtOption, ConstantTimeEq, ConditionallySelectable, ConditionallyNegatable,
103          },
104          zeroize::Zeroize,
105          rand_core::RngCore,
106          crypto_bigint::{
107            Limb, Encoding, Integer, Uint,
108            modular::{ConstMontyParams, ConstMontyForm},
109            impl_modulus,
110          },
111          ff::*,
112          __prime_field_private::*,
113        };
114
115        const MODULUS_WITHOUT_PREFIX: &str = hex_str_without_prefix($modulus_as_be_hex);
116        const MULTIPLICATIVE_GENERATOR_WITHOUT_PREFIX: &str =
117          hex_str_without_prefix($multiplicative_generator_as_be_hex);
118
119        const MODULUS_BYTES: usize = MODULUS_WITHOUT_PREFIX.len() / 2;
120        type UnderlyingUint = Uint<{ MODULUS_BYTES.div_ceil(Limb::BYTES) }>;
121
122        const PADDED_MODULUS_WITHOUT_PREFIX_BYTES: [u8; 2 * UnderlyingUint::BYTES] = {
123          let mut res = [b'0'; 2 * UnderlyingUint::BYTES];
124          let start = (2 * UnderlyingUint::BYTES) - MODULUS_WITHOUT_PREFIX.len();
125          let mut i = start;
126          while i < (2 * UnderlyingUint::BYTES) {
127            res[i] = MODULUS_WITHOUT_PREFIX.as_bytes()[i - start];
128            i += 1;
129          }
130          res
131        };
132        const PADDED_MODULUS_WITHOUT_PREFIX: &str = {
133          match core::str::from_utf8(&PADDED_MODULUS_WITHOUT_PREFIX_BYTES) {
134            Ok(res) => res,
135            Err(_) => panic!("couldn't successfully pad modulus"),
136          }
137        };
138
139        const PADDED_MULTIPLICATIVE_GENERATOR_WITHOUT_PREFIX_BYTES:
140          [u8; 2 * UnderlyingUint::BYTES] = {
141          let mut res = [b'0'; 2 * UnderlyingUint::BYTES];
142          let start = (2 * UnderlyingUint::BYTES) - MULTIPLICATIVE_GENERATOR_WITHOUT_PREFIX.len();
143          let mut i = start;
144          while i < (2 * UnderlyingUint::BYTES) {
145            res[i] = MULTIPLICATIVE_GENERATOR_WITHOUT_PREFIX.as_bytes()[i - start];
146            i += 1;
147          }
148          res
149        };
150        const PADDED_MULTIPLICATIVE_GENERATOR_WITHOUT_PREFIX: &str = {
151          match core::str::from_utf8(&PADDED_MULTIPLICATIVE_GENERATOR_WITHOUT_PREFIX_BYTES) {
152            Ok(res) => res,
153            Err(_) => panic!("couldn't successfully pad multiplicative generator"),
154          }
155        };
156
157        impl_modulus!(Params, UnderlyingUint, PADDED_MODULUS_WITHOUT_PREFIX);
158        type Underlying = ConstMontyForm<Params, { UnderlyingUint::LIMBS }>;
159
160        const MODULUS: &UnderlyingUint = Params::MODULUS.as_ref();
161        const MODULUS_MINUS_ONE: UnderlyingUint = MODULUS.wrapping_sub(&UnderlyingUint::ONE);
162        const MODULUS_MINUS_TWO: UnderlyingUint = MODULUS.wrapping_sub(&UnderlyingUint::from_u8(2));
163        const T: UnderlyingUint = MODULUS_MINUS_ONE.shr_vartime($name::S);
164
165        /// A field automatically generated with `short-weierstrass`.
166        #[derive(Clone, Copy, Eq, Debug)]
167        pub struct $name(Underlying);
168
169        impl Default for $name {
170          fn default() -> Self {
171            Self::ZERO
172          }
173        }
174
175        impl $name {
176          /// Create a `$name` from the `Uint` type underlying it.
177          pub const fn from(value: &UnderlyingUint) -> Self {
178            $name(Underlying::new(value))
179          }
180        }
181        impl From<u8> for $name {
182          fn from(value: u8) -> Self {
183            Self::from(&UnderlyingUint::from(value))
184          }
185        }
186        impl From<u16> for $name {
187          fn from(value: u16) -> Self {
188            Self::from(&UnderlyingUint::from(value))
189          }
190        }
191        impl From<u32> for $name {
192          fn from(value: u32) -> Self {
193            Self::from(&UnderlyingUint::from(value))
194          }
195        }
196        impl From<u64> for $name {
197          fn from(value: u64) -> Self {
198            Self::from(&UnderlyingUint::from(value))
199          }
200        }
201
202        impl ConstantTimeEq for $name {
203          fn ct_eq(&self, other: &Self) -> Choice {
204            self.0.ct_eq(&other.0)
205          }
206        }
207        impl PartialEq for $name {
208          fn eq(&self, other: &Self) -> bool {
209            bool::from(self.ct_eq(other))
210          }
211        }
212
213        impl ConditionallySelectable for $name {
214          fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
215            Self(<_>::conditional_select(&a.0, &b.0, choice))
216          }
217        }
218        impl ConditionallyNegatable for $name {
219          fn conditional_negate(&mut self, negate: Choice) {
220            self.0.conditional_negate(negate)
221          }
222        }
223
224        impl Zeroize for $name {
225          fn zeroize(&mut self) {
226            self.0.zeroize();
227          }
228        }
229
230        impl Neg for $name {
231          type Output = Self;
232          fn neg(self) -> Self {
233            Self(-self.0)
234          }
235        }
236
237        impl Add for $name {
238          type Output = Self;
239          fn add(self, other: Self) -> Self {
240            Self(self.0 + other.0)
241          }
242        }
243        impl Sub for $name {
244          type Output = Self;
245          fn sub(self, other: Self) -> Self {
246            Self(self.0 - other.0)
247          }
248        }
249        impl Mul for $name {
250          type Output = Self;
251          fn mul(self, other: Self) -> Self {
252            Self(self.0 * other.0)
253          }
254        }
255        impl AddAssign for $name {
256          fn add_assign(&mut self, other: Self) {
257            self.0 += other.0;
258          }
259        }
260        impl SubAssign for $name {
261          fn sub_assign(&mut self, other: Self) {
262            self.0 -= other.0;
263          }
264        }
265        impl MulAssign for $name {
266          fn mul_assign(&mut self, other: Self) {
267            self.0 *= other.0;
268          }
269        }
270        impl Sum for $name {
271          fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
272            let mut res = Self::ZERO;
273            for item in iter {
274              res += item;
275            }
276            res
277          }
278        }
279        impl Product for $name {
280          fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
281            let mut res = Self::ONE;
282            for item in iter {
283              res *= item;
284            }
285            res
286          }
287        }
288        impl<'a> Add<&'a Self> for $name {
289          type Output = Self;
290          fn add(self, other: &'a Self) -> Self {
291            Self(self.0 + other.0)
292          }
293        }
294        impl<'a> Sub<&'a Self> for $name {
295          type Output = Self;
296          fn sub(self, other: &'a Self) -> Self {
297            Self(self.0 - other.0)
298          }
299        }
300        impl<'a> Mul<&'a Self> for $name {
301          type Output = Self;
302          fn mul(self, other: &'a Self) -> Self {
303            Self(self.0 * other.0)
304          }
305        }
306        impl<'a> Sum<&'a Self> for $name {
307          fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
308            let mut res = Self::ZERO;
309            for item in iter {
310              res += item;
311            }
312            res
313          }
314        }
315        impl<'a> Product<&'a Self> for $name {
316          fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
317            let mut res = Self::ONE;
318            for item in iter {
319              res *= item;
320            }
321            res
322          }
323        }
324        impl<'a> AddAssign<&'a Self> for $name {
325          fn add_assign(&mut self, other: &'a Self) {
326            self.0 += other.0;
327          }
328        }
329        impl<'a> SubAssign<&'a Self> for $name {
330          fn sub_assign(&mut self, other: &'a Self) {
331            self.0 -= other.0;
332          }
333        }
334        impl<'a> MulAssign<&'a Self> for $name {
335          fn mul_assign(&mut self, other: &'a Self) {
336            self.0 *= other.0;
337          }
338        }
339
340        impl Field for $name {
341          const ZERO: Self = Self(Underlying::ZERO);
342          const ONE: Self = Self(Underlying::ONE);
343          fn random(mut rng: impl RngCore) -> Self {
344            let mut bytes = [0; 2 * MODULUS_BYTES];
345            rng.fill_bytes(&mut bytes);
346            Self::from_uniform_bytes(&bytes)
347          }
348          fn square(&self) -> Self {
349            Self(self.0.square())
350          }
351          fn double(&self) -> Self {
352            Self(self.0.double())
353          }
354          fn invert(&self) -> CtOption<Self> {
355            CtOption::from(self.0.inv()).map(Self)
356          }
357          fn sqrt(&self) -> CtOption<Self> {
358            const THREE_MOD_FOUR: bool = (MODULUS.as_words()[0] % 4) == 3;
359            const ONE_MOD_EIGHT: bool = (MODULUS.as_words()[0] % 8) == 1;
360            const FIVE_MOD_EIGHT: bool = (MODULUS.as_words()[0] % 8) == 5;
361
362            let sqrt = if THREE_MOD_FOUR {
363              const SQRT_EXP: UnderlyingUint =
364                MODULUS.shr_vartime(2).wrapping_add(&UnderlyingUint::ONE);
365              Self(self.0.pow(&SQRT_EXP))
366            } else if ONE_MOD_EIGHT {
367              const TM1D2: UnderlyingUint = (T.wrapping_sub(&UnderlyingUint::ONE)).shr_vartime(1);
368              const TM1D2_WORDS_LEN: usize = UnderlyingUint::BITS.div_ceil(u64::BITS) as usize;
369              const TM1D2_WORDS: [u64; TM1D2_WORDS_LEN] = uint_to_u64_words(TM1D2);
370
371              const TM1D2_RECONSTRUCTION: UnderlyingUint = u64_words_to_uint(TM1D2_WORDS);
372              const RECONSTRUCTION_EQUALS_VALUE: bool = {
373                let mut i = 0;
374                let mut res = true;
375                while i < TM1D2_WORDS_LEN {
376                  res &= TM1D2_RECONSTRUCTION.as_words()[i] == TM1D2.as_words()[i];
377                  i += 1;
378                }
379                res
380              };
381              const _ASSERT_RECONSTRUCTION_EQUALS_VALUE:
382                [(); 0 - ((!RECONSTRUCTION_EQUALS_VALUE) as usize)] = [(); _];
383
384              helpers::sqrt_tonelli_shanks::<Self, _>(self, TM1D2_WORDS).unwrap_or(Self::ZERO)
385            } else {
386              const SQRT_EXP: UnderlyingUint = MODULUS.shr_vartime(3);
387              let upsilon = self.double().0.pow(&SQRT_EXP);
388              let i = (upsilon.square() * &self.0).double();
389              Self(upsilon * self.0 * (i - Self::ONE.0))
390            };
391
392            let sqrt = <_>::conditional_select(&sqrt, &-sqrt, sqrt.0.retrieve().is_odd());
393            CtOption::new(sqrt, sqrt.square().ct_eq(self))
394          }
395          fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
396            helpers::sqrt_ratio_generic(num, div)
397          }
398        }
399
400        #[derive(Clone, Copy)]
401        pub struct Repr([u8; MODULUS_BYTES]);
402        impl Default for Repr {
403          fn default() -> Self {
404            Self([0; _])
405          }
406        }
407        impl AsRef<[u8]> for Repr {
408          fn as_ref(&self) -> &[u8] {
409            self.0.as_ref()
410          }
411        }
412        impl AsMut<[u8]> for Repr {
413          fn as_mut(&mut self) -> &mut [u8] {
414            self.0.as_mut()
415          }
416        }
417
418        impl PrimeField for $name {
419          type Repr = Repr;
420          const MODULUS: &str = $modulus_as_be_hex;
421          const NUM_BITS: u32 = MODULUS.bits();
422          const CAPACITY: u32 = Self::NUM_BITS - 1;
423          const TWO_INV: Self =
424            Self(Underlying::new(&UnderlyingUint::from_u8(2)).pow(&MODULUS_MINUS_TWO));
425          const MULTIPLICATIVE_GENERATOR: Self = Self(
426            Underlying::new(
427              &UnderlyingUint::from_be_hex(PADDED_MULTIPLICATIVE_GENERATOR_WITHOUT_PREFIX)
428            )
429          );
430          const S: u32 = calculate_S::<_, Params>();
431          const ROOT_OF_UNITY: Self = Self(Self::MULTIPLICATIVE_GENERATOR.0.pow(&T));
432          const ROOT_OF_UNITY_INV: Self = Self(Self::ROOT_OF_UNITY.0.pow(&MODULUS_MINUS_TWO));
433          const DELTA: Self = {
434            let two_to_the_s = UnderlyingUint::ONE.shl_vartime(Self::S);
435            Self(Self::MULTIPLICATIVE_GENERATOR.0.pow(&two_to_the_s))
436          };
437
438          fn to_repr(&self) -> Self::Repr {
439            let mut res = Repr([0; _]);
440            if $big_endian {
441              res.0.copy_from_slice(
442                &self.0.retrieve().to_be_bytes()[(UnderlyingUint::BYTES - MODULUS_BYTES) ..]
443              );
444            } else {
445              res.0.copy_from_slice(&self.0.retrieve().to_le_bytes()[.. MODULUS_BYTES]);
446            }
447            res
448          }
449          fn from_repr(repr: Self::Repr) -> CtOption<Self> {
450            let mut expanded_repr = [0; UnderlyingUint::BYTES];
451            let result = Self(if $big_endian {
452              expanded_repr[(UnderlyingUint::BYTES - MODULUS_BYTES) .. ].copy_from_slice(&repr.0);
453              Underlying::new(&UnderlyingUint::from_be_bytes(expanded_repr))
454            } else {
455              expanded_repr[.. MODULUS_BYTES].copy_from_slice(&repr.0);
456              Underlying::new(&UnderlyingUint::from_le_bytes(expanded_repr))
457            });
458            CtOption::new(result, result.to_repr().0.ct_eq(&repr.0))
459          }
460          fn is_odd(&self) -> Choice {
461            self.0.retrieve().is_odd()
462          }
463        }
464
465        impl PrimeFieldBits for $name {
466          type ReprBits = [u8; UnderlyingUint::BYTES];
467          fn to_le_bits(&self) -> FieldBits<Self::ReprBits> {
468            self.0.retrieve().to_le_bytes().into()
469          }
470          fn char_le_bits() -> FieldBits<Self::ReprBits> {
471            MODULUS.to_le_bytes().into()
472          }
473        }
474
475        impl FromUniformBytes<{ 2 * MODULUS_BYTES }> for $name {
476          fn from_uniform_bytes(bytes: &[u8; 2 * MODULUS_BYTES]) -> Self {
477            let mut expanded_wide_repr = [0; 2 * UnderlyingUint::BYTES];
478            expanded_wide_repr[.. (2 * MODULUS_BYTES)].copy_from_slice(bytes);
479            let bytes = expanded_wide_repr;
480
481            let lo =
482              Underlying::new(&UnderlyingUint::from_le_slice(&bytes[.. UnderlyingUint::BYTES]));
483            let hi =
484              Underlying::new(&UnderlyingUint::from_le_slice(&bytes[UnderlyingUint::BYTES ..]));
485            const HI: Underlying = {
486              let mut res = Underlying::new(&UnderlyingUint::ONE);
487              let mut i = 0;
488              while i < UnderlyingUint::BITS {
489                res = res.double();
490                i += 1;
491              }
492              res
493            };
494            Self(lo + (hi * HI))
495          }
496        }
497
498        const BITS_PLUS_SECURITY_LEVEL: usize =
499          (MODULUS.bits() + MODULUS.bits().div_ceil(2)) as usize;
500        const BITS_PLUS_SECURITY_LEVEL_BYTES: usize = BITS_PLUS_SECURITY_LEVEL.div_ceil(8);
501        impl FromUniformBytes<{ BITS_PLUS_SECURITY_LEVEL_BYTES }> for $name {
502          fn from_uniform_bytes(bytes: &[u8; BITS_PLUS_SECURITY_LEVEL_BYTES]) -> Self {
503            let mut larger = [0; 2 * MODULUS_BYTES];
504            larger[.. BITS_PLUS_SECURITY_LEVEL_BYTES].copy_from_slice(bytes);
505            Self::from_uniform_bytes(&larger)
506          }
507        }
508
509        #[cfg(feature = "std")]
510        #[test]
511        fn test() {
512          use prime_field::__prime_field_private::ff_group_tests;
513          use ff_group_tests::prime_field::test_prime_field_bits;
514          test_prime_field_bits::<_, $name>(&mut prime_field::rand_core::OsRng);
515        }
516      }
517
518      pub use [<$name __prime_field_private>]::$name;
519    }
520  };
521}