fhe_math/zq/
mod.rs

1#![warn(missing_docs, unused_imports)]
2
3//! Ring operations for moduli up to 62 bits.
4
5pub mod primes;
6
7use std::ops::Deref;
8
9use crate::errors::{Error, Result};
10use fhe_util::{is_prime, transcode_from_bytes, transcode_to_bytes};
11use itertools::{izip, Itertools};
12use num_bigint::BigUint;
13use num_traits::cast::ToPrimitive;
14use pulp::Arch;
15use rand::{distr::Uniform, CryptoRng, Rng, RngCore};
16
17/// cond ? on_true : on_false
18const fn const_time_cond_select(on_true: u64, on_false: u64, cond: bool) -> u64 {
19    let mask = -(cond as i64) as u64;
20    let diff = on_true ^ on_false;
21    (diff & mask) ^ on_false
22}
23
24/// Structure encapsulating an integer modulus up to 62 bits.
25#[derive(Debug, Clone)]
26pub struct Modulus {
27    pub(crate) p: u64,
28    barrett_hi: u64,
29    barrett_lo: u64,
30    leading_zeros: u32,
31    pub(crate) supports_opt: bool,
32    distribution: Uniform<u64>,
33    arch: Arch,
34}
35
36// We need to declare Eq manually because of the `Uniform` member.
37impl Eq for Modulus {}
38
39impl PartialEq for Modulus {
40    fn eq(&self, other: &Self) -> bool {
41        self.p == other.p
42    }
43}
44
45// Override the dereference to return the underlying modulus.
46impl Deref for Modulus {
47    type Target = u64;
48
49    fn deref(&self) -> &Self::Target {
50        &self.p
51    }
52}
53
54impl Modulus {
55    /// Create a modulus from an integer of at most 62 bits.
56    pub fn new(p: u64) -> Result<Self> {
57        if p < 2 || (p >> 62) != 0 {
58            Err(Error::InvalidModulus(p))
59        } else {
60            let barrett = ((BigUint::from(1u64) << 128usize) / p).to_u128().unwrap(); // 2^128 / p
61            Ok(Self {
62                p,
63                barrett_hi: (barrett >> 64) as u64,
64                barrett_lo: barrett as u64,
65                leading_zeros: p.leading_zeros(),
66                supports_opt: primes::supports_opt(p),
67                distribution: Uniform::new(0, p).unwrap(),
68                arch: Arch::new(),
69            })
70        }
71    }
72
73    /// Performs the modular addition of a and b in constant time.
74    /// Aborts if a >= p or b >= p in debug mode.
75    pub const fn add(&self, a: u64, b: u64) -> u64 {
76        debug_assert!(a < self.p && b < self.p);
77        Self::reduce1(a + b, self.p)
78    }
79
80    /// Performs the modular addition of a and b in variable time.
81    /// Aborts if a >= p or b >= p in debug mode.
82    ///
83    /// # Safety
84    /// This function is not constant time and its timing may reveal information
85    /// about the values being added.
86    pub const unsafe fn add_vt(&self, a: u64, b: u64) -> u64 {
87        debug_assert!(a < self.p && b < self.p);
88        Self::reduce1_vt(a + b, self.p)
89    }
90
91    /// Performs the modular subtraction of a and b in constant time.
92    /// Aborts if a >= p or b >= p in debug mode.
93    pub const fn sub(&self, a: u64, b: u64) -> u64 {
94        debug_assert!(a < self.p && b < self.p);
95        Self::reduce1(a + self.p - b, self.p)
96    }
97
98    /// Performs the modular subtraction of a and b in constant time.
99    /// Aborts if a >= p or b >= p in debug mode.
100    ///
101    /// # Safety
102    /// This function is not constant time and its timing may reveal information
103    /// about the values being subtracted.
104    const unsafe fn sub_vt(&self, a: u64, b: u64) -> u64 {
105        debug_assert!(a < self.p && b < self.p);
106        Self::reduce1_vt(a + self.p - b, self.p)
107    }
108
109    /// Performs the modular multiplication of a and b in constant time.
110    /// Aborts if a >= p or b >= p in debug mode.
111    pub const fn mul(&self, a: u64, b: u64) -> u64 {
112        debug_assert!(a < self.p && b < self.p);
113        self.reduce_u128((a as u128) * (b as u128))
114    }
115
116    /// Performs the modular multiplication of a and b in constant time.
117    /// Aborts if a >= p or b >= p in debug mode.
118    ///
119    /// # Safety
120    /// This function is not constant time and its timing may reveal information
121    /// about the values being multiplied.
122    const unsafe fn mul_vt(&self, a: u64, b: u64) -> u64 {
123        debug_assert!(a < self.p && b < self.p);
124        Self::reduce1_vt(self.lazy_reduce_u128((a as u128) * (b as u128)), self.p)
125    }
126
127    /// Optimized modular multiplication of a and b in constant time.
128    ///
129    /// Aborts if a >= p or b >= p in debug mode.
130    pub const fn mul_opt(&self, a: u64, b: u64) -> u64 {
131        debug_assert!(self.supports_opt);
132        debug_assert!(a < self.p && b < self.p);
133
134        self.reduce_opt_u128((a as u128) * (b as u128))
135    }
136
137    /// Optimized modular multiplication of a and b in variable time.
138    /// Aborts if a >= p or b >= p in debug mode.
139    ///
140    /// # Safety
141    /// This function is not constant time and its timing may reveal information
142    /// about the values being multiplied.
143    const unsafe fn mul_opt_vt(&self, a: u64, b: u64) -> u64 {
144        debug_assert!(self.supports_opt);
145        debug_assert!(a < self.p && b < self.p);
146
147        self.reduce_opt_u128_vt((a as u128) * (b as u128))
148    }
149
150    /// Modular negation in constant time.
151    ///
152    /// Aborts if a >= p in debug mode.
153    pub const fn neg(&self, a: u64) -> u64 {
154        debug_assert!(a < self.p);
155        Self::reduce1(self.p - a, self.p)
156    }
157
158    /// Modular negation in variable time.
159    /// Aborts if a >= p in debug mode.
160    ///
161    /// # Safety
162    /// This function is not constant time and its timing may reveal information
163    /// about the value being negated.
164    const unsafe fn neg_vt(&self, a: u64) -> u64 {
165        debug_assert!(a < self.p);
166        Self::reduce1_vt(self.p - a, self.p)
167    }
168
169    /// Compute the Shoup representation of a.
170    ///
171    /// Aborts if a >= p in debug mode.
172    pub const fn shoup(&self, a: u64) -> u64 {
173        debug_assert!(a < self.p);
174
175        (((a as u128) << 64) / (self.p as u128)) as u64
176    }
177
178    /// Shoup multiplication of a and b in constant time.
179    ///
180    /// Aborts if b >= p or b_shoup != shoup(b) in debug mode.
181    pub const fn mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
182        Self::reduce1(self.lazy_mul_shoup(a, b, b_shoup), self.p)
183    }
184
185    /// Shoup multiplication of a and b in variable time.
186    /// Aborts if b >= p or b_shoup != shoup(b) in debug mode.
187    ///
188    /// # Safety
189    /// This function is not constant time and its timing may reveal information
190    /// about the values being multiplied.
191    const unsafe fn mul_shoup_vt(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
192        Self::reduce1_vt(self.lazy_mul_shoup(a, b, b_shoup), self.p)
193    }
194
195    /// Lazy Shoup multiplication of a and b in constant time.
196    /// The output is in the interval [0, 2 * p).
197    ///
198    /// Aborts if b >= p or b_shoup != shoup(b) in debug mode.
199    pub const fn lazy_mul_shoup(&self, a: u64, b: u64, b_shoup: u64) -> u64 {
200        debug_assert!(b < self.p);
201        debug_assert!(b_shoup == self.shoup(b));
202
203        let q = ((a as u128) * (b_shoup as u128)) >> 64;
204        let r = ((a as u128) * (b as u128) - q * (self.p as u128)) as u64;
205
206        debug_assert!(r < 2 * self.p);
207
208        r
209    }
210
211    /// Modular addition of vectors in place in constant time.
212    ///
213    /// Aborts if a and b differ in size, and if any of their values is >= p in
214    /// debug mode.
215    pub fn add_vec(&self, a: &mut [u64], b: &[u64]) {
216        debug_assert_eq!(a.len(), b.len());
217        self.arch.dispatch(|| {
218            izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add(*ai, *bi))
219        })
220    }
221
222    /// Modular addition of vectors in place in variable time.
223    /// Aborts if a and b differ in size, and if any of their values is >= p in
224    /// debug mode.
225    ///
226    /// # Safety
227    /// This function is not constant time and its timing may reveal information
228    /// about the values being added.
229    pub unsafe fn add_vec_vt(&self, a: &mut [u64], b: &[u64]) {
230        let n = a.len();
231        debug_assert_eq!(n, b.len());
232
233        let p = self.p;
234        macro_rules! add_at {
235            ($idx:expr) => {
236                *a.get_unchecked_mut($idx) =
237                    Self::reduce1_vt(*a.get_unchecked_mut($idx) + *b.get_unchecked($idx), p);
238            };
239        }
240
241        if n % 16 == 0 {
242            self.arch.dispatch(|| {
243                for i in 0..n / 16 {
244                    add_at!(16 * i);
245                    add_at!(16 * i + 1);
246                    add_at!(16 * i + 2);
247                    add_at!(16 * i + 3);
248                    add_at!(16 * i + 4);
249                    add_at!(16 * i + 5);
250                    add_at!(16 * i + 6);
251                    add_at!(16 * i + 7);
252                    add_at!(16 * i + 8);
253                    add_at!(16 * i + 9);
254                    add_at!(16 * i + 10);
255                    add_at!(16 * i + 11);
256                    add_at!(16 * i + 12);
257                    add_at!(16 * i + 13);
258                    add_at!(16 * i + 14);
259                    add_at!(16 * i + 15);
260                }
261            })
262        } else {
263            self.arch.dispatch(|| {
264                izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.add_vt(*ai, *bi))
265            })
266        }
267    }
268
269    /// Modular subtraction of vectors in place in constant time.
270    ///
271    /// Aborts if a and b differ in size, and if any of their values is >= p in
272    /// debug mode.
273    pub fn sub_vec(&self, a: &mut [u64], b: &[u64]) {
274        debug_assert_eq!(a.len(), b.len());
275        self.arch.dispatch(|| {
276            izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub(*ai, *bi))
277        })
278    }
279
280    /// Modular subtraction of vectors in place in variable time.
281    /// Aborts if a and b differ in size, and if any of their values is >= p in
282    /// debug mode.
283    ///
284    /// # Safety
285    /// This function is not constant time and its timing may reveal information
286    /// about the values being subtracted.
287    pub unsafe fn sub_vec_vt(&self, a: &mut [u64], b: &[u64]) {
288        let n = a.len();
289        debug_assert_eq!(n, b.len());
290
291        let p = self.p;
292        macro_rules! sub_at {
293            ($idx:expr) => {
294                *a.get_unchecked_mut($idx) =
295                    Self::reduce1_vt(p + *a.get_unchecked_mut($idx) - *b.get_unchecked($idx), p);
296            };
297        }
298
299        if n % 16 == 0 {
300            self.arch.dispatch(|| {
301                for i in 0..n / 16 {
302                    sub_at!(16 * i);
303                    sub_at!(16 * i + 1);
304                    sub_at!(16 * i + 2);
305                    sub_at!(16 * i + 3);
306                    sub_at!(16 * i + 4);
307                    sub_at!(16 * i + 5);
308                    sub_at!(16 * i + 6);
309                    sub_at!(16 * i + 7);
310                    sub_at!(16 * i + 8);
311                    sub_at!(16 * i + 9);
312                    sub_at!(16 * i + 10);
313                    sub_at!(16 * i + 11);
314                    sub_at!(16 * i + 12);
315                    sub_at!(16 * i + 13);
316                    sub_at!(16 * i + 14);
317                    sub_at!(16 * i + 15);
318                }
319            })
320        } else {
321            self.arch.dispatch(|| {
322                izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.sub_vt(*ai, *bi))
323            })
324        }
325    }
326
327    /// Modular multiplication of vectors in place in constant time.
328    ///
329    /// Aborts if a and b differ in size, and if any of their values is >= p in
330    /// debug mode.
331    pub fn mul_vec(&self, a: &mut [u64], b: &[u64]) {
332        debug_assert_eq!(a.len(), b.len());
333
334        if self.supports_opt {
335            self.arch.dispatch(|| {
336                izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt(*ai, *bi))
337            })
338        } else {
339            self.arch.dispatch(|| {
340                izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul(*ai, *bi))
341            })
342        }
343    }
344
345    /// Modular scalar multiplication of vectors in place in constant time.
346    ///
347    /// Aborts if any of the values in a is >= p in debug mode.
348    pub fn scalar_mul_vec(&self, a: &mut [u64], b: u64) {
349        let b_shoup = self.shoup(b);
350        self.arch.dispatch(|| {
351            a.iter_mut()
352                .for_each(|ai| *ai = self.mul_shoup(*ai, b, b_shoup))
353        })
354    }
355
356    /// Modular scalar multiplication of vectors in place in variable time.
357    /// Aborts if any of the values in a is >= p in debug mode.
358    ///
359    /// # Safety
360    /// This function is not constant time and its timing may reveal information
361    /// about the values being multiplied.
362    pub unsafe fn scalar_mul_vec_vt(&self, a: &mut [u64], b: u64) {
363        let b_shoup = self.shoup(b);
364        self.arch.dispatch(|| {
365            a.iter_mut()
366                .for_each(|ai| *ai = self.mul_shoup_vt(*ai, b, b_shoup))
367        })
368    }
369
370    /// Modular multiplication of vectors in place in variable time.
371    /// Aborts if a and b differ in size, and if any of their values is >= p in
372    /// debug mode.
373    ///
374    /// # Safety
375    /// This function is not constant time and its timing may reveal information
376    /// about the values being subtracted.
377    pub unsafe fn mul_vec_vt(&self, a: &mut [u64], b: &[u64]) {
378        debug_assert_eq!(a.len(), b.len());
379
380        if self.supports_opt {
381            self.arch.dispatch(|| {
382                izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_opt_vt(*ai, *bi))
383            })
384        } else {
385            self.arch.dispatch(|| {
386                izip!(a.iter_mut(), b.iter()).for_each(|(ai, bi)| *ai = self.mul_vt(*ai, *bi))
387            })
388        }
389    }
390
391    /// Compute the Shoup representation of a vector.
392    ///
393    /// Aborts if any of the values of the vector is >= p in debug mode.
394    pub fn shoup_vec(&self, a: &[u64]) -> Vec<u64> {
395        self.arch
396            .dispatch(|| a.iter().map(|ai| self.shoup(*ai)).collect_vec())
397    }
398
399    /// Shoup modular multiplication of vectors in place in constant time.
400    ///
401    /// Aborts if a and b differ in size, and if any of their values is >= p in
402    /// debug mode.
403    pub fn mul_shoup_vec(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) {
404        debug_assert_eq!(a.len(), b.len());
405        debug_assert_eq!(a.len(), b_shoup.len());
406        debug_assert_eq!(&b_shoup, &self.shoup_vec(b));
407
408        self.arch.dispatch(|| {
409            izip!(a.iter_mut(), b.iter(), b_shoup.iter())
410                .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup(*ai, *bi, *bi_shoup))
411        })
412    }
413
414    /// Shoup modular multiplication of vectors in place in variable time.
415    /// Aborts if a and b differ in size, and if any of their values is >= p in
416    /// debug mode.
417    ///
418    /// # Safety
419    /// This function is not constant time and its timing may reveal information
420    /// about the values being multiplied.
421    pub unsafe fn mul_shoup_vec_vt(&self, a: &mut [u64], b: &[u64], b_shoup: &[u64]) {
422        debug_assert_eq!(a.len(), b.len());
423        debug_assert_eq!(a.len(), b_shoup.len());
424        debug_assert_eq!(&b_shoup, &self.shoup_vec(b));
425
426        self.arch.dispatch(|| {
427            izip!(a.iter_mut(), b.iter(), b_shoup.iter())
428                .for_each(|(ai, bi, bi_shoup)| *ai = self.mul_shoup_vt(*ai, *bi, *bi_shoup))
429        })
430    }
431
432    /// Reduce a vector in place in constant time.
433    pub fn reduce_vec(&self, a: &mut [u64]) {
434        self.arch
435            .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce(*ai)))
436    }
437
438    /// Center a value modulo p as i64 in variable time.
439    /// TODO: To test and to make constant time?
440    ///
441    /// # Safety
442    /// This function is not constant time and its timing may reveal information
443    /// about the value being centered.
444    const unsafe fn center_vt(&self, a: u64) -> i64 {
445        debug_assert!(a < self.p);
446
447        if a >= self.p >> 1 {
448            (a as i64) - (self.p as i64)
449        } else {
450            a as i64
451        }
452    }
453
454    /// Center a vector in variable time.
455    ///
456    /// # Safety
457    /// This function is not constant time and its timing may reveal information
458    /// about the values being centered.
459    pub unsafe fn center_vec_vt(&self, a: &[u64]) -> Vec<i64> {
460        self.arch
461            .dispatch(|| a.iter().map(|ai| self.center_vt(*ai)).collect_vec())
462    }
463
464    /// Reduce a vector in place in variable time.
465    ///
466    /// # Safety
467    /// This function is not constant time and its timing may reveal information
468    /// about the values being reduced.
469    pub unsafe fn reduce_vec_vt(&self, a: &mut [u64]) {
470        self.arch
471            .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.reduce_vt(*ai)))
472    }
473
474    /// Modular reduction of a i64 in constant time.
475    const fn reduce_i64(&self, a: i64) -> u64 {
476        self.reduce_u128((((self.p as i128) << 64) + (a as i128)) as u128)
477    }
478
479    /// Modular reduction of a i64 in variable time.
480    ///
481    /// # Safety
482    /// This function is not constant time and its timing may reveal information
483    /// about the values being reduced.
484    const unsafe fn reduce_i64_vt(&self, a: i64) -> u64 {
485        self.reduce_u128_vt((((self.p as i128) << 64) + (a as i128)) as u128)
486    }
487
488    /// Reduce a vector in place in constant time.
489    pub fn reduce_vec_i64(&self, a: &[i64]) -> Vec<u64> {
490        self.arch
491            .dispatch(|| a.iter().map(|ai| self.reduce_i64(*ai)).collect_vec())
492    }
493
494    /// Reduce a vector in place in variable time.
495    ///
496    /// # Safety
497    /// This function is not constant time and its timing may reveal information
498    /// about the values being reduced.
499    pub unsafe fn reduce_vec_i64_vt(&self, a: &[i64]) -> Vec<u64> {
500        self.arch
501            .dispatch(|| a.iter().map(|ai| self.reduce_i64_vt(*ai)).collect())
502    }
503
504    /// Reduce a vector in constant time.
505    pub fn reduce_vec_new(&self, a: &[u64]) -> Vec<u64> {
506        self.arch
507            .dispatch(|| a.iter().map(|ai| self.reduce(*ai)).collect())
508    }
509
510    /// Reduce a vector in variable time.
511    ///
512    /// # Safety
513    /// This function is not constant time and its timing may reveal information
514    /// about the values being reduced.
515    pub unsafe fn reduce_vec_new_vt(&self, a: &[u64]) -> Vec<u64> {
516        self.arch
517            .dispatch(|| a.iter().map(|bi| self.reduce_vt(*bi)).collect())
518    }
519
520    /// Modular negation of a vector in place in constant time.
521    ///
522    /// Aborts if any of the values in the vector is >= p in debug mode.
523    pub fn neg_vec(&self, a: &mut [u64]) {
524        self.arch
525            .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.neg(*ai)))
526    }
527
528    /// Modular negation of a vector in place in variable time.
529    /// Aborts if any of the values in the vector is >= p in debug mode.
530    ///
531    /// # Safety
532    /// This function is not constant time and its timing may reveal information
533    /// about the values being negated.
534    pub unsafe fn neg_vec_vt(&self, a: &mut [u64]) {
535        self.arch
536            .dispatch(|| a.iter_mut().for_each(|ai| *ai = self.neg_vt(*ai)))
537    }
538
539    /// Modular exponentiation in variable time.
540    ///
541    /// Aborts if a >= p or n >= p in debug mode.
542    pub fn pow(&self, a: u64, n: u64) -> u64 {
543        debug_assert!(a < self.p && n < self.p);
544
545        if n == 0 {
546            1
547        } else if n == 1 {
548            a
549        } else {
550            let mut r = a;
551            let mut i = (62 - n.leading_zeros()) as isize;
552            while i >= 0 {
553                r = self.mul(r, r);
554                if (n >> i) & 1 == 1 {
555                    r = self.mul(r, a);
556                }
557                i -= 1;
558            }
559            r
560        }
561    }
562
563    /// Modular inversion in variable time.
564    ///
565    /// Returns None if p is not prime or a = 0.
566    /// Aborts if a >= p in debug mode.
567    pub fn inv(&self, a: u64) -> std::option::Option<u64> {
568        if !is_prime(self.p) || a == 0 {
569            None
570        } else {
571            let r = self.pow(a, self.p - 2);
572            debug_assert_eq!(self.mul(a, r), 1);
573            Some(r)
574        }
575    }
576
577    /// Modular reduction of a u128 in constant time.
578    pub const fn reduce_u128(&self, a: u128) -> u64 {
579        Self::reduce1(self.lazy_reduce_u128(a), self.p)
580    }
581
582    /// Modular reduction of a u128 in variable time.
583    ///
584    /// # Safety
585    /// This function is not constant time and its timing may reveal information
586    /// about the value being reduced.
587    pub const unsafe fn reduce_u128_vt(&self, a: u128) -> u64 {
588        Self::reduce1_vt(self.lazy_reduce_u128(a), self.p)
589    }
590
591    /// Modular reduction of a u64 in constant time.
592    pub const fn reduce(&self, a: u64) -> u64 {
593        Self::reduce1(self.lazy_reduce(a), self.p)
594    }
595
596    /// Modular reduction of a u64 in variable time.
597    ///
598    /// # Safety
599    /// This function is not constant time and its timing may reveal information
600    /// about the value being reduced.
601    pub const unsafe fn reduce_vt(&self, a: u64) -> u64 {
602        Self::reduce1_vt(self.lazy_reduce(a), self.p)
603    }
604
605    /// Optimized modular reduction of a u128 in constant time.
606    pub const fn reduce_opt_u128(&self, a: u128) -> u64 {
607        debug_assert!(self.supports_opt);
608        Self::reduce1(self.lazy_reduce_opt_u128(a), self.p)
609    }
610
611    /// Optimized modular reduction of a u128 in constant time.
612    ///
613    /// # Safety
614    /// This function is not constant time and its timing may reveal information
615    /// about the value being reduced.
616    pub(crate) const unsafe fn reduce_opt_u128_vt(&self, a: u128) -> u64 {
617        debug_assert!(self.supports_opt);
618        Self::reduce1_vt(self.lazy_reduce_opt_u128(a), self.p)
619    }
620
621    /// Optimized modular reduction of a u64 in constant time.
622    pub const fn reduce_opt(&self, a: u64) -> u64 {
623        Self::reduce1(self.lazy_reduce_opt(a), self.p)
624    }
625
626    /// Optimized modular reduction of a u64 in variable time.
627    ///
628    /// # Safety
629    /// This function is not constant time and its timing may reveal information
630    /// about the value being reduced.
631    pub const unsafe fn reduce_opt_vt(&self, a: u64) -> u64 {
632        Self::reduce1_vt(self.lazy_reduce_opt(a), self.p)
633    }
634
635    /// Return x mod p in constant time.
636    /// Aborts if x >= 2 * p in debug mode.
637    pub(crate) const fn reduce1(x: u64, p: u64) -> u64 {
638        debug_assert!(p >> 63 == 0);
639        debug_assert!(x < 2 * p);
640
641        let r = const_time_cond_select(x, x.wrapping_sub(p), x < p);
642
643        debug_assert!(r == x % p);
644
645        r
646    }
647
648    /// Return x mod p in variable time.
649    /// Aborts if x >= 2 * p in debug mode.
650    ///
651    /// # Safety
652    /// This function is not constant time and its timing may reveal information
653    /// about the value being reduced.
654    #[cfg(any(target_os = "macos", target_feature = "avx2"))]
655    pub(crate) const unsafe fn reduce1_vt(x: u64, p: u64) -> u64 {
656        debug_assert!(p >> 63 == 0);
657        debug_assert!(x < 2 * p);
658
659        if x >= p {
660            x - p
661        } else {
662            x
663        }
664    }
665
666    #[cfg(all(not(target_os = "macos"), not(target_feature = "avx2")))]
667    #[inline]
668    pub(crate) const unsafe fn reduce1_vt(x: u64, p: u64) -> u64 {
669        Self::reduce1(x, p)
670    }
671
672    /// Lazy modular reduction of a in constant time.
673    /// The output is in the interval [0, 2 * p).
674    pub const fn lazy_reduce_u128(&self, a: u128) -> u64 {
675        let a_lo = a as u64;
676        let a_hi = (a >> 64) as u64;
677        let p_lo_lo = ((a_lo as u128) * (self.barrett_lo as u128)) >> 64;
678        let p_hi_lo = (a_hi as u128) * (self.barrett_lo as u128);
679        let p_lo_hi = (a_lo as u128) * (self.barrett_hi as u128);
680
681        let q = ((p_lo_hi + p_hi_lo + p_lo_lo) >> 64) + (a_hi as u128) * (self.barrett_hi as u128);
682        let r = (a - q * (self.p as u128)) as u64;
683
684        debug_assert!((r as u128) < 2 * (self.p as u128));
685        debug_assert!(r % self.p == (a % (self.p as u128)) as u64);
686
687        r
688    }
689
690    /// Lazy modular reduction of a in constant time.
691    /// The output is in the interval [0, 2 * p).
692    pub const fn lazy_reduce(&self, a: u64) -> u64 {
693        let p_lo_lo = ((a as u128) * (self.barrett_lo as u128)) >> 64;
694        let p_lo_hi = (a as u128) * (self.barrett_hi as u128);
695
696        let q = (p_lo_hi + p_lo_lo) >> 64;
697        let r = (a as u128 - q * (self.p as u128)) as u64;
698
699        debug_assert!((r as u128) < 2 * (self.p as u128));
700        debug_assert!(r % self.p == a % self.p);
701
702        r
703    }
704
705    /// Lazy optimized modular reduction of a in constant time.
706    /// The output is in the interval [0, 2 * p).
707    ///
708    /// Aborts if the input is >= p ^ 2 in debug mode.
709    pub const fn lazy_reduce_opt_u128(&self, a: u128) -> u64 {
710        debug_assert!(a < (self.p as u128) * (self.p as u128));
711
712        let q = (((self.barrett_lo as u128) * (a >> 64)) + (a << self.leading_zeros)) >> 64;
713        let r = (a - q * (self.p as u128)) as u64;
714
715        debug_assert!((r as u128) < 2 * (self.p as u128));
716        debug_assert!(r % self.p == (a % (self.p as u128)) as u64);
717
718        r
719    }
720
721    /// Lazy optimized modular reduction of a in constant time.
722    /// The output is in the interval [0, 2 * p).
723    const fn lazy_reduce_opt(&self, a: u64) -> u64 {
724        let q = a >> (64 - self.leading_zeros);
725        let r = ((a as u128) - (q as u128) * (self.p as u128)) as u64;
726
727        debug_assert!((r as u128) < 2 * (self.p as u128));
728        debug_assert!(r % self.p == a % self.p);
729
730        r
731    }
732
733    /// Lazy modular reduction of a vector in constant time.
734    /// The output coefficients are in the interval [0, 2 * p).
735    pub fn lazy_reduce_vec(&self, a: &mut [u64]) {
736        if self.supports_opt {
737            a.iter_mut().for_each(|ai| *ai = self.lazy_reduce_opt(*ai))
738        } else {
739            a.iter_mut().for_each(|ai| *ai = self.lazy_reduce(*ai))
740        }
741    }
742
743    /// Returns a random vector.
744    pub fn random_vec<R: RngCore + CryptoRng>(&self, size: usize, rng: &mut R) -> Vec<u64> {
745        rng.sample_iter(self.distribution).take(size).collect_vec()
746    }
747
748    /// Length of the serialization of a vector of size `size`.
749    ///
750    /// Panics if the size is not a multiple of 8.
751    pub const fn serialization_length(&self, size: usize) -> usize {
752        assert!(size % 8 == 0);
753        let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
754        p_nbits * size / 8
755    }
756
757    /// Serialize a vector of elements of length a multiple of 8.
758    ///
759    /// Panics if the length of the vector is not a multiple of 8.
760    pub fn serialize_vec(&self, a: &[u64]) -> Vec<u8> {
761        let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
762        transcode_to_bytes(a, p_nbits)
763    }
764
765    /// Deserialize a vector of bytes into a vector of elements mod p.
766    pub fn deserialize_vec(&self, b: &[u8]) -> Vec<u64> {
767        let p_nbits = 64 - (self.p - 1).leading_zeros() as usize;
768        transcode_from_bytes(b, p_nbits)
769    }
770}
771
772#[cfg(test)]
773mod tests {
774    use super::{primes, Modulus};
775    use itertools::{izip, Itertools};
776    use proptest::collection::vec as prop_vec;
777    use proptest::prelude::{any, BoxedStrategy, Just, Strategy};
778    use rand::{rng, RngCore};
779
780    // Utility functions for the proptests.
781
782    fn valid_moduli() -> impl Strategy<Value = Modulus> {
783        any::<u64>().prop_filter_map("filter invalid moduli", |p| Modulus::new(p).ok())
784    }
785
786    fn vecs() -> BoxedStrategy<(Vec<u64>, Vec<u64>)> {
787        prop_vec(any::<u64>(), 1..100)
788            .prop_flat_map(|vec| {
789                let len = vec.len();
790                (Just(vec), prop_vec(any::<u64>(), len))
791            })
792            .boxed()
793    }
794
795    proptest! {
796        #[test]
797        fn constructor(p: u64) {
798            // 63 and 64-bit integers do not work.
799            prop_assert!(Modulus::new(p | (1u64 << 62)).is_err());
800            prop_assert!(Modulus::new(p | (1u64 << 63)).is_err());
801
802            // p = 0 & 1 do not work.
803            prop_assert!(Modulus::new(0u64).is_err());
804            prop_assert!(Modulus::new(1u64).is_err());
805
806            // Otherwise, all moduli should work.
807            prop_assume!(p >> 2 >= 2);
808            let q = Modulus::new(p >> 2);
809            prop_assert!(q.is_ok());
810            prop_assert_eq!(*q.unwrap(), p >> 2);
811        }
812
813        #[test]
814        fn neg(p in valid_moduli(), mut a: u64) {
815            a = p.reduce(a);
816            prop_assert_eq!(p.neg(a), (*p - a) % *p);
817            unsafe { prop_assert_eq!(p.neg_vt(a), (*p - a) % *p) }
818
819            #[cfg(debug_assertions)]
820            {
821                prop_assert!(std::panic::catch_unwind(|| p.neg(*p)).is_err());
822                prop_assert!(std::panic::catch_unwind(|| p.neg(*p + 1)).is_err());
823            }
824        }
825
826        #[test]
827        fn add(p in valid_moduli(), mut a: u64, mut b: u64) {
828            a = p.reduce(a);
829            b = p.reduce(b);
830            prop_assert_eq!(p.add(a, b), (a + b) % *p);
831            unsafe { prop_assert_eq!(p.add_vt(a, b), (a + b) % *p) }
832
833            #[cfg(debug_assertions)]
834            {
835                prop_assert!(std::panic::catch_unwind(|| p.add(*p, a)).is_err());
836                prop_assert!(std::panic::catch_unwind(|| p.add(a, *p)).is_err());
837                prop_assert!(std::panic::catch_unwind(|| p.add(*p + 1, a)).is_err());
838                prop_assert!(std::panic::catch_unwind(|| p.add(a, *p + 1)).is_err());
839            }
840        }
841
842        #[test]
843        fn sub(p in valid_moduli(), mut a: u64, mut b: u64) {
844            a = p.reduce(a);
845            b = p.reduce(b);
846            prop_assert_eq!(p.sub(a, b), (a + *p - b) % *p);
847            unsafe { prop_assert_eq!(p.sub_vt(a, b), (a + *p - b) % *p) }
848
849            #[cfg(debug_assertions)]
850            {
851                prop_assert!(std::panic::catch_unwind(|| p.sub(*p, a)).is_err());
852                prop_assert!(std::panic::catch_unwind(|| p.sub(a, *p)).is_err());
853                prop_assert!(std::panic::catch_unwind(|| p.sub(*p + 1, a)).is_err());
854                prop_assert!(std::panic::catch_unwind(|| p.sub(a, *p + 1)).is_err());
855            }
856        }
857
858        #[test]
859        fn mul(p in valid_moduli(), mut a: u64, mut b: u64) {
860            a = p.reduce(a);
861            b = p.reduce(b);
862            prop_assert_eq!(p.mul(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128));
863            unsafe { prop_assert_eq!(p.mul_vt(a, b) as u128, ((a as u128) * (b as u128)) % (*p as u128)) }
864
865            #[cfg(debug_assertions)]
866            {
867                prop_assert!(std::panic::catch_unwind(|| p.mul(*p, a)).is_err());
868                prop_assert!(std::panic::catch_unwind(|| p.mul(a, *p)).is_err());
869                prop_assert!(std::panic::catch_unwind(|| p.mul(*p + 1, a)).is_err());
870                prop_assert!(std::panic::catch_unwind(|| p.mul(a, *p + 1)).is_err());
871            }
872        }
873
874        #[test]
875        fn mul_shoup(p in valid_moduli(), mut a: u64, mut b: u64) {
876            a = p.reduce(a);
877            b = p.reduce(b);
878
879            // Compute shoup representation
880            let b_shoup = p.shoup(b);
881
882            #[cfg(debug_assertions)]
883            {
884                prop_assert!(std::panic::catch_unwind(|| p.shoup(*p)).is_err());
885                prop_assert!(std::panic::catch_unwind(|| p.shoup(*p + 1)).is_err());
886            }
887
888            // Check that the multiplication yields the expected result
889            prop_assert_eq!(p.mul_shoup(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (*p as u128));
890            unsafe { prop_assert_eq!(p.mul_shoup_vt(a, b, b_shoup) as u128, ((a as u128) * (b as u128)) % (*p as u128)) }
891
892            // Check that the multiplication with incorrect b_shoup panics in debug mode
893            #[cfg(debug_assertions)]
894            {
895                prop_assert!(std::panic::catch_unwind(|| p.mul_shoup(a, *p, b_shoup)).is_err());
896                prop_assume!(a != b);
897                prop_assert!(std::panic::catch_unwind(|| p.mul_shoup(a, a, b_shoup)).is_err());
898            }
899        }
900
901        #[test]
902        fn reduce(p in valid_moduli(), a: u64) {
903            prop_assert_eq!(p.reduce(a), a % *p);
904            unsafe { prop_assert_eq!(p.reduce_vt(a), a % *p) }
905            if p.supports_opt {
906                prop_assert_eq!(p.reduce_opt(a), a % *p);
907                unsafe { prop_assert_eq!(p.reduce_opt_vt(a), a % *p) }
908            }
909        }
910
911        #[test]
912        fn lazy_reduce(p in valid_moduli(), a: u64) {
913            prop_assert!(p.lazy_reduce(a) < 2 * *p);
914            prop_assert_eq!(p.lazy_reduce(a) % *p, p.reduce(a));
915        }
916
917        #[test]
918        fn reduce_i64(p in valid_moduli(), a: i64) {
919            let b = if a < 0 { p.neg(p.reduce(-a as u64)) } else { p.reduce(a as u64) };
920            prop_assert_eq!(p.reduce_i64(a), b);
921            unsafe { prop_assert_eq!(p.reduce_i64_vt(a), b) }
922        }
923
924        #[test]
925        fn reduce_u128(p in valid_moduli(), mut a: u128) {
926            prop_assert_eq!(p.reduce_u128(a) as u128, a % (*p as u128));
927            unsafe { prop_assert_eq!(p.reduce_u128_vt(a) as u128, a % (*p as u128)) }
928            if p.supports_opt {
929                let p_square = (*p as u128) * (*p as u128);
930                a %= p_square;
931                prop_assert_eq!(p.reduce_opt_u128(a) as u128, a % (*p as u128));
932                unsafe { prop_assert_eq!(p.reduce_opt_u128_vt(a) as u128, a % (*p as u128)) }
933            }
934        }
935
936        #[test]
937        fn add_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
938            p.reduce_vec(&mut a);
939            p.reduce_vec(&mut b);
940            let c = a.clone();
941            p.add_vec(&mut a, &b);
942            prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec());
943            a.clone_from(&c);
944            unsafe { p.add_vec_vt(&mut a, &b) }
945            prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.add(*bi, *ci)).collect_vec());
946        }
947
948        #[test]
949        fn sub_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
950            p.reduce_vec(&mut a);
951            p.reduce_vec(&mut b);
952            let c = a.clone();
953            p.sub_vec(&mut a, &b);
954            prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec());
955            a.clone_from(&c);
956            unsafe { p.sub_vec_vt(&mut a, &b) }
957            prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.sub(*ci, *bi)).collect_vec());
958        }
959
960        #[test]
961        fn mul_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
962            p.reduce_vec(&mut a);
963            p.reduce_vec(&mut b);
964            let c = a.clone();
965            p.mul_vec(&mut a, &b);
966            prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
967            a.clone_from(&c);
968            unsafe { p.mul_vec_vt(&mut a, &b); }
969            prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
970        }
971
972        #[test]
973        fn scalar_mul_vec(p in valid_moduli(), mut a: Vec<u64>, mut b: u64) {
974            p.reduce_vec(&mut a);
975            b = p.reduce(b);
976            let c = a.clone();
977
978            p.scalar_mul_vec(&mut a, b);
979            prop_assert_eq!(a.clone(), c.iter().map(|ci| p.mul(*ci, b)).collect_vec());
980
981            a.clone_from(&c);
982            unsafe { p.scalar_mul_vec_vt(&mut a, b) }
983            prop_assert_eq!(a, c.iter().map(|ci| p.mul(*ci, b)).collect_vec());
984        }
985
986        #[test]
987        fn mul_shoup_vec(p in valid_moduli(), (mut a, mut b) in vecs()) {
988            p.reduce_vec(&mut a);
989            p.reduce_vec(&mut b);
990            let b_shoup = p.shoup_vec(&b);
991            let c = a.clone();
992            p.mul_shoup_vec(&mut a, &b, &b_shoup);
993            prop_assert_eq!(a.clone(), izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
994            a.clone_from(&c);
995            unsafe { p.mul_shoup_vec_vt(&mut a, &b, &b_shoup) }
996            prop_assert_eq!(a, izip!(b.iter(), c.iter()).map(|(bi, ci)| p.mul(*ci, *bi)).collect_vec());
997        }
998
999        #[test]
1000        fn reduce_vec(p in valid_moduli(), a: Vec<u64>) {
1001            let mut b = a.clone();
1002            p.reduce_vec(&mut b);
1003            prop_assert_eq!(b.clone(), a.iter().map(|ai| p.reduce(*ai)).collect_vec());
1004
1005            b.clone_from(&a);
1006            unsafe { p.reduce_vec_vt(&mut b) }
1007            prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec());
1008        }
1009
1010        #[test]
1011        fn lazy_reduce_vec(p in valid_moduli(), a: Vec<u64>) {
1012            let mut b = a.clone();
1013            p.lazy_reduce_vec(&mut b);
1014            prop_assert!(b.iter().all(|bi| *bi < 2 * *p));
1015            prop_assert!(izip!(a, b).all(|(ai, bi)| bi % *p == ai % *p));
1016        }
1017
1018        #[test]
1019        fn reduce_vec_new(p in valid_moduli(), a: Vec<u64>) {
1020            let b = p.reduce_vec_new(&a);
1021            prop_assert_eq!(b, a.iter().map(|ai| p.reduce(*ai)).collect_vec());
1022            prop_assert_eq!(p.reduce_vec_new(&a), unsafe { p.reduce_vec_new_vt(&a) });
1023        }
1024
1025        #[test]
1026        fn reduce_vec_i64(p in valid_moduli(), a: Vec<i64>) {
1027            let b = p.reduce_vec_i64(&a);
1028            prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec());
1029            let b = unsafe { p.reduce_vec_i64_vt(&a) };
1030            prop_assert_eq!(b, a.iter().map(|ai| p.reduce_i64(*ai)).collect_vec());
1031        }
1032
1033        #[test]
1034        fn neg_vec(p in valid_moduli(), mut a: Vec<u64>) {
1035            p.reduce_vec(&mut a);
1036            let mut b = a.clone();
1037            p.neg_vec(&mut b);
1038            prop_assert_eq!(b.clone(), a.iter().map(|ai| p.neg(*ai)).collect_vec());
1039            b.clone_from(&a);
1040            unsafe { p.neg_vec_vt(&mut b); }
1041            prop_assert_eq!(b, a.iter().map(|ai| p.neg(*ai)).collect_vec());
1042        }
1043
1044        #[test]
1045        fn random_vec(p in valid_moduli(), size in 1..1000usize) {
1046            let mut rng = rng();
1047
1048            let v = p.random_vec(size, &mut rng);
1049            prop_assert_eq!(v.len(), size);
1050
1051            let w = p.random_vec(size, &mut rng);
1052            prop_assert_eq!(w.len(), size);
1053
1054            if (*p).leading_zeros() <= 30 {
1055                prop_assert_ne!(v, w); // This will hold with probability at least 2^(-30)
1056            }
1057        }
1058
1059        #[test]
1060        fn serialize(p in valid_moduli(), mut a in prop_vec(any::<u64>(), 8)) {
1061            p.reduce_vec(&mut a);
1062            let b = p.serialize_vec(&a);
1063            let c = p.deserialize_vec(&b);
1064            prop_assert_eq!(a, c);
1065        }
1066    }
1067
1068    // TODO: Make a proptest.
1069    #[test]
1070    fn mul_opt() {
1071        let ntests = 100;
1072        let mut rng = rand::rng();
1073
1074        #[allow(clippy::single_element_loop)]
1075        for p in [4611686018326724609] {
1076            let q = Modulus::new(p).unwrap();
1077            assert!(primes::supports_opt(p));
1078
1079            assert_eq!(q.mul_opt(0, 1), 0);
1080            assert_eq!(q.mul_opt(1, 1), 1);
1081            assert_eq!(q.mul_opt(2 % p, 3 % p), 6 % p);
1082            assert_eq!(q.mul_opt(p - 1, 1), p - 1);
1083            assert_eq!(q.mul_opt(p - 1, 2 % p), p - 2);
1084
1085            #[cfg(debug_assertions)]
1086            {
1087                assert!(std::panic::catch_unwind(|| q.mul_opt(p, 1)).is_err());
1088                assert!(std::panic::catch_unwind(|| q.mul_opt(p << 1, 1)).is_err());
1089                assert!(std::panic::catch_unwind(|| q.mul_opt(0, p)).is_err());
1090                assert!(std::panic::catch_unwind(|| q.mul_opt(0, p << 1)).is_err());
1091            }
1092
1093            for _ in 0..ntests {
1094                let a = rng.next_u64() % p;
1095                let b = rng.next_u64() % p;
1096                assert_eq!(
1097                    q.mul_opt(a, b),
1098                    (((a as u128) * (b as u128)) % (p as u128)) as u64
1099                );
1100            }
1101        }
1102    }
1103
1104    // TODO: Make a proptest.
1105    #[test]
1106    fn pow() {
1107        let ntests = 10;
1108        let mut rng = rand::rng();
1109
1110        for p in [2u64, 3, 17, 1987, 4611686018326724609] {
1111            let q = Modulus::new(p).unwrap();
1112
1113            assert_eq!(q.pow(p - 1, 0), 1);
1114            assert_eq!(q.pow(p - 1, 1), p - 1);
1115            assert_eq!(q.pow(p - 1, 2 % p), 1);
1116            assert_eq!(q.pow(1, p - 2), 1);
1117            assert_eq!(q.pow(1, p - 1), 1);
1118
1119            #[cfg(debug_assertions)]
1120            {
1121                assert!(std::panic::catch_unwind(|| q.pow(p, 1)).is_err());
1122                assert!(std::panic::catch_unwind(|| q.pow(p << 1, 1)).is_err());
1123                assert!(std::panic::catch_unwind(|| q.pow(0, p)).is_err());
1124                assert!(std::panic::catch_unwind(|| q.pow(0, p << 1)).is_err());
1125            }
1126
1127            for _ in 0..ntests {
1128                let a = rng.next_u64() % p;
1129                let b = (rng.next_u64() % p) % 1000;
1130                let mut c = b;
1131                let mut r = 1;
1132                while c > 0 {
1133                    r = q.mul(r, a);
1134                    c -= 1;
1135                }
1136                assert_eq!(q.pow(a, b), r);
1137            }
1138        }
1139    }
1140
1141    // TODO: Make a proptest.
1142    #[test]
1143    fn inv() {
1144        let ntests = 100;
1145        let mut rng = rand::rng();
1146
1147        for p in [2u64, 3, 17, 1987, 4611686018326724609] {
1148            let q = Modulus::new(p).unwrap();
1149
1150            assert!(q.inv(0).is_none());
1151            assert_eq!(q.inv(1).unwrap(), 1);
1152            assert_eq!(q.inv(p - 1).unwrap(), p - 1);
1153
1154            #[cfg(debug_assertions)]
1155            {
1156                assert!(std::panic::catch_unwind(|| q.inv(p)).is_err());
1157                assert!(std::panic::catch_unwind(|| q.inv(p << 1)).is_err());
1158            }
1159
1160            for _ in 0..ntests {
1161                let a = rng.next_u64() % p;
1162                let b = q.inv(a);
1163
1164                if a == 0 {
1165                    assert!(b.is_none())
1166                } else {
1167                    assert!(b.is_some());
1168                    assert_eq!(q.mul(a, b.unwrap()), 1)
1169                }
1170            }
1171        }
1172    }
1173}