tfhe_ntt/
lib.rs

1//! tfhe-ntt is a pure Rust high performance number theoretic transform library that processes
2//! vectors of sizes that are powers of two.
3//!
4//! This library provides three kinds of NTT:
5//! - The prime NTT computes the transform in a field $\mathbb{Z}/p\mathbb{Z}$ with $p$ prime,
6//!   allowing for arithmetic operations on the polynomial modulo $p$.
7//! - The native NTT internally computes the transform of the first kind with several primes,
8//!   allowing the simulation of arithmetic modulo the product of those primes, and truncates the
9//!   result when the inverse transform is desired. The truncated result is guaranteed to be as if
10//!   the computations were performed with wrapping arithmetic, as long as the full integer result
11//!   would have been smaller than half the product of the primes, in absolute value. It is
12//!   guaranteed to be suitable for multiplying two polynomials with arbitrary coefficients, and
13//!   returns the result in wrapping arithmetic.
14//! - The native binary NTT is similar to the native NTT, but is optimized for the case where one of
15//!   the operands of the multiplication has coefficients in $\lbrace 0, 1 \rbrace$.
16//!
17//! # Features
18//!
19//! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions.
20//! - `nightly`: This enables unstable Rust features to further speed up the NTT, by enabling AVX512
21//!   instructions on CPUs that support them. This feature requires a nightly Rust toolchain.
22//!
23//! # Example
24//!
25//! ```
26//! use tfhe_ntt::prime32::Plan;
27//!
28//! const N: usize = 32;
29//! let p = 1062862849;
30//! let plan = Plan::try_new(N, p).unwrap();
31//!
32//! let data = [
33//!     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24,
34//!     25, 26, 27, 28, 29, 30, 31,
35//! ];
36//!
37//! let mut transformed_fwd = data;
38//! plan.fwd(&mut transformed_fwd);
39//!
40//! let mut transformed_inv = transformed_fwd;
41//! plan.inv(&mut transformed_inv);
42//!
43//! for (&actual, expected) in transformed_inv
44//!     .iter()
45//!     .zip(data.iter().map(|x| x * N as u32))
46//! {
47//!     assert_eq!(expected, actual);
48//! }
49//! ```
50
51#![cfg_attr(not(feature = "std"), no_std)]
52#![allow(clippy::too_many_arguments, clippy::let_unit_value)]
53#![cfg_attr(docsrs, feature(doc_cfg))]
54
55/// Implementation notes:
56///
57/// we use `NullaryFnOnce` instead of a closure because we need the `#[inline(always)]`
58/// annotation, which doesn't always work with closures for some reason.
59///
60/// Shoup modular multiplication  
61/// <https://pdfs.semanticscholar.org/e000/fa109f1b2a6a3e52e04462bac4b7d58140c9.pdf>
62///
63/// Lemire modular reduction  
64/// <https://lemire.me/blog/2019/02/08/faster-remainders-when-the-divisor-is-a-constant-beating-compilers-and-libdivide/>
65///
66/// Barrett reduction  
67/// <https://arxiv.org/pdf/2103.16400.pdf> Algorithm 8
68///
69/// Chinese remainder theorem solution:
70/// The art of computer programming (Donald E. Knuth), section 4.3.2
71///
72/// Implementation notes on the single Barrett reduction code can be found at
73/// <https://github.com/zama-ai/tfhe-rs/blob/main/implementation_notes/tfhe-ntt/gh_issue_2037_barrett_range.md>
74/// Or from the repo root at: implementation_notes/tfhe-ntt/gh_issue_2037_barrett_range.md
75#[allow(dead_code)]
76fn implementation_notes() {}
77
78use u256_impl::u256;
79
80#[allow(unused_imports)]
81use pulp::*;
82
83#[doc(hidden)]
84pub mod prime;
85mod roots;
86mod u256_impl;
87
88/// Fast division by a constant divisor.
89pub mod fastdiv;
90/// 32bit negacyclic NTT for a prime modulus.
91pub mod prime32;
92/// 64bit negacyclic NTT for a prime modulus.
93pub mod prime64;
94
95/// Negacyclic NTT for multiplying two polynomials with values less than `2^128`.
96pub mod native128;
97/// Negacyclic NTT for multiplying two polynomials with values less than `2^32`.
98pub mod native32;
99/// Negacyclic NTT for multiplying two polynomials with values less than `2^64`.
100pub mod native64;
101
102/// Negacyclic NTT for multiplying a polynomial with values less than `2^128` with a binary
103/// polynomial.
104pub mod native_binary128;
105/// Negacyclic NTT for multiplying a polynomial with values less than `2^32` with a binary
106/// polynomial.
107pub mod native_binary32;
108/// Negacyclic NTT for multiplying a polynomial with values less than `2^64` with a binary
109/// polynomial.
110pub mod native_binary64;
111
112pub mod product;
113
114// Fn arguments are (simd, z0, z1, w, w_shoup, p, neg_p, two_p)
115trait Butterfly<S: Copy, V: Copy>: Copy + Fn(S, V, V, V, V, V, V, V) -> (V, V) {}
116impl<F: Copy + Fn(S, V, V, V, V, V, V, V) -> (V, V), S: Copy, V: Copy> Butterfly<S, V> for F {}
117
118#[inline]
119fn bit_rev(nbits: u32, i: usize) -> usize {
120    i.reverse_bits() >> (usize::BITS - nbits)
121}
122
123#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
124#[derive(Copy, Clone, Debug)]
125#[repr(transparent)]
126struct V3(pulp::x86::V3);
127
128#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
129#[cfg(feature = "nightly")]
130#[derive(Copy, Clone, Debug)]
131#[repr(transparent)]
132struct V4(pulp::x86::V4);
133
134#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
135pulp::simd_type! {
136    struct V4IFma {
137        pub sse: "sse",
138        pub sse2: "sse2",
139        pub fxsr: "fxsr",
140        pub sse3: "sse3",
141        pub ssse3: "ssse3",
142        pub sse4_1: "sse4.1",
143        pub sse4_2: "sse4.2",
144        pub popcnt: "popcnt",
145        pub avx: "avx",
146        pub avx2: "avx2",
147        pub bmi1: "bmi1",
148        pub bmi2: "bmi2",
149        pub fma: "fma",
150        pub lzcnt: "lzcnt",
151        pub avx512f: "avx512f",
152        pub avx512bw: "avx512bw",
153        pub avx512cd: "avx512cd",
154        pub avx512dq: "avx512dq",
155        pub avx512vl: "avx512vl",
156        pub avx512ifma: "avx512ifma",
157    }
158}
159
160#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
161#[cfg(feature = "nightly")]
162impl V4 {
163    #[inline]
164    pub fn try_new() -> Option<Self> {
165        pulp::x86::V4::try_new().map(Self)
166    }
167
168    /// Returns separately two vectors containing the low 64 bits of the result,
169    /// and the high 64 bits of the result.
170    #[inline(always)]
171    pub fn widening_mul_u64x8(self, a: u64x8, b: u64x8) -> (u64x8, u64x8) {
172        // https://stackoverflow.com/a/28827013
173        let avx = self.avx512f;
174        let x = cast(a);
175        let y = cast(b);
176
177        let lo_mask = avx._mm512_set1_epi64(0x0000_0000_FFFF_FFFFu64 as _);
178        let x_hi = avx._mm512_shuffle_epi32::<0b1011_0001>(x);
179        let y_hi = avx._mm512_shuffle_epi32::<0b1011_0001>(y);
180
181        let z_lo_lo = avx._mm512_mul_epu32(x, y);
182        let z_lo_hi = avx._mm512_mul_epu32(x, y_hi);
183        let z_hi_lo = avx._mm512_mul_epu32(x_hi, y);
184        let z_hi_hi = avx._mm512_mul_epu32(x_hi, y_hi);
185
186        let z_lo_lo_shift = avx._mm512_srli_epi64::<32>(z_lo_lo);
187
188        let sum_tmp = avx._mm512_add_epi64(z_lo_hi, z_lo_lo_shift);
189        let sum_lo = avx._mm512_and_si512(sum_tmp, lo_mask);
190        let sum_mid = avx._mm512_srli_epi64::<32>(sum_tmp);
191
192        let sum_mid2 = avx._mm512_add_epi64(z_hi_lo, sum_lo);
193        let sum_mid2_hi = avx._mm512_srli_epi64::<32>(sum_mid2);
194        let sum_hi = avx._mm512_add_epi64(z_hi_hi, sum_mid);
195
196        let prod_hi = avx._mm512_add_epi64(sum_hi, sum_mid2_hi);
197        let prod_lo = avx._mm512_add_epi64(
198            avx._mm512_slli_epi64::<32>(avx._mm512_add_epi64(z_lo_hi, z_hi_lo)),
199            z_lo_lo,
200        );
201
202        (cast(prod_lo), cast(prod_hi))
203    }
204
205    /// Multiplies the low 32 bits of each 64 bit integer and returns the 64 bit result.
206    #[inline(always)]
207    pub fn mul_low_32_bits_u64x8(self, a: u64x8, b: u64x8) -> u64x8 {
208        pulp::cast(self.avx512f._mm512_mul_epu32(pulp::cast(a), pulp::cast(b)))
209    }
210}
211
212#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
213#[cfg(feature = "nightly")]
214impl V4IFma {
215    /// Returns separately two vectors containing the low 52 bits of the result,
216    /// and the high 52 bits of the result.
217    #[inline(always)]
218    pub fn widening_mul_u52x8(self, a: u64x8, b: u64x8) -> (u64x8, u64x8) {
219        let a = cast(a);
220        let b = cast(b);
221        let zero = cast(self.splat_u64x8(0));
222        (
223            cast(self.avx512ifma._mm512_madd52lo_epu64(zero, a, b)),
224            cast(self.avx512ifma._mm512_madd52hi_epu64(zero, a, b)),
225        )
226    }
227
228    /// (a * b + c) mod 2^52 for each 52 bit integer in a, b, and c.
229    #[inline(always)]
230    pub fn wrapping_mul_add_u52x8(self, a: u64x8, b: u64x8, c: u64x8) -> u64x8 {
231        self.and_u64x8(
232            cast(
233                self.avx512ifma
234                    ._mm512_madd52lo_epu64(cast(c), cast(a), cast(b)),
235            ),
236            self.splat_u64x8((1u64 << 52) - 1),
237        )
238    }
239}
240
241#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
242#[cfg(feature = "nightly")]
243trait SupersetOfV4: Copy {
244    fn get_v4(self) -> V4;
245    fn vectorize(self, f: impl pulp::NullaryFnOnce);
246}
247
248#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249#[cfg(feature = "nightly")]
250impl SupersetOfV4 for V4 {
251    #[inline(always)]
252    fn get_v4(self) -> V4 {
253        self
254    }
255    #[inline(always)]
256    fn vectorize(self, f: impl pulp::NullaryFnOnce) {
257        self.0.vectorize(f);
258    }
259}
260#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
261#[cfg(feature = "nightly")]
262impl SupersetOfV4 for V4IFma {
263    #[inline(always)]
264    fn get_v4(self) -> V4 {
265        *self
266    }
267    #[inline(always)]
268    fn vectorize(self, f: impl pulp::NullaryFnOnce) {
269        self.vectorize(f);
270    }
271}
272
273#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
274impl V3 {
275    #[inline]
276    pub fn try_new() -> Option<Self> {
277        pulp::x86::V3::try_new().map(Self)
278    }
279
280    /// Returns separately two vectors containing the low 64 bits of the result,
281    /// and the high 64 bits of the result.
282    #[inline(always)]
283    pub fn widening_mul_u64x4(self, a: u64x4, b: u64x4) -> (u64x4, u64x4) {
284        // https://stackoverflow.com/a/28827013
285        let avx = self.avx;
286        let avx2 = self.avx2;
287        let x = cast(a);
288        let y = cast(b);
289        let lo_mask = avx._mm256_set1_epi64x(0x0000_0000_FFFF_FFFFu64 as _);
290        let x_hi = avx2._mm256_shuffle_epi32::<0b10110001>(x);
291        let y_hi = avx2._mm256_shuffle_epi32::<0b10110001>(y);
292
293        let z_lo_lo = avx2._mm256_mul_epu32(x, y);
294        let z_lo_hi = avx2._mm256_mul_epu32(x, y_hi);
295        let z_hi_lo = avx2._mm256_mul_epu32(x_hi, y);
296        let z_hi_hi = avx2._mm256_mul_epu32(x_hi, y_hi);
297
298        let z_lo_lo_shift = avx2._mm256_srli_epi64::<32>(z_lo_lo);
299
300        let sum_tmp = avx2._mm256_add_epi64(z_lo_hi, z_lo_lo_shift);
301        let sum_lo = avx2._mm256_and_si256(sum_tmp, lo_mask);
302        let sum_mid = avx2._mm256_srli_epi64::<32>(sum_tmp);
303
304        let sum_mid2 = avx2._mm256_add_epi64(z_hi_lo, sum_lo);
305        let sum_mid2_hi = avx2._mm256_srli_epi64::<32>(sum_mid2);
306        let sum_hi = avx2._mm256_add_epi64(z_hi_hi, sum_mid);
307
308        let prod_hi = avx2._mm256_add_epi64(sum_hi, sum_mid2_hi);
309        let prod_lo = avx2._mm256_add_epi64(
310            avx2._mm256_slli_epi64::<32>(avx2._mm256_add_epi64(z_lo_hi, z_hi_lo)),
311            z_lo_lo,
312        );
313
314        (cast(prod_lo), cast(prod_hi))
315    }
316
317    /// Multiplies the low 32 bits of each 64 bit integer and returns the 64 bit result.
318    #[inline(always)]
319    pub fn mul_low_32_bits_u64x4(self, a: u64x4, b: u64x4) -> u64x4 {
320        pulp::cast(self.avx2._mm256_mul_epu32(pulp::cast(a), pulp::cast(b)))
321    }
322
323    // (a * b mod 2^32) mod 2^64 for each element in a and b.
324    #[inline(always)]
325    pub fn wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(self, a: u64x4, b: u64x4) -> u64x4 {
326        let a = cast(a);
327        let b = cast(b);
328        let avx2 = self.avx2;
329        let x_hi = avx2._mm256_shuffle_epi32::<0b10110001>(a);
330        let z_lo_lo = avx2._mm256_mul_epu32(a, b);
331        let z_hi_lo = avx2._mm256_mul_epu32(x_hi, b);
332        cast(avx2._mm256_add_epi64(avx2._mm256_slli_epi64::<32>(z_hi_lo), z_lo_lo))
333    }
334}
335
336#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
337#[cfg(feature = "nightly")]
338impl core::ops::Deref for V4 {
339    type Target = pulp::x86::V4;
340
341    #[inline]
342    fn deref(&self) -> &Self::Target {
343        &self.0
344    }
345}
346
347#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
348#[cfg(feature = "nightly")]
349impl core::ops::Deref for V4IFma {
350    type Target = V4;
351
352    #[inline]
353    fn deref(&self) -> &Self::Target {
354        let Self {
355            sse,
356            sse2,
357            fxsr,
358            sse3,
359            ssse3,
360            sse4_1,
361            sse4_2,
362            popcnt,
363            avx,
364            avx2,
365            bmi1,
366            bmi2,
367            fma,
368            lzcnt,
369            avx512f,
370            avx512bw,
371            avx512cd,
372            avx512dq,
373            avx512vl,
374            avx512ifma: _,
375        } = *self;
376        let simd_ref = (pulp::x86::V4 {
377            sse,
378            sse2,
379            fxsr,
380            sse3,
381            ssse3,
382            sse4_1,
383            sse4_2,
384            popcnt,
385            avx,
386            avx2,
387            bmi1,
388            bmi2,
389            fma,
390            lzcnt,
391            avx512f,
392            avx512bw,
393            avx512cd,
394            avx512dq,
395            avx512vl,
396        })
397        .to_ref();
398
399        // SAFETY
400        // `pulp::x86::V4` and `crate::V4` have the same layout, since the latter is
401        // #[repr(transparent)].
402        unsafe { &*(simd_ref as *const pulp::x86::V4 as *const V4) }
403    }
404}
405
406#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
407impl core::ops::Deref for V3 {
408    type Target = pulp::x86::V3;
409
410    #[inline]
411    fn deref(&self) -> &Self::Target {
412        &self.0
413    }
414}
415
416// the magic constants are such that
417// for all x < 2^64
418// x / P_i == ((x * P_i_MAGIC) >> 64) >> P_i_MAGIC_SHIFT
419//
420// this can be used to implement the modulo operation in constant time to avoid side channel
421// attacks, can also speed up the operation x % P_i, since the compiler doesn't manage to vectorize
422// it on its own.
423//
424// how to:
425// run `cargo test generate_primes -- --nocapture`
426//
427// copy paste the generated primes in this function
428// ```
429// pub fn codegen(x: u64) -> u64 {
430//     x / $PRIME
431// }
432// ```
433//
434// look at the generated assembly for codegen
435// extract primes that satisfy the desired property
436//
437// asm should look like this on x86_64
438// ```
439// mov rax, rdi
440// movabs rcx, P_MAGIC (as i64 signed value)
441// mul rcx
442// mov rax, rdx
443// shr rax, P_MAGIC_SHIFT
444// ret
445// ```
446#[allow(dead_code)]
447pub(crate) mod primes32 {
448    use crate::{
449        fastdiv::{Div32, Div64},
450        prime::exp_mod32,
451    };
452
453    pub const P0: u32 = 0b0011_1111_0101_1010_0000_0000_0000_0001;
454    pub const P1: u32 = 0b0011_1111_0101_1101_0000_0000_0000_0001;
455    pub const P2: u32 = 0b0011_1111_0111_0110_0000_0000_0000_0001;
456    pub const P3: u32 = 0b0011_1111_1000_0010_0000_0000_0000_0001;
457    pub const P4: u32 = 0b0011_1111_1010_1100_0000_0000_0000_0001;
458    pub const P5: u32 = 0b0011_1111_1010_1111_0000_0000_0000_0001;
459    pub const P6: u32 = 0b0011_1111_1011_0001_0000_0000_0000_0001;
460    pub const P7: u32 = 0b0011_1111_1011_1011_0000_0000_0000_0001;
461    pub const P8: u32 = 0b0011_1111_1101_1110_0000_0000_0000_0001;
462    pub const P9: u32 = 0b0011_1111_1111_1100_0000_0000_0000_0001;
463
464    pub const P0_MAGIC: u64 = 9317778228489988551;
465    pub const P1_MAGIC: u64 = 4658027473943558643;
466    pub const P2_MAGIC: u64 = 1162714878353869247;
467    pub const P3_MAGIC: u64 = 4647426722536610861;
468    pub const P4_MAGIC: u64 = 9270903515973367219;
469    pub const P5_MAGIC: u64 = 2317299382174935855;
470    pub const P6_MAGIC: u64 = 9268060552616330319;
471    pub const P7_MAGIC: u64 = 2315594963384859737;
472    pub const P8_MAGIC: u64 = 9242552129100825291;
473    pub const P9_MAGIC: u64 = 576601523622774689;
474
475    pub const P0_MAGIC_SHIFT: u32 = 29;
476    pub const P1_MAGIC_SHIFT: u32 = 28;
477    pub const P2_MAGIC_SHIFT: u32 = 26;
478    pub const P3_MAGIC_SHIFT: u32 = 28;
479    pub const P4_MAGIC_SHIFT: u32 = 29;
480    pub const P5_MAGIC_SHIFT: u32 = 27;
481    pub const P6_MAGIC_SHIFT: u32 = 29;
482    pub const P7_MAGIC_SHIFT: u32 = 27;
483    pub const P8_MAGIC_SHIFT: u32 = 29;
484    pub const P9_MAGIC_SHIFT: u32 = 25;
485
486    const fn mul_mod(modulus: u32, a: u32, b: u32) -> u32 {
487        let wide = a as u64 * b as u64;
488        (wide % modulus as u64) as u32
489    }
490
491    const fn inv_mod(modulus: u32, x: u32) -> u32 {
492        exp_mod32(Div32::new(modulus), x, modulus - 2)
493    }
494
495    const fn shoup(modulus: u32, w: u32) -> u32 {
496        (((w as u64) << 32) / modulus as u64) as u32
497    }
498
499    const fn mul_mod64(modulus: u64, a: u64, b: u64) -> u64 {
500        let wide = a as u128 * b as u128;
501        (wide % modulus as u128) as u64
502    }
503
504    const fn exp_mod64(modulus: u64, base: u64, pow: u64) -> u64 {
505        crate::prime::exp_mod64(Div64::new(modulus), base, pow)
506    }
507
508    const fn shoup64(modulus: u64, w: u64) -> u64 {
509        (((w as u128) << 64) / modulus as u128) as u64
510    }
511
512    pub const P0_INV_MOD_P1: u32 = inv_mod(P1, P0);
513    pub const P0_INV_MOD_P1_SHOUP: u32 = shoup(P1, P0_INV_MOD_P1);
514    pub const P01_INV_MOD_P2: u32 = inv_mod(P2, mul_mod(P2, P0, P1));
515    pub const P01_INV_MOD_P2_SHOUP: u32 = shoup(P2, P01_INV_MOD_P2);
516    pub const P012_INV_MOD_P3: u32 = inv_mod(P3, mul_mod(P3, mul_mod(P3, P0, P1), P2));
517    pub const P012_INV_MOD_P3_SHOUP: u32 = shoup(P3, P012_INV_MOD_P3);
518    pub const P0123_INV_MOD_P4: u32 =
519        inv_mod(P4, mul_mod(P4, mul_mod(P4, mul_mod(P4, P0, P1), P2), P3));
520    pub const P0123_INV_MOD_P4_SHOUP: u32 = shoup(P4, P0123_INV_MOD_P4);
521
522    pub const P0_MOD_P2_SHOUP: u32 = shoup(P2, P0);
523    pub const P0_MOD_P3_SHOUP: u32 = shoup(P3, P0);
524    pub const P1_MOD_P3_SHOUP: u32 = shoup(P3, P1);
525    pub const P0_MOD_P4_SHOUP: u32 = shoup(P4, P0);
526    pub const P1_MOD_P4_SHOUP: u32 = shoup(P4, P1);
527    pub const P2_MOD_P4_SHOUP: u32 = shoup(P4, P2);
528
529    pub const P1_INV_MOD_P2: u32 = inv_mod(P2, P1);
530    pub const P1_INV_MOD_P2_SHOUP: u32 = shoup(P2, P1_INV_MOD_P2);
531    pub const P3_INV_MOD_P4: u32 = inv_mod(P4, P3);
532    pub const P3_INV_MOD_P4_SHOUP: u32 = shoup(P4, P3_INV_MOD_P4);
533    pub const P12: u64 = P1 as u64 * P2 as u64;
534    pub const P34: u64 = P3 as u64 * P4 as u64;
535    pub const P0_INV_MOD_P12: u64 =
536        exp_mod64(P12, P0 as u64, (P1 as u64 - 1) * (P2 as u64 - 1) - 1);
537    pub const P0_INV_MOD_P12_SHOUP: u64 = shoup64(P12, P0_INV_MOD_P12);
538    pub const P0_MOD_P34_SHOUP: u64 = shoup64(P34, P0 as u64);
539    pub const P012_INV_MOD_P34: u64 = exp_mod64(
540        P34,
541        mul_mod64(P34, P0 as u64, P12),
542        (P3 as u64 - 1) * (P4 as u64 - 1) - 1,
543    );
544    pub const P012_INV_MOD_P34_SHOUP: u64 = shoup64(P34, P012_INV_MOD_P34);
545
546    pub const P2_INV_MOD_P3: u32 = inv_mod(P3, P2);
547    pub const P2_INV_MOD_P3_SHOUP: u32 = shoup(P3, P2_INV_MOD_P3);
548    pub const P4_INV_MOD_P5: u32 = inv_mod(P5, P4);
549    pub const P4_INV_MOD_P5_SHOUP: u32 = shoup(P5, P4_INV_MOD_P5);
550    pub const P6_INV_MOD_P7: u32 = inv_mod(P7, P6);
551    pub const P6_INV_MOD_P7_SHOUP: u32 = shoup(P7, P6_INV_MOD_P7);
552    pub const P8_INV_MOD_P9: u32 = inv_mod(P9, P8);
553    pub const P8_INV_MOD_P9_SHOUP: u32 = shoup(P9, P8_INV_MOD_P9);
554
555    pub const P01: u64 = P0 as u64 * P1 as u64;
556    pub const P23: u64 = P2 as u64 * P3 as u64;
557    pub const P45: u64 = P4 as u64 * P5 as u64;
558    pub const P67: u64 = P6 as u64 * P7 as u64;
559    pub const P89: u64 = P8 as u64 * P9 as u64;
560
561    pub const P01_MOD_P45_SHOUP: u64 = shoup64(P45, P01);
562    pub const P01_MOD_P67_SHOUP: u64 = shoup64(P67, P01);
563    pub const P01_MOD_P89_SHOUP: u64 = shoup64(P89, P01);
564
565    pub const P23_MOD_P67_SHOUP: u64 = shoup64(P67, P23);
566    pub const P23_MOD_P89_SHOUP: u64 = shoup64(P89, P23);
567
568    pub const P45_MOD_P89_SHOUP: u64 = shoup64(P89, P45);
569
570    pub const P01_INV_MOD_P23: u64 = exp_mod64(P23, P01, (P2 as u64 - 1) * (P3 as u64 - 1) - 1);
571    pub const P01_INV_MOD_P23_SHOUP: u64 = shoup64(P23, P01_INV_MOD_P23);
572    pub const P0123_INV_MOD_P45: u64 = exp_mod64(
573        P45,
574        mul_mod64(P45, P01, P23),
575        (P4 as u64 - 1) * (P5 as u64 - 1) - 1,
576    );
577    pub const P0123_INV_MOD_P45_SHOUP: u64 = shoup64(P45, P0123_INV_MOD_P45);
578    pub const P012345_INV_MOD_P67: u64 = exp_mod64(
579        P67,
580        mul_mod64(P67, mul_mod64(P67, P01, P23), P45),
581        (P6 as u64 - 1) * (P7 as u64 - 1) - 1,
582    );
583    pub const P012345_INV_MOD_P67_SHOUP: u64 = shoup64(P67, P012345_INV_MOD_P67);
584    pub const P01234567_INV_MOD_P89: u64 = exp_mod64(
585        P89,
586        mul_mod64(P89, mul_mod64(P89, mul_mod64(P89, P01, P23), P45), P67),
587        (P8 as u64 - 1) * (P9 as u64 - 1) - 1,
588    );
589    pub const P01234567_INV_MOD_P89_SHOUP: u64 = shoup64(P89, P01234567_INV_MOD_P89);
590
591    pub const P0123: u128 = u128::wrapping_mul(P01 as u128, P23 as u128);
592    pub const P012345: u128 = u128::wrapping_mul(P0123, P45 as u128);
593    pub const P01234567: u128 = u128::wrapping_mul(P012345, P67 as u128);
594    pub const P0123456789: u128 = u128::wrapping_mul(P01234567, P89 as u128);
595}
596
597#[allow(dead_code)]
598pub(crate) mod primes52 {
599    use crate::fastdiv::Div64;
600
601    pub const P0: u64 = 0b0011_1111_1111_1111_1111_1111_1110_0111_0111_0000_0000_0000_0001;
602    pub const P1: u64 = 0b0011_1111_1111_1111_1111_1111_1110_1011_1001_0000_0000_0000_0001;
603    pub const P2: u64 = 0b0011_1111_1111_1111_1111_1111_1110_1100_1000_0000_0000_0000_0001;
604    pub const P3: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1000_1011_0000_0000_0000_0001;
605    pub const P4: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1011_1000_0000_0000_0000_0001;
606    pub const P5: u64 = 0b0011_1111_1111_1111_1111_1111_1111_1100_0111_0000_0000_0000_0001;
607
608    pub const P0_MAGIC: u64 = 9223372247845040859;
609    pub const P1_MAGIC: u64 = 4611686106205779591;
610    pub const P2_MAGIC: u64 = 4611686102179247601;
611    pub const P3_MAGIC: u64 = 2305843024917166187;
612    pub const P4_MAGIC: u64 = 4611686037754736721;
613    pub const P5_MAGIC: u64 = 4611686033728204851;
614
615    pub const P0_MAGIC_SHIFT: u32 = 49;
616    pub const P1_MAGIC_SHIFT: u32 = 48;
617    pub const P2_MAGIC_SHIFT: u32 = 48;
618    pub const P3_MAGIC_SHIFT: u32 = 47;
619    pub const P4_MAGIC_SHIFT: u32 = 48;
620    pub const P5_MAGIC_SHIFT: u32 = 48;
621
622    const fn mul_mod(modulus: u64, a: u64, b: u64) -> u64 {
623        let wide = a as u128 * b as u128;
624        (wide % modulus as u128) as u64
625    }
626
627    const fn inv_mod(modulus: u64, x: u64) -> u64 {
628        crate::prime::exp_mod64(Div64::new(modulus), x, modulus - 2)
629    }
630
631    const fn shoup(modulus: u64, w: u64) -> u64 {
632        (((w as u128) << 52) / modulus as u128) as u64
633    }
634
635    pub const P0_INV_MOD_P1: u64 = inv_mod(P1, P0);
636    pub const P0_INV_MOD_P1_SHOUP: u64 = shoup(P1, P0_INV_MOD_P1);
637
638    pub const P01_INV_MOD_P2: u64 = inv_mod(P2, mul_mod(P2, P0, P1));
639    pub const P01_INV_MOD_P2_SHOUP: u64 = shoup(P2, P01_INV_MOD_P2);
640    pub const P012_INV_MOD_P3: u64 = inv_mod(P3, mul_mod(P3, mul_mod(P3, P0, P1), P2));
641    pub const P012_INV_MOD_P3_SHOUP: u64 = shoup(P3, P012_INV_MOD_P3);
642    pub const P0123_INV_MOD_P4: u64 =
643        inv_mod(P4, mul_mod(P4, mul_mod(P4, mul_mod(P4, P0, P1), P2), P3));
644    pub const P0123_INV_MOD_P4_SHOUP: u64 = shoup(P4, P0123_INV_MOD_P4);
645
646    pub const P0_MOD_P2_SHOUP: u64 = shoup(P2, P0);
647    pub const P0_MOD_P3_SHOUP: u64 = shoup(P3, P0);
648    pub const P1_MOD_P3_SHOUP: u64 = shoup(P3, P1);
649    pub const P0_MOD_P4_SHOUP: u64 = shoup(P4, P0);
650    pub const P1_MOD_P4_SHOUP: u64 = shoup(P4, P1);
651    pub const P2_MOD_P4_SHOUP: u64 = shoup(P4, P2);
652}
653
654macro_rules! izip {
655    (@ __closure @ ($a:expr)) => { |a| (a,) };
656    (@ __closure @ ($a:expr, $b:expr)) => { |(a, b)| (a, b) };
657    (@ __closure @ ($a:expr, $b:expr, $c:expr)) => { |((a, b), c)| (a, b, c) };
658    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr)) => { |(((a, b), c), d)| (a, b, c, d) };
659    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr)) => { |((((a, b), c), d), e)| (a, b, c, d, e) };
660    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr)) => { |(((((a, b), c), d), e), f)| (a, b, c, d, e, f) };
661    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr)) => { |((((((a, b), c), d), e), f), g)| (a, b, c, d, e, f, g) };
662    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr)) => { |(((((((a, b), c), d), e), f), g), h)| (a, b, c, d, e, f, g, h) };
663    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr)) => { |((((((((a, b), c), d), e), f), g), h), i)| (a, b, c, d, e, f, g, h, i) };
664    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr)) => { |(((((((((a, b), c), d), e), f), g), h), i), j)| (a, b, c, d, e, f, g, h, i, j) };
665    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr)) => { |((((((((((a, b), c), d), e), f), g), h), i), j), k)| (a, b, c, d, e, f, g, h, i, j, k) };
666    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr)) => { |(((((((((((a, b), c), d), e), f), g), h), i), j), k), l)| (a, b, c, d, e, f, g, h, i, j, k, l) };
667    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr)) => { |((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m)| (a, b, c, d, e, f, g, h, i, j, k, l, m) };
668    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr)) => { |(((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n) };
669    (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr)) => { |((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) };
670
671    ( $first:expr $(,)?) => {
672        {
673            ::core::iter::IntoIterator::into_iter($first)
674        }
675    };
676    ( $first:expr, $($rest:expr),+ $(,)?) => {
677        {
678            ::core::iter::IntoIterator::into_iter($first)
679                $(.zip($rest))*
680                .map(crate::izip!(@ __closure @ ($first, $($rest),*)))
681        }
682    };
683}
684pub(crate) use izip;
685
686#[cfg(test)]
687mod tests {
688    use crate::prime::largest_prime_in_arithmetic_progression64;
689    use rand::random;
690
691    #[test]
692    fn test_barrett32() {
693        let p =
694            largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 30, 1 << 31).unwrap() as u32;
695
696        let big_q: u32 = p.ilog2() + 1;
697        let big_l: u32 = big_q + 31;
698        let k: u32 = ((1u128 << big_l) / p as u128).try_into().unwrap();
699
700        for _ in 0..10000 {
701            let a = random::<u32>() % p;
702            let b = random::<u32>() % p;
703
704            let d = a as u64 * b as u64;
705            // Q < 31
706            // d < 2^(2Q)
707            // (d >> (Q-1)) < 2^(Q+1)         -> c1 fits in u32
708            let c1 = (d >> (big_q - 1)) as u32;
709            // c2 < 2^(Q+33)
710            let c3 = ((c1 as u64 * k as u64) >> 32) as u32;
711            let c = (d as u32).wrapping_sub(p.wrapping_mul(c3));
712            let c = if c >= p { c - p } else { c };
713            assert_eq!(c as u64, d % p as u64);
714        }
715    }
716
717    #[test]
718    fn test_barrett52() {
719        let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 50, 1 << 51).unwrap();
720
721        let big_q: u32 = p.ilog2() + 1;
722        let big_l: u32 = big_q + 51;
723        let k: u64 = ((1u128 << big_l) / p as u128).try_into().unwrap();
724
725        for _ in 0..10000 {
726            let a = random::<u64>() % p;
727            let b = random::<u64>() % p;
728
729            let d = a as u128 * b as u128;
730            // Q < 51
731            // d < 2^(2Q)
732            // (d >> (Q-1)) < 2^(Q+1)         -> c1 fits in u64
733            let c1 = (d >> (big_q - 1)) as u64;
734            // c2 < 2^(Q+53)
735            let c3 = ((c1 as u128 * k as u128) >> 52) as u64;
736            let c = (d as u64).wrapping_sub(p.wrapping_mul(c3));
737            let c = if c >= p { c - p } else { c };
738            assert_eq!(c as u128, d % p as u128);
739        }
740    }
741
742    #[test]
743    fn test_barrett64() {
744        let p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 1 << 62, 1 << 63).unwrap();
745
746        let big_q: u32 = p.ilog2() + 1;
747        let big_l: u32 = big_q + 63;
748        let k: u64 = ((1u128 << big_l) / p as u128).try_into().unwrap();
749
750        for _ in 0..10000 {
751            let a = random::<u64>() % p;
752            let b = random::<u64>() % p;
753
754            let d = a as u128 * b as u128;
755            // Q < 63
756            // d < 2^(2Q)
757            // (d >> (Q-1)) < 2^(Q+1)         -> c1 fits in u64
758            let c1 = (d >> (big_q - 1)) as u64;
759            // c2 < 2^(Q+65)
760            let c3 = ((c1 as u128 * k as u128) >> 64) as u64;
761            let c = (d as u64).wrapping_sub(p.wrapping_mul(c3));
762            let c = if c >= p { c - p } else { c };
763            assert_eq!(c as u128, d % p as u128);
764        }
765    }
766
767    // primes should be of the form x * LARGEST_POLYNOMIAL_SIZE(2^16) + 1
768    // primes should be < 2^30 or < 2^50, for NTT efficiency
769    // primes should satisfy the magic property documented above the primes32 module
770    // primes should be as large as possible
771    #[cfg(feature = "std")]
772    #[test]
773    fn generate_primes() {
774        let mut p = 1u64 << 30;
775        for _ in 0..100 {
776            p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, p - 1).unwrap();
777            println!("{p:#034b}");
778        }
779
780        let mut p = 1u64 << 50;
781        for _ in 0..100 {
782            p = largest_prime_in_arithmetic_progression64(1 << 16, 1, 0, p - 1).unwrap();
783            println!("{p:#054b}");
784        }
785    }
786}
787
788#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
789#[cfg(test)]
790mod x86_tests {
791    use super::*;
792    use rand::random as rnd;
793
794    #[test]
795    fn test_widening_mul() {
796        if let Some(simd) = crate::V3::try_new() {
797            let a = u64x4(rnd(), rnd(), rnd(), rnd());
798            let b = u64x4(rnd(), rnd(), rnd(), rnd());
799            let (lo, hi) = simd.widening_mul_u64x4(a, b);
800            assert_eq!(
801                lo,
802                u64x4(
803                    u64::wrapping_mul(a.0, b.0),
804                    u64::wrapping_mul(a.1, b.1),
805                    u64::wrapping_mul(a.2, b.2),
806                    u64::wrapping_mul(a.3, b.3),
807                ),
808            );
809            assert_eq!(
810                hi,
811                u64x4(
812                    ((a.0 as u128 * b.0 as u128) >> 64) as u64,
813                    ((a.1 as u128 * b.1 as u128) >> 64) as u64,
814                    ((a.2 as u128 * b.2 as u128) >> 64) as u64,
815                    ((a.3 as u128 * b.3 as u128) >> 64) as u64,
816                ),
817            );
818        }
819
820        #[cfg(feature = "nightly")]
821        if let Some(simd) = crate::V4::try_new() {
822            let a = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
823            let b = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
824            let (lo, hi) = simd.widening_mul_u64x8(a, b);
825            assert_eq!(
826                lo,
827                u64x8(
828                    u64::wrapping_mul(a.0, b.0),
829                    u64::wrapping_mul(a.1, b.1),
830                    u64::wrapping_mul(a.2, b.2),
831                    u64::wrapping_mul(a.3, b.3),
832                    u64::wrapping_mul(a.4, b.4),
833                    u64::wrapping_mul(a.5, b.5),
834                    u64::wrapping_mul(a.6, b.6),
835                    u64::wrapping_mul(a.7, b.7),
836                ),
837            );
838            assert_eq!(
839                hi,
840                u64x8(
841                    ((a.0 as u128 * b.0 as u128) >> 64) as u64,
842                    ((a.1 as u128 * b.1 as u128) >> 64) as u64,
843                    ((a.2 as u128 * b.2 as u128) >> 64) as u64,
844                    ((a.3 as u128 * b.3 as u128) >> 64) as u64,
845                    ((a.4 as u128 * b.4 as u128) >> 64) as u64,
846                    ((a.5 as u128 * b.5 as u128) >> 64) as u64,
847                    ((a.6 as u128 * b.6 as u128) >> 64) as u64,
848                    ((a.7 as u128 * b.7 as u128) >> 64) as u64,
849                ),
850            );
851        }
852    }
853
854    #[test]
855    fn test_mul_low_32_bits() {
856        if let Some(simd) = crate::V3::try_new() {
857            let a = u64x4(rnd(), rnd(), rnd(), rnd());
858            let b = u64x4(rnd(), rnd(), rnd(), rnd());
859            let res = simd.mul_low_32_bits_u64x4(a, b);
860            assert_eq!(
861                res,
862                u64x4(
863                    a.0 as u32 as u64 * b.0 as u32 as u64,
864                    a.1 as u32 as u64 * b.1 as u32 as u64,
865                    a.2 as u32 as u64 * b.2 as u32 as u64,
866                    a.3 as u32 as u64 * b.3 as u32 as u64,
867                ),
868            );
869        }
870        #[cfg(feature = "nightly")]
871        if let Some(simd) = crate::V4::try_new() {
872            let a = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
873            let b = u64x8(rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd(), rnd());
874            let res = simd.mul_low_32_bits_u64x8(a, b);
875            assert_eq!(
876                res,
877                u64x8(
878                    a.0 as u32 as u64 * b.0 as u32 as u64,
879                    a.1 as u32 as u64 * b.1 as u32 as u64,
880                    a.2 as u32 as u64 * b.2 as u32 as u64,
881                    a.3 as u32 as u64 * b.3 as u32 as u64,
882                    a.4 as u32 as u64 * b.4 as u32 as u64,
883                    a.5 as u32 as u64 * b.5 as u32 as u64,
884                    a.6 as u32 as u64 * b.6 as u32 as u64,
885                    a.7 as u32 as u64 * b.7 as u32 as u64,
886                ),
887            );
888        }
889    }
890
891    #[test]
892    fn test_mul_lhs_with_low_32_bits_of_rhs() {
893        if let Some(simd) = crate::V3::try_new() {
894            let a = u64x4(rnd(), rnd(), rnd(), rnd());
895            let b = u64x4(rnd(), rnd(), rnd(), rnd());
896            let res = simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(a, b);
897            assert_eq!(
898                res,
899                u64x4(
900                    u64::wrapping_mul(a.0, b.0 as u32 as u64),
901                    u64::wrapping_mul(a.1, b.1 as u32 as u64),
902                    u64::wrapping_mul(a.2, b.2 as u32 as u64),
903                    u64::wrapping_mul(a.3, b.3 as u32 as u64),
904                ),
905            );
906        }
907    }
908}