Skip to main content

poulpy_cpu_avx512/
lib.rs

1//! AVX-512 / AVX-512-IFMA accelerated CPU backends for the Poulpy lattice cryptography library.
2//!
3//! This crate provides three backend implementations for [`poulpy_hal`]:
4//!
5//! - `FFT64Avx512`: f64 FFT backend, gated on `enable-avx512f`.
6//! - `NTT120Avx512`: Q120 NTT backend over four ~30-bit CRT primes, gated on `enable-avx512f`.
7//! - `NTT126Ifma`: Q126 NTT backend over three ~42-bit CRT primes, gated on `enable-ifma`.
8//!
9//! # Architecture
10//!
11//! `poulpy_hal` defines a hardware abstraction layer (HAL) via the
12//! [`Backend`](poulpy_hal::layouts::Backend) trait and open extension point
13//! (OEP) traits in [`poulpy_hal::oep`]. This crate implements those extension
14//! points with AVX-512F, AVX-512-IFMA, AVX2/FMA, and scalar/reference fallback
15//! paths depending on the backend and operation family.
16//!
17//! The internal modules are organized by operation domain:
18//!
19//! | Module             | Domain                                                     |
20//! |--------------------|------------------------------------------------------------|
21//! | `fft64`            | `FFT64Avx512` backend and REIM FFT table wrappers          |
22//! | `znx_avx512`       | AVX-512F single ring element arithmetic                    |
23//! | `ntt120_avx512`    | `NTT120Avx512` NTT, VMP, convolution, and DFT kernels      |
24//! | `ntt126_ifma`      | `NTT126Ifma` IFMA NTT, VMP, SVP, convolution, and DFT code |
25//! | `hal_impl`         | HAL OEP implementations and default wiring                 |
26//! | `vec_znx_big_avx512` | AVX-512F i128 accumulator helpers                        |
27//!
28//! # Scalar types
29//!
30//! - `FFT64Avx512`: `ScalarPrep = f64`, `ScalarBig = i64`.
31//! - `NTT120Avx512`: `ScalarPrep = Q120bScalar`, `ScalarBig = i128`.
32//! - `NTT126Ifma`: `ScalarPrep = Q120bScalar`, `ScalarBig = i128`.
33//!
34//! # CPU requirements
35//!
36//! `FFT64Avx512` and `NTT120Avx512` require x86-64 with AVX-512F. The FFT64
37//! backend also uses AVX2 and FMA kernels and checks those features at module
38//! construction.
39//!
40//! `NTT126Ifma` additionally requires AVX-512-IFMA, AVX-512VL, BMI2, and ADX.
41//! Runtime CPU feature detection is performed in
42//! [`Module::new()`](poulpy_hal::api::ModuleNew::new); missing runtime features
43//! cause a descriptive panic.
44//!
45//! # Compile-time requirements
46//!
47//! Backends are opt-in through Cargo features and matching target features:
48//!
49//! ```text
50//! RUSTFLAGS="-C target-feature=+avx512f" \
51//!     cargo build --release --features enable-avx512f
52//!
53//! RUSTFLAGS="-C target-feature=+avx512f,+avx512ifma,+avx512vl,+bmi2,+adx" \
54//!     cargo build --release --features enable-ifma
55//! ```
56//!
57//! If neither feature is enabled, this crate compiles as an empty shell so the
58//! workspace remains portable on machines without AVX-512. Code that imports
59//! AVX-512 backend types must enable the feature that exports them.
60//!
61//! # Correctness guarantees
62//!
63//! Operations are deterministic across runs. FFT operations are constrained to
64//! preserve the rounding behavior expected by the reference backend, while NTT
65//! operations are exact modulo their CRT prime sets.
66//!
67//! Integer overflow in limb arithmetic is intentional where the bivariate
68//! representation relies on wrapping arithmetic to propagate carries correctly
69//! across base-2^k limbs.
70//!
71//! # Safety invariants
72//!
73//! Unsafe kernels require:
74//!
75//! - the selected backend's CPU features to be enabled and present at runtime,
76//! - input and output layouts to have matching shapes and documented bounds,
77//! - buffers to satisfy the alignment required by `poulpy_hal::DEFAULTALIGN`.
78//!
79//! Violating those invariants may cause undefined behavior, panics, or silent
80//! arithmetic errors.
81//!
82//! # Threading and concurrency
83//!
84//! Backend marker types are zero-sized and `Send + Sync`. `Module<BE>` values
85//! hold immutable precomputed tables after construction. Operations take
86//! mutable output references, so normal Rust borrowing rules prevent data races
87//! at the API boundary.
88//!
89//! # Feature flags
90//!
91//! - `enable-avx512f`: exports `FFT64Avx512` and `NTT120Avx512`.
92//! - `enable-ifma`: implies `enable-avx512f` and also exports `NTT126Ifma`.
93//! - `enable-ckks`: wires these backends into `poulpy-ckks` defaults.
94//!
95//! # Platform support
96//!
97//! - Required: x86-64.
98//! - `FFT64Avx512`: AVX-512F + AVX2 + FMA.
99//! - `NTT120Avx512`: AVX-512F.
100//! - `NTT126Ifma`: AVX-512F + AVX-512-IFMA + AVX-512VL + BMI2 + ADX.
101//! - Non-x86 targets and x86-64 CPUs without the selected feature set are not supported.
102//!
103//! # Usage
104//!
105//! The public backend marker types are used as type parameters to HAL, core,
106//! CKKS, and bin-FHE generic APIs. Application code usually selects one of
107//! these types in the backend-owning crate or benchmark harness.
108//!
109//! # Versioning and stability
110//!
111//! The public API consists of the backend marker types, FFT table wrappers, and
112//! the `ntt126_ifma_api` support exports used by benchmarks. Other items are
113//! implementation details.
114
115#[cfg(all(feature = "enable-avx512f", not(docsrs), not(target_arch = "x86_64")))]
116compile_error!("feature `enable-avx512f` requires target_arch = \"x86_64\".");
117
118#[cfg(all(
119    feature = "enable-avx512f",
120    not(docsrs),
121    target_arch = "x86_64",
122    not(target_feature = "avx512f")
123))]
124compile_error!("feature `enable-avx512f` requires AVX512F. Build with RUSTFLAGS=\"-C target-feature=+avx512f\".");
125
126#[cfg(all(
127    feature = "enable-ifma",
128    not(docsrs),
129    target_arch = "x86_64",
130    not(target_feature = "avx512ifma")
131))]
132compile_error!(
133    "feature `enable-ifma` requires AVX512-IFMA. Build with RUSTFLAGS=\"-C target-feature=+avx512f,+avx512ifma,+avx512vl,+bmi2,+adx\"."
134);
135
136#[cfg(all(
137    feature = "enable-ifma",
138    not(docsrs),
139    target_arch = "x86_64",
140    not(target_feature = "avx512vl")
141))]
142compile_error!(
143    "feature `enable-ifma` requires AVX512VL. Build with RUSTFLAGS=\"-C target-feature=+avx512f,+avx512ifma,+avx512vl,+bmi2,+adx\"."
144);
145
146#[cfg(all(feature = "enable-ifma", not(docsrs), target_arch = "x86_64", not(target_feature = "bmi2")))]
147compile_error!(
148    "feature `enable-ifma` requires BMI2. Build with RUSTFLAGS=\"-C target-feature=+avx512f,+avx512ifma,+avx512vl,+bmi2,+adx\"."
149);
150
151#[cfg(all(feature = "enable-ifma", not(docsrs), target_arch = "x86_64", not(target_feature = "adx")))]
152compile_error!(
153    "feature `enable-ifma` requires ADX. Build with RUSTFLAGS=\"-C target-feature=+avx512f,+avx512ifma,+avx512vl,+bmi2,+adx\"."
154);
155
156#[cfg(feature = "enable-avx512f")]
157mod fft64;
158#[cfg(feature = "enable-avx512f")]
159mod hal_impl;
160#[cfg(feature = "enable-avx512f")]
161mod ntt120_avx512;
162#[cfg(feature = "enable-avx512f")]
163mod znx_avx512;
164
165#[cfg(feature = "enable-avx512f")]
166mod vec_znx_big_avx512;
167
168#[cfg(feature = "enable-ifma")]
169mod ntt126_ifma;
170
171#[cfg(feature = "enable-avx512f")]
172pub use fft64::{FFT64Avx512, FFT64Avx512ReimTable, ReimFFTAvx512, ReimIFFTAvx512};
173#[cfg(feature = "enable-avx512f")]
174pub use ntt120_avx512::NTT120Avx512;
175#[cfg(feature = "enable-ifma")]
176pub use ntt126_ifma::NTT126Ifma;
177
178/// Public surface for tools that drive [`NTT126Ifma`] kernels directly (e.g. the
179/// benches): the precomputed twiddle tables, the prime set, and the
180/// [`Ntt126IfmaDFTExecute`](ntt126_ifma_api::Ntt126IfmaDFTExecute) trait used to
181/// dispatch a forward / inverse NTT.
182///
183/// The scalar test oracles for the IFMA SIMD kernels live under
184/// `crate::ntt126_ifma::reference` and are not re-exported.
185#[cfg(feature = "enable-ifma")]
186pub mod ntt126_ifma_api {
187    pub use crate::ntt126_ifma::primes::{PrimeSetNtt126Ifma, Primes42};
188    pub use crate::ntt126_ifma::tables::{Ntt126IfmaTable, Ntt126IfmaTableInv};
189    pub use crate::ntt126_ifma::traits::Ntt126IfmaDFTExecute;
190}
191
192#[cfg(all(feature = "enable-avx512f", feature = "enable-ckks"))]
193mod ckks_impl;
194#[cfg(feature = "enable-avx512f")]
195mod core_impl;
196
197#[cfg(all(test, feature = "enable-avx512f", feature = "enable-ckks"))]
198mod tests;
199
200// --- TransferFrom impls ---
201#[cfg(feature = "enable-avx512f")]
202mod transfer_impls {
203    use poulpy_cpu_ref::{FFT64Ref, NTT120Ref};
204    use poulpy_hal::layouts::{Backend, TransferFrom};
205
206    #[cfg(feature = "enable-ifma")]
207    use crate::NTT126Ifma;
208    use crate::{FFT64Avx512, NTT120Avx512};
209
210    impl TransferFrom<FFT64Avx512> for FFT64Avx512 {
211        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
212            FFT64Avx512::from_host_bytes(&FFT64Avx512::to_host_bytes(src))
213        }
214    }
215    impl TransferFrom<FFT64Ref> for FFT64Avx512 {
216        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
217            FFT64Avx512::from_host_bytes(&FFT64Ref::to_host_bytes(src))
218        }
219    }
220
221    impl TransferFrom<NTT120Avx512> for NTT120Avx512 {
222        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
223            NTT120Avx512::from_host_bytes(&NTT120Avx512::to_host_bytes(src))
224        }
225    }
226    impl TransferFrom<NTT120Ref> for NTT120Avx512 {
227        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
228            NTT120Avx512::from_host_bytes(&NTT120Ref::to_host_bytes(src))
229        }
230    }
231
232    // Cross-family: coefficient-domain buffers are compatible.
233    // Prepared layouts must not be transferred directly; transfer the
234    // non-prepared form and re-prepare on the destination backend.
235    impl TransferFrom<NTT120Ref> for FFT64Avx512 {
236        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
237            FFT64Avx512::from_host_bytes(&NTT120Ref::to_host_bytes(src))
238        }
239    }
240    impl TransferFrom<NTT120Avx512> for FFT64Avx512 {
241        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
242            FFT64Avx512::from_host_bytes(&NTT120Avx512::to_host_bytes(src))
243        }
244    }
245    impl TransferFrom<FFT64Ref> for NTT120Avx512 {
246        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
247            NTT120Avx512::from_host_bytes(&FFT64Ref::to_host_bytes(src))
248        }
249    }
250    impl TransferFrom<FFT64Avx512> for NTT120Avx512 {
251        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
252            NTT120Avx512::from_host_bytes(&FFT64Avx512::to_host_bytes(src))
253        }
254    }
255
256    #[cfg(feature = "enable-ifma")]
257    impl TransferFrom<NTT126Ifma> for NTT126Ifma {
258        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
259            NTT126Ifma::from_host_bytes(&NTT126Ifma::to_host_bytes(src))
260        }
261    }
262    #[cfg(feature = "enable-ifma")]
263    impl TransferFrom<NTT120Ref> for NTT126Ifma {
264        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
265            NTT126Ifma::from_host_bytes(&NTT120Ref::to_host_bytes(src))
266        }
267    }
268    #[cfg(feature = "enable-ifma")]
269    impl TransferFrom<FFT64Ref> for NTT126Ifma {
270        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
271            NTT126Ifma::from_host_bytes(&FFT64Ref::to_host_bytes(src))
272        }
273    }
274    #[cfg(feature = "enable-ifma")]
275    impl TransferFrom<NTT120Avx512> for NTT126Ifma {
276        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
277            NTT126Ifma::from_host_bytes(&NTT120Avx512::to_host_bytes(src))
278        }
279    }
280    #[cfg(feature = "enable-ifma")]
281    impl TransferFrom<FFT64Avx512> for NTT126Ifma {
282        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
283            NTT126Ifma::from_host_bytes(&FFT64Avx512::to_host_bytes(src))
284        }
285    }
286    #[cfg(feature = "enable-ifma")]
287    impl TransferFrom<NTT126Ifma> for FFT64Avx512 {
288        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
289            FFT64Avx512::from_host_bytes(&NTT126Ifma::to_host_bytes(src))
290        }
291    }
292    #[cfg(feature = "enable-ifma")]
293    impl TransferFrom<NTT126Ifma> for NTT120Avx512 {
294        fn transfer_buf(src: &Vec<u8>) -> Vec<u8> {
295            NTT120Avx512::from_host_bytes(&NTT126Ifma::to_host_bytes(src))
296        }
297    }
298}