Skip to main content

primitives/algebra/field/mersenne/
m107.rs

1use std::{
2    iter::{Product, Sum},
3    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
4};
5
6use crypto_bigint::rand_core::RngCore;
7use ff::{Field, PrimeField};
8use hybrid_array::Array;
9use rand::Rng;
10use serde::{Deserialize, Serialize};
11use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption};
12use typenum::{U1, U14, U16};
13use wincode::{SchemaRead, SchemaWrite};
14
15use crate::{
16    algebra::{
17        field::FieldExtension,
18        ops::{AccReduce, ReduceWide},
19        uniform_bytes::FromUniformBytes,
20    },
21    random::{CryptoRngCore, Random},
22    types::{HeapArray, Positive},
23};
24
25mod ff_impl {
26    use ff::PrimeField;
27    use serde::{Deserialize, Serialize};
28
29    #[derive(PrimeField, Serialize, Deserialize)]
30    #[PrimeFieldModulus = "162259276829213363391578010288127"]
31    #[PrimeFieldGenerator = "3"]
32    #[PrimeFieldReprEndianness = "little"]
33    pub struct Mersenne107FF([u64; 2]);
34}
35
36#[derive(
37    Copy, Clone, Default, Debug, SchemaRead, SchemaWrite, PartialEq, Eq, Hash, Ord, PartialOrd,
38)]
39#[repr(C)]
40pub struct Mersenne107(pub(super) u128);
41
42impl Serialize for Mersenne107 {
43    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
44    where
45        S: serde::Serializer,
46    {
47        self.as_le_array().serialize(serializer)
48    }
49}
50
51impl<'de> Deserialize<'de> for Mersenne107 {
52    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
53    where
54        D: serde::Deserializer<'de>,
55    {
56        let arr = <[u8; 14]>::deserialize(deserializer)?;
57        Self::from_canonical_bytes(&arr).ok_or_else(|| {
58            serde::de::Error::custom("Invalid Mersenne107 canonical byte representation")
59        })
60    }
61}
62
63impl Mersenne107 {
64    pub const NUM_BITS: usize = 107;
65    pub const MODULUS: u128 = (1u128 << Self::NUM_BITS) - 1;
66    pub const MAX: u128 = Self::MODULUS - 1;
67
68    fn as_le_array(&self) -> [u8; 14] {
69        let mut arr = [0u8; 14];
70        arr[..14].copy_from_slice(&self.0.to_le_bytes()[..14]);
71        arr
72    }
73
74    fn from_canonical_bytes(arr: &[u8; 14]) -> Option<Self> {
75        let mut tmp = [0u8; 16];
76        tmp[..14].copy_from_slice(arr);
77        let val = u128::from_le_bytes(tmp);
78        (val < Self::MODULUS).then_some(Self(val))
79    }
80}
81
82///////////////////////////////////////////////////////////////////////////////////////////////////
83// Arithmetic ops - multiplication
84///////////////////////////////////////////////////////////////////////////////////////////////////
85
86#[macros::op_variants(owned)]
87impl<'a> MulAssign<&'a Mersenne107> for Mersenne107 {
88    #[inline]
89    fn mul_assign(&mut self, rhs: &'a Mersenne107) {
90        self.0 = super::m107_ops::mul(self.0, rhs.0);
91    }
92}
93
94#[macros::op_variants(owned)]
95impl<'a> Mul<&'a Mersenne107> for Mersenne107 {
96    type Output = Self;
97    #[inline]
98    fn mul(self, rhs: &'a Mersenne107) -> Self::Output {
99        let mut res = self;
100        res.mul_assign(rhs);
101        res
102    }
103}
104
105///////////////////////////////////////////////////////////////////////////////////////////////////
106// Arithmetic ops - addition
107///////////////////////////////////////////////////////////////////////////////////////////////////
108
109#[macros::op_variants(owned)]
110impl<'a> AddAssign<&'a Mersenne107> for Mersenne107 {
111    #[inline]
112    fn add_assign(&mut self, rhs: &'a Mersenne107) {
113        self.0 += rhs.0;
114        super::m107_ops::reduce_mod_1bit_inplace(&mut self.0);
115    }
116}
117#[macros::op_variants(owned)]
118impl<'a> Add<&'a Mersenne107> for Mersenne107 {
119    type Output = Self;
120
121    #[inline]
122    fn add(self, rhs: &'a Mersenne107) -> Self::Output {
123        let mut res = self;
124        res.add_assign(rhs);
125        res
126    }
127}
128
129///////////////////////////////////////////////////////////////////////////////////////////////////
130// Arithmetic ops - substraction
131///////////////////////////////////////////////////////////////////////////////////////////////////
132
133#[macros::op_variants(owned)]
134impl<'a> SubAssign<&'a Mersenne107> for Mersenne107 {
135    #[inline]
136    fn sub_assign(&mut self, rhs: &'a Mersenne107) {
137        self.0 += Self::MODULUS - rhs.0;
138        super::m107_ops::reduce_mod_1bit_inplace(&mut self.0);
139    }
140}
141
142#[macros::op_variants(owned)]
143impl<'a> Sub<&'a Mersenne107> for Mersenne107 {
144    type Output = Self;
145
146    #[inline]
147    fn sub(mut self, rhs: &'a Mersenne107) -> Self::Output {
148        self.sub_assign(rhs);
149        self
150    }
151}
152
153///////////////////////////////////////////////////////////////////////////////////////////////////
154// Arithmetic ops - negation
155///////////////////////////////////////////////////////////////////////////////////////////////////
156
157#[macros::op_variants(borrowed)]
158impl Neg for Mersenne107 {
159    type Output = Mersenne107;
160
161    fn neg(self) -> Self::Output {
162        Self(super::m107_ops::reduce_mod_1bit(Self::MODULUS - self.0))
163    }
164}
165
166///////////////////////////////////////////////////////////////////////////////////////////////////
167// Constant time
168///////////////////////////////////////////////////////////////////////////////////////////////////
169
170impl ConditionallySelectable for Mersenne107 {
171    fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
172        Self(u128::conditional_select(&a.0, &b.0, choice))
173    }
174}
175
176impl ConstantTimeEq for Mersenne107 {
177    fn ct_eq(&self, other: &Self) -> Choice {
178        self.0.ct_eq(&other.0)
179    }
180}
181
182///////////////////////////////////////////////////////////////////////////////////////////////////
183// Iterator operations
184///////////////////////////////////////////////////////////////////////////////////////////////////
185
186impl Sum for Mersenne107 {
187    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
188        let t = iter.fold(<Self as AccReduce>::zero_wide(), |mut acc, x| {
189            Self::acc(&mut acc, &x);
190            acc
191        });
192        Self::reduce_mod_order(t)
193    }
194}
195
196impl<'a> Sum<&'a Self> for Mersenne107 {
197    fn sum<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
198        let t = iter.fold(<Self as AccReduce>::zero_wide(), |mut acc, x| {
199            Self::acc(&mut acc, x);
200            acc
201        });
202        Self::reduce_mod_order(t)
203    }
204}
205
206impl<'a> Product<&'a Self> for Mersenne107 {
207    fn product<I: Iterator<Item = &'a Self>>(iter: I) -> Self {
208        iter.fold(Self::ONE, |acc, x| acc * x)
209    }
210}
211
212impl Product for Mersenne107 {
213    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
214        iter.fold(Self::ONE, |acc, x| acc * x)
215    }
216}
217
218///////////////////////////////////////////////////////////////////////////////////////////////////
219// Field trait
220///////////////////////////////////////////////////////////////////////////////////////////////////
221
222impl Field for Mersenne107 {
223    const ZERO: Self = Mersenne107(0);
224    const ONE: Self = Mersenne107(1);
225
226    fn random(mut rng: impl RngCore) -> Self {
227        let tmp = rng.gen::<u128>();
228        Self(super::m107_ops::reduce_mod(tmp)) // the probability is skewed because M107 modulus
229                                               // does not
230                                               // divide 2^128 exactly
231    }
232
233    fn square(&self) -> Self {
234        *self * self // TODO: optimize ?
235    }
236
237    fn double(&self) -> Self {
238        Self(super::m107_ops::reduce_mod_1bit(self.0 << 1))
239    }
240
241    fn invert(&self) -> CtOption<Self> {
242        // Fallback to ff implementation
243        // TODO: see if we can optimize this without ff
244        let val: ff_impl::Mersenne107FF = self.into();
245        let inv = val.invert();
246        inv.map(|v| v.into())
247    }
248
249    fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
250        let num = num.into();
251        let div = div.into();
252        let (choice, val) = ff_impl::Mersenne107FF::sqrt_ratio(&num, &div);
253        (choice, val.into())
254    }
255
256    fn sqrt(&self) -> CtOption<Self> {
257        // Fallback to ff implementation
258        // TODO: see if we can optimize this without ff
259        let val: ff_impl::Mersenne107FF = self.into();
260        let inv = val.sqrt();
261        inv.map(|v| v.into())
262    }
263}
264
265///////////////////////////////////////////////////////////////////////////////////////////////////
266// Field extension trait
267///////////////////////////////////////////////////////////////////////////////////////////////////
268
269impl FieldExtension for Mersenne107 {
270    type Subfield = Self;
271    type Degree = U1;
272    type FieldBitSize = typenum::U<{ Mersenne107::NUM_BITS }>;
273    type FieldBytesSize = U14;
274
275    fn to_subfield_elements(&self) -> impl ExactSizeIterator<Item = Self::Subfield> {
276        std::iter::once(*self)
277    }
278
279    fn from_subfield_elements(elems: &[Self::Subfield]) -> Option<Self> {
280        if elems.len() == 1 {
281            elems.first().copied()
282        } else {
283            None
284        }
285    }
286
287    fn to_le_bytes(&self) -> Array<u8, Self::FieldBytesSize> {
288        self.as_le_array().into()
289    }
290
291    fn from_le_bytes(bytes: &[u8]) -> Option<Self> {
292        if bytes.len() == 14 {
293            let arr: &[u8; 14] = bytes.try_into().expect("This should never fail");
294            Self::from_canonical_bytes(arr)
295        } else {
296            None
297        }
298    }
299
300    fn mul_by_subfield(&self, other: &Self::Subfield) -> Self {
301        *self * other
302    }
303
304    fn generator() -> Self {
305        Self(3u128) // 3^(2^107-2) == 1 mod 2^107-1
306    }
307}
308
309impl Random for Mersenne107 {
310    fn random(mut rng: impl CryptoRngCore) -> Self {
311        let tmp = rng.gen::<u128>();
312        Self(super::m107_ops::reduce_mod(tmp))
313    }
314
315    fn random_array<M: Positive>(mut rng: impl CryptoRngCore) -> HeapArray<Self, M> {
316        let mut buf = HeapArray::<Self, M>::default().into_box_bytes();
317        rng.fill_bytes(&mut buf);
318        let mut tmp = HeapArray::from_box_bytes(buf);
319        tmp.iter_mut()
320            .for_each(|v: &mut Self| super::m107_ops::reduce_mod_inplace(&mut v.0));
321        tmp
322    }
323}
324
325unsafe impl bytemuck::Zeroable for Mersenne107 {}
326unsafe impl bytemuck::Pod for Mersenne107 {}
327
328impl FromUniformBytes for Mersenne107 {
329    type UniformBytes = U16;
330    fn from_uniform_bytes(bytes: &hybrid_array::Array<u8, Self::UniformBytes>) -> Self {
331        let mut val = u128::from_le_bytes(bytes.0);
332        super::m107_ops::reduce_mod_inplace(&mut val);
333        Self(val)
334    }
335}
336
337impl From<u64> for Mersenne107 {
338    fn from(val: u64) -> Self {
339        Self(val as u128)
340    }
341}
342
343impl From<u128> for Mersenne107 {
344    fn from(val: u128) -> Self {
345        Self(super::m107_ops::reduce_mod(val))
346    }
347}
348
349impl From<ff_impl::Mersenne107FF> for Mersenne107 {
350    fn from(val: ff_impl::Mersenne107FF) -> Self {
351        Self::from_le_bytes(&val.to_repr().as_ref()[..14]).unwrap()
352    }
353}
354
355impl<'a> From<&'a Mersenne107> for ff_impl::Mersenne107FF {
356    fn from(val: &'a Mersenne107) -> Self {
357        Self::from_repr(ff_impl::Mersenne107FFRepr(val.0.to_le_bytes())).unwrap()
358    }
359}
360
361#[cfg(test)]
362mod test {
363    use ff::Field;
364    use num_bigint::BigInt;
365    use typenum::Unsigned;
366
367    use crate::{
368        algebra::field::{
369            mersenne::{m107::Mersenne107, test::bigint_to_m107},
370            FieldExtension,
371        },
372        random::test_rng,
373    };
374
375    type M = typenum::U1000;
376
377    #[test]
378    fn test_neg() {
379        fn test_internal(a: Mersenne107) {
380            let exp = bigint_to_m107(-BigInt::from(a.0));
381            let act = -a;
382            assert_eq!(exp, act, "a = {a:?}");
383        }
384
385        let mut rng = test_rng();
386        for _ in 0..M::to_usize() {
387            let a = Mersenne107::random(&mut rng);
388            test_internal(a);
389        }
390
391        // Corner cases
392        test_internal(Mersenne107::ZERO);
393        test_internal(Mersenne107::ONE);
394        test_internal(Mersenne107(Mersenne107::MAX));
395    }
396
397    #[test]
398    fn test_invert() {
399        fn test_internal(a: Mersenne107) {
400            let a_inv = a.invert().unwrap();
401            let act = a * a_inv;
402            assert_eq!(Mersenne107::ONE, act, "a = {a:?}");
403        }
404
405        let mut rng = test_rng();
406        for _ in 0..M::to_usize() {
407            let a = Mersenne107::random(&mut rng);
408            if a == Mersenne107::ZERO {
409                continue;
410            }
411            test_internal(a);
412        }
413
414        // Corner cases
415        test_internal(Mersenne107::ONE);
416        test_internal(Mersenne107(Mersenne107::MAX));
417    }
418
419    #[test]
420    fn test_sqrt() {
421        fn test_internal(a: Mersenne107) {
422            let a_sqrt = a.sqrt();
423            if a_sqrt.into_option().is_none() {
424                return;
425            }
426
427            let a_sqrt = a_sqrt.unwrap();
428            let act = a_sqrt * a_sqrt;
429            assert_eq!(a, act, "a = {a:?}");
430        }
431
432        let mut rng = test_rng();
433        for _ in 0..M::to_usize() {
434            let a = Mersenne107::random(&mut rng);
435            test_internal(a);
436        }
437
438        // Corner cases
439        test_internal(Mersenne107::ZERO);
440        test_internal(Mersenne107::ONE);
441        test_internal(Mersenne107(Mersenne107::MAX));
442    }
443
444    #[test]
445    fn test_canonical_bytes_decoding() {
446        let value = Mersenne107::from(123456789u64);
447        let bytes = value.to_le_bytes();
448        assert_eq!(Mersenne107::from_le_bytes(&bytes), Some(value));
449
450        // Encoding of the modulus p = 2^107 - 1 is non-canonical and must be rejected.
451        let modulus_bytes = [
452            0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x07,
453        ];
454        assert_eq!(Mersenne107::from_le_bytes(&modulus_bytes), None);
455    }
456
457    macro_rules! test_op {
458        ($op:tt) => {
459                fn test_internal(a: Mersenne107, b: Mersenne107) {
460                    let exp = bigint_to_m107(BigInt::from(a.0) $op BigInt::from(b.0));
461                    let act = a $op b;
462                    assert_eq!(exp, act, "a = {a:?}, b = {b:?}");
463                }
464
465                let mut rng = test_rng();
466                for _ in 0..M::to_usize() {
467                    let a = Mersenne107::random(&mut rng);
468                    let b = Mersenne107::random(&mut rng);
469                    test_internal(a, b);
470                }
471
472                // Corner cases
473                test_internal(Mersenne107::ZERO, Mersenne107::ZERO);
474                test_internal(Mersenne107::ZERO, Mersenne107::ONE);
475                test_internal(Mersenne107::ONE, Mersenne107::ZERO);
476                test_internal(Mersenne107::ONE, Mersenne107::ONE);
477                test_internal(Mersenne107::ZERO, Mersenne107(Mersenne107::MAX));
478                test_internal(Mersenne107::ONE, Mersenne107(Mersenne107::MAX));
479                test_internal(Mersenne107(Mersenne107::MAX), Mersenne107::ZERO);
480                test_internal(Mersenne107(Mersenne107::MAX), Mersenne107::ONE);
481                test_internal(Mersenne107(Mersenne107::MAX), Mersenne107(Mersenne107::MAX));
482            }
483    }
484
485    #[test]
486    fn test_mul() {
487        test_op!(*);
488    }
489    #[test]
490    fn test_add() {
491        test_op!(+);
492    }
493    #[test]
494    fn test_sub() {
495        test_op!(-);
496    }
497}