lambdaworks_math/field/fields/mersenne31/
field.rs

1use crate::{
2    errors::CreationError,
3    field::{
4        element::FieldElement,
5        errors::FieldError,
6        traits::{IsField, IsPrimeField},
7    },
8};
9use core::fmt::{self, Display};
10
11/// Represents a 31 bit integer value
12/// Invariants:
13///      31st bit is clear
14///      n < MODULUS
15#[derive(Debug, Clone, Copy, Hash, PartialOrd, Ord, PartialEq, Eq)]
16pub struct Mersenne31Field;
17
18impl Mersenne31Field {
19    fn weak_reduce(n: u32) -> u32 {
20        // To reduce 'n' to 31 bits we clear its MSB, then add it back in its reduced form.
21        let msb = n & (1 << 31);
22        let msb_reduced = msb >> 31;
23        let res = msb ^ n;
24
25        // assert msb_reduced fits within 31 bits
26        debug_assert!((res >> 31) == 0 && (msb_reduced >> 1) == 0);
27        res + msb_reduced
28    }
29
30    fn as_representative(n: &u32) -> u32 {
31        if *n == MERSENNE_31_PRIME_FIELD_ORDER {
32            0
33        } else {
34            *n
35        }
36    }
37
38    #[inline]
39    pub fn sum<I: Iterator<Item = <Self as IsField>::BaseType>>(
40        iter: I,
41    ) -> <Self as IsField>::BaseType {
42        // Delayed reduction
43        Self::from_u64(iter.map(|x| x as u64).sum::<u64>())
44    }
45
46    /// Computes a * 2^k, with 0 < k < 31
47    pub fn mul_power_two(a: u32, k: u32) -> u32 {
48        let msb = (a & (u32::MAX << (31 - k))) >> (31 - k); // The k + 1 msf shifted right .
49        let lsb = (a & (u32::MAX >> (k + 1))) << k; // The 31 - k lsb shifted left.
50        Self::weak_reduce(msb + lsb)
51    }
52
53    pub fn pow_2(a: &u32, order: u32) -> u32 {
54        let mut res = *a;
55        (0..order).for_each(|_| res = Self::square(&res));
56        res
57    }
58
59    /// TODO: See if we can optimize this function.
60    /// Computes 2a^2 - 1
61    pub fn two_square_minus_one(a: &u32) -> u32 {
62        if *a == 0 {
63            MERSENNE_31_PRIME_FIELD_ORDER - 1
64        } else {
65            Self::from_u64(((u64::from(*a) * u64::from(*a)) << 1) - 1)
66        }
67    }
68}
69
70pub const MERSENNE_31_PRIME_FIELD_ORDER: u32 = (1 << 31) - 1;
71
72//NOTE: This implementation was inspired by and borrows from the work done by the Plonky3 team
73// https://github.com/Plonky3/Plonky3/blob/main/mersenne-31/src/lib.rs
74// Thank you for pushing this technology forward.
75impl IsField for Mersenne31Field {
76    type BaseType = u32;
77
78    /// Returns the sum of `a` and `b`.
79    fn add(a: &u32, b: &u32) -> u32 {
80        // We are using that if a and b are field elements of Mersenne31, then
81        // a + b has at most 32 bits, so we can use the weak_reduce function to take mudulus p.
82        Self::weak_reduce(a + b)
83    }
84
85    /// Returns the multiplication of `a` and `b`.
86    // Note: for powers of 2 we can perform bit shifting this would involve overriding the trait implementation
87    fn mul(a: &u32, b: &u32) -> u32 {
88        Self::from_u64(u64::from(*a) * u64::from(*b))
89    }
90
91    fn sub(a: &u32, b: &u32) -> u32 {
92        Self::weak_reduce(a + MERSENNE_31_PRIME_FIELD_ORDER - b)
93    }
94
95    /// Returns the additive inverse of `a`.
96    fn neg(a: &u32) -> u32 {
97        // NOTE: MODULUS known to have 31 bit clear
98        MERSENNE_31_PRIME_FIELD_ORDER - a
99    }
100
101    /// Returns the multiplicative inverse of `a`.
102    fn inv(x: &u32) -> Result<u32, FieldError> {
103        if *x == Self::zero() || *x == MERSENNE_31_PRIME_FIELD_ORDER {
104            return Err(FieldError::InvZeroError);
105        }
106        let p101 = Self::mul(&Self::pow_2(x, 2), x);
107        let p1111 = Self::mul(&Self::square(&p101), &p101);
108        let p11111111 = Self::mul(&Self::pow_2(&p1111, 4u32), &p1111);
109        let p111111110000 = Self::pow_2(&p11111111, 4u32);
110        let p111111111111 = Self::mul(&p111111110000, &p1111);
111        let p1111111111111111 = Self::mul(&Self::pow_2(&p111111110000, 4u32), &p11111111);
112        let p1111111111111111111111111111 =
113            Self::mul(&Self::pow_2(&p1111111111111111, 12u32), &p111111111111);
114        let p1111111111111111111111111111101 =
115            Self::mul(&Self::pow_2(&p1111111111111111111111111111, 3u32), &p101);
116        Ok(p1111111111111111111111111111101)
117    }
118
119    /// Returns the division of `a` and `b`.
120    fn div(a: &u32, b: &u32) -> Result<u32, FieldError> {
121        let b_inv = Self::inv(b).map_err(|_| FieldError::DivisionByZero)?;
122        Ok(Self::mul(a, &b_inv))
123    }
124
125    /// Returns a boolean indicating whether `a` and `b` are equal or not.
126    fn eq(a: &u32, b: &u32) -> bool {
127        Self::as_representative(a) == Self::representative(b)
128    }
129
130    /// Returns the additive neutral element.
131    fn zero() -> u32 {
132        0u32
133    }
134
135    /// Returns the multiplicative neutral element.
136    fn one() -> u32 {
137        1u32
138    }
139
140    /// Returns the element `x * 1` where 1 is the multiplicative neutral element.
141    fn from_u64(x: u64) -> u32 {
142        (((((x >> 31) + x + 1) >> 31) + x) & (MERSENNE_31_PRIME_FIELD_ORDER as u64)) as u32
143    }
144
145    /// Takes as input an element of BaseType and returns the internal representation
146    /// of that element in the field.
147    fn from_base_type(x: u32) -> u32 {
148        Self::weak_reduce(x)
149    }
150    fn double(a: &u32) -> u32 {
151        Self::weak_reduce(a << 1)
152    }
153}
154
155impl IsPrimeField for Mersenne31Field {
156    type RepresentativeType = u32;
157
158    // Since our invariant guarantees that `value` fits in 31 bits, there is only one possible value
159    // `value` that is not canonical, namely 2^31 - 1 = p = 0.
160    fn representative(x: &u32) -> u32 {
161        debug_assert!((x >> 31) == 0);
162        Self::as_representative(x)
163    }
164
165    fn field_bit_size() -> usize {
166        ((MERSENNE_31_PRIME_FIELD_ORDER - 1).ilog2() + 1) as usize
167    }
168
169    fn from_hex(hex_string: &str) -> Result<Self::BaseType, CreationError> {
170        let mut hex_string = hex_string;
171        // Remove 0x if it's on the string
172        let mut char_iterator = hex_string.chars();
173        if hex_string.len() > 2
174            && char_iterator.next().unwrap() == '0'
175            && char_iterator.next().unwrap() == 'x'
176        {
177            hex_string = &hex_string[2..];
178        }
179        u32::from_str_radix(hex_string, 16).map_err(|_| CreationError::InvalidHexString)
180    }
181
182    #[cfg(feature = "std")]
183    fn to_hex(x: &u32) -> String {
184        format!("{x:X}")
185    }
186}
187
188impl FieldElement<Mersenne31Field> {
189    #[cfg(feature = "alloc")]
190    pub fn to_bytes_le(&self) -> alloc::vec::Vec<u8> {
191        self.representative().to_le_bytes().to_vec()
192    }
193
194    #[cfg(feature = "alloc")]
195    pub fn to_bytes_be(&self) -> alloc::vec::Vec<u8> {
196        self.representative().to_be_bytes().to_vec()
197    }
198}
199
200impl Display for FieldElement<Mersenne31Field> {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        write!(f, "{:x}", self.representative())
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    type F = Mersenne31Field;
210    type FE = FieldElement<F>;
211
212    #[test]
213    fn mul_power_two_is_correct() {
214        let a = 3u32;
215        let k = 2;
216        let expected_result = FE::from(&a) * FE::from(2).pow(k as u16);
217        let result = F::mul_power_two(a, k);
218        assert_eq!(FE::from(&result), expected_result)
219    }
220
221    #[test]
222    fn mul_power_two_is_correct_2() {
223        let a = 229287u32;
224        let k = 4;
225        let expected_result = FE::from(&a) * FE::from(2).pow(k as u16);
226        let result = F::mul_power_two(a, k);
227        assert_eq!(FE::from(&result), expected_result)
228    }
229
230    #[test]
231    fn pow_2_is_correct() {
232        let a = 3u32;
233        let order = 12;
234        let result = F::pow_2(&a, order);
235        let expected_result = FE::pow(&FE::from(&a), 4096u32);
236        assert_eq!(FE::from(&result), expected_result)
237    }
238
239    #[test]
240    fn from_hex_for_b_is_11() {
241        assert_eq!(F::from_hex("B").unwrap(), 11);
242    }
243
244    #[test]
245    fn from_hex_for_b_is_11_v2() {
246        assert_eq!(FE::from_hex("B").unwrap(), FE::from(11));
247    }
248
249    #[test]
250    fn sum_delayed_reduction() {
251        let up_to = u32::pow(2, 16);
252        let pow = u64::pow(2, 60);
253
254        let iter = (0..up_to).map(F::weak_reduce).map(|e| F::pow(&e, pow));
255
256        assert_eq!(F::from_u64(2142542785), F::sum(iter));
257    }
258
259    #[test]
260    fn from_hex_for_0x1_a_is_26() {
261        assert_eq!(F::from_hex("0x1a").unwrap(), 26);
262    }
263
264    #[test]
265    fn bit_size_of_field_is_31() {
266        assert_eq!(
267            <F as crate::field::traits::IsPrimeField>::field_bit_size(),
268            31
269        );
270    }
271
272    #[test]
273    fn one_plus_1_is_2() {
274        assert_eq!(FE::one() + FE::one(), FE::from(&2u32));
275    }
276
277    #[test]
278    fn neg_1_plus_1_is_0() {
279        assert_eq!(-FE::one() + FE::one(), FE::zero());
280    }
281
282    #[test]
283    fn neg_1_plus_2_is_1() {
284        assert_eq!(-FE::one() + FE::from(&2u32), FE::one());
285    }
286
287    #[test]
288    fn max_order_plus_1_is_0() {
289        assert_eq!(
290            FE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1)) + FE::from(1),
291            FE::from(0)
292        );
293    }
294
295    #[test]
296    fn comparing_13_and_13_are_equal() {
297        assert_eq!(FE::from(&13u32), FE::from(13));
298    }
299
300    #[test]
301    fn comparing_13_and_8_they_are_not_equal() {
302        assert_ne!(FE::from(&13u32), FE::from(8));
303    }
304
305    #[test]
306    fn one_sub_1_is_0() {
307        assert_eq!(FE::one() - FE::one(), FE::zero());
308    }
309
310    #[test]
311    fn zero_sub_1_is_order_minus_1() {
312        assert_eq!(
313            FE::zero() - FE::one(),
314            FE::from(&(MERSENNE_31_PRIME_FIELD_ORDER - 1))
315        );
316    }
317
318    #[test]
319    fn neg_1_sub_neg_1_is_0() {
320        assert_eq!(-FE::one() - (-FE::one()), FE::zero());
321    }
322
323    #[test]
324    fn neg_1_sub_0_is_neg_1() {
325        assert_eq!(-FE::one() - FE::zero(), -FE::one());
326    }
327
328    #[test]
329    fn mul_neutral_element() {
330        assert_eq!(FE::one() * FE::from(&2u32), FE::from(&2u32));
331    }
332
333    #[test]
334    fn mul_2_3_is_6() {
335        assert_eq!(FE::from(&2u32) * FE::from(&3u32), FE::from(&6u32));
336    }
337
338    #[test]
339    fn mul_order_neg_1() {
340        assert_eq!(
341            FE::from(MERSENNE_31_PRIME_FIELD_ORDER as u64 - 1)
342                * FE::from(MERSENNE_31_PRIME_FIELD_ORDER as u64 - 1),
343            FE::one()
344        );
345    }
346
347    #[test]
348    fn pow_p_neg_1() {
349        assert_eq!(
350            FE::pow(&FE::from(&2u32), MERSENNE_31_PRIME_FIELD_ORDER - 1),
351            FE::one()
352        )
353    }
354
355    #[test]
356    fn inv_0_error() {
357        let result = FE::inv(&FE::zero());
358        assert!(matches!(result, Err(FieldError::InvZeroError)));
359    }
360
361    #[test]
362    fn inv_2() {
363        let result = FE::inv(&FE::from(&2u32)).unwrap();
364        // sage: 1 / F(2) = 1073741824
365        assert_eq!(result, FE::from(1073741824));
366    }
367
368    #[test]
369    fn pow_2_3() {
370        assert_eq!(FE::pow(&FE::from(&2u32), 3u64), FE::from(8));
371    }
372
373    #[test]
374    fn div_1() {
375        assert_eq!(
376            (FE::from(&2u32) / FE::from(&1u32)).unwrap(),
377            FE::from(&2u32)
378        );
379    }
380
381    #[test]
382    fn div_4_2() {
383        assert_eq!(
384            (FE::from(&4u32) / FE::from(&2u32)).unwrap(),
385            FE::from(&2u32)
386        );
387    }
388
389    #[test]
390    fn div_4_3() {
391        // sage: F(4) / F(3) = 1431655766
392        assert_eq!(
393            (FE::from(&4u32) / FE::from(&3u32)).unwrap(),
394            FE::from(1431655766)
395        );
396    }
397
398    #[test]
399    fn two_plus_its_additive_inv_is_0() {
400        assert_eq!(FE::from(&2u32) + (-FE::from(&2u32)), FE::zero());
401    }
402
403    #[test]
404    fn from_u64_test() {
405        assert_eq!(FE::from(1u64), FE::one());
406    }
407
408    #[test]
409    fn creating_a_field_element_from_its_representative_returns_the_same_element_1() {
410        let change: u32 = MERSENNE_31_PRIME_FIELD_ORDER + 1;
411        let f1 = FE::from(&change);
412        let f2 = FE::from(&FE::representative(&f1));
413        assert_eq!(f1, f2);
414    }
415
416    #[test]
417    fn creating_a_field_element_from_its_representative_returns_the_same_element_2() {
418        let change: u32 = MERSENNE_31_PRIME_FIELD_ORDER + 8;
419        let f1 = FE::from(&change);
420        let f2 = FE::from(&FE::representative(&f1));
421        assert_eq!(f1, f2);
422    }
423
424    #[test]
425    fn from_base_type_test() {
426        assert_eq!(FE::from(&1u32), FE::one());
427    }
428
429    #[cfg(feature = "std")]
430    #[test]
431    fn to_hex_test() {
432        let num = FE::from_hex("B").unwrap();
433        assert_eq!(FE::to_hex(&num), "B");
434    }
435
436    #[test]
437    fn double_equals_add_itself() {
438        let a = FE::from(1234);
439        assert_eq!(a + a, a.double())
440    }
441
442    #[test]
443    fn two_square_minus_one_is_correct() {
444        let a = FE::from(2147483650);
445        assert_eq!(
446            FE::from(&F::two_square_minus_one(a.value())),
447            a.square().double() - FE::one()
448        )
449    }
450
451    #[test]
452    fn two_square_zero_minus_one_is_minus_one() {
453        let a = FE::from(0);
454        assert_eq!(
455            FE::from(&F::two_square_minus_one(a.value())),
456            a.square().double() - FE::one()
457        )
458    }
459
460    #[test]
461    fn two_square_p_minus_one_is_minus_one() {
462        let a = FE::from(&MERSENNE_31_PRIME_FIELD_ORDER);
463        assert_eq!(
464            FE::from(&F::two_square_minus_one(a.value())),
465            a.square().double() - FE::one()
466        )
467    }
468
469    #[test]
470    fn mul_by_inv() {
471        let x = 3476715743_u32;
472        assert_eq!(FE::from(&x).inv().unwrap() * FE::from(&x), FE::one());
473    }
474}