Skip to main content

poulpy_hal/layouts/
module.rs

1use std::{
2    fmt::{Debug, Display},
3    marker::PhantomData,
4    ptr::NonNull,
5};
6
7use bytemuck::Pod;
8use rand_distr::num_traits::Zero;
9
10use crate::layouts::{Data, Location, MatZnx, ScalarZnx, VecZnx};
11use crate::{
12    GALOISGENERATOR,
13    api::{ModuleLogN, ModuleN},
14};
15
16/// Core trait that every backend (CPU, GPU, FPGA, ...) must implement.
17///
18/// Defines the scalar types used for DFT-domain (`ScalarPrep`) and
19/// extended-precision (`ScalarBig`) representations, as well as the
20/// opaque `Handle` type that holds backend-specific precomputed state
21/// (e.g. FFT twiddle factors).
22///
23/// # Safety
24///
25/// [`destroy`](Backend::destroy) is called during [`Module`] drop and must
26/// correctly deallocate the handle without double-free.
27#[allow(clippy::missing_safety_doc)]
28pub trait Backend: Sized + Sync + Send {
29    /// Scalar type for extended-precision (big) polynomial representations.
30    type ScalarBig: Copy + Zero + Display + Debug + Pod;
31    /// Scalar type for DFT-domain (prepared) polynomial representations.
32    type ScalarPrep: Copy + Zero + Display + Debug + Pod;
33    /// Owned backend storage for layouts and scratch.
34    ///
35    /// This buffer may be host-resident or device-resident. It is intentionally
36    /// no longer required to expose direct host byte slices.
37    type OwnedBuf: Data + Send + Sync;
38    /// Shared borrowed view into backend-owned storage.
39    type BufRef<'a>: Data + Sync
40    where
41        Self: 'a;
42    /// Mutable borrowed view into backend-owned storage.
43    type BufMut<'a>: Data + Send
44    where
45        Self: 'a;
46    /// Opaque backend handle type (e.g. precomputed FFT twiddle factors).
47    type Handle: 'static;
48    /// Residency of this backend's buffers — [`Host`](crate::layouts::Host)
49    /// or [`Device`](crate::layouts::Device).
50    type Location: Location;
51    /// Allocates a backend-owned byte buffer of `len` bytes.
52    fn alloc_bytes(len: usize) -> Self::OwnedBuf;
53    /// Allocates a zero-initialized backend-owned byte buffer of `len` bytes.
54    ///
55    /// Backends may override this with a device-native implementation
56    /// (e.g. `cudaMemset`-backed allocation). The default implementation
57    /// falls back to allocating first and then zero-filling through the
58    /// existing host upload path.
59    fn alloc_zeroed_bytes(len: usize) -> Self::OwnedBuf {
60        let mut buf = Self::alloc_bytes(len);
61        let zeros = vec![0u8; len];
62        Self::copy_from_host(&mut buf, &zeros);
63        buf
64    }
65    /// Uploads or copies host bytes into backend-owned storage.
66    fn from_host_bytes(bytes: &[u8]) -> Self::OwnedBuf;
67    /// Wraps/Uploads a host-owned byte buffer into backend-owned storage.
68    ///
69    /// Backends may override this for a zero-copy fast path when the input is
70    /// already in a compatible host representation.
71    fn from_bytes(bytes: Vec<u8>) -> Self::OwnedBuf;
72    /// Copies the contents of a backend-owned buffer into a fresh host `Vec<u8>`.
73    ///
74    /// For host backends this is typically a simple clone of the underlying
75    /// storage; for device backends it performs a device-to-host download.
76    fn to_host_bytes(buf: &Self::OwnedBuf) -> Vec<u8>;
77    /// Copies the contents of a backend-owned buffer into a host byte slice.
78    ///
79    /// `dst.len()` must equal the byte length of `buf`.
80    fn copy_to_host(buf: &Self::OwnedBuf, dst: &mut [u8]);
81    /// Copies a host byte slice into a backend-owned buffer.
82    ///
83    /// `src.len()` must equal the byte length of `buf`.
84    fn copy_from_host(buf: &mut Self::OwnedBuf, src: &[u8]);
85    /// Returns the number of bytes stored in a backend-owned buffer.
86    fn len_bytes(buf: &Self::OwnedBuf) -> usize;
87    /// Borrows a shared backend-native view over an owned buffer.
88    fn view(buf: &Self::OwnedBuf) -> Self::BufRef<'_>;
89    /// Reborrows an existing shared backend-native view.
90    fn view_ref<'a, 'b>(buf: &'a Self::BufRef<'b>) -> Self::BufRef<'a>
91    where
92        Self: 'b;
93    /// Reborrows a mutable backend-native view as a shared backend-native view.
94    fn view_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>) -> Self::BufRef<'a>
95    where
96        Self: 'b;
97    /// Reborrows an existing mutable backend-native view.
98    fn view_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>) -> Self::BufMut<'a>
99    where
100        Self: 'b;
101    /// Borrows a mutable backend-native view over an owned buffer.
102    fn view_mut(buf: &mut Self::OwnedBuf) -> Self::BufMut<'_>;
103    /// Borrows a shared sub-region of an owned buffer.
104    fn region(buf: &Self::OwnedBuf, offset: usize, len: usize) -> Self::BufRef<'_>;
105    /// Borrows a mutable sub-region of an owned buffer.
106    fn region_mut(buf: &mut Self::OwnedBuf, offset: usize, len: usize) -> Self::BufMut<'_>;
107    /// Reborrows a shared sub-region of an existing shared backend-native view.
108    fn region_ref<'a, 'b>(buf: &'a Self::BufRef<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
109    where
110        Self: 'b;
111    /// Reborrows a shared sub-region of an existing mutable backend-native view.
112    fn region_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
113    where
114        Self: 'b;
115    /// Reborrows a mutable sub-region of an existing mutable backend-native view.
116    fn region_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufMut<'a>
117    where
118        Self: 'b;
119    /// Bytes size of `ScalarBig`.
120    fn size_of_scalar_big() -> usize {
121        size_of::<Self::ScalarBig>()
122    }
123    /// Bytes size of `ScalarPrep`.
124    fn size_of_scalar_prep() -> usize {
125        size_of::<Self::ScalarPrep>()
126    }
127
128    /// Required alignment (in bytes) for scratch-arena carved regions.
129    ///
130    /// Default to 64 (one CPU cache line). Device backends should override this
131    /// to match their native memory alignment requirement (e.g. 128 for CUDA,
132    /// 256 for ROCm). `ScratchArena::align_up` uses this constant so that
133    /// carved regions satisfy both alignment and SIMD requirements.
134    const SCRATCH_ALIGN: usize = 64;
135
136    /// Byte size of a [`crate::layouts::VecZnxDft`] buffer.
137    fn bytes_of_vec_znx_dft(n: usize, cols: usize, size: usize) -> usize {
138        n * cols * size * Self::size_of_scalar_prep()
139    }
140    /// Byte size of a [`crate::layouts::VecZnxBig`] buffer.
141    fn bytes_of_vec_znx_big(n: usize, cols: usize, size: usize) -> usize {
142        n * cols * size * Self::size_of_scalar_big()
143    }
144    /// Byte size of a [`crate::layouts::SvpPPol`] buffer.
145    fn bytes_of_svp_ppol(n: usize, cols: usize) -> usize {
146        n * cols * Self::size_of_scalar_prep()
147    }
148    /// Byte size of a [`crate::layouts::VmpPMat`] buffer.
149    fn bytes_of_vmp_pmat(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
150        n * rows * cols_in * cols_out * size * Self::size_of_scalar_prep()
151    }
152    /// Byte size of a [`crate::layouts::CnvPVecL`] buffer.
153    fn bytes_of_cnv_pvec_left(n: usize, cols: usize, size: usize) -> usize {
154        n * cols * size * Self::size_of_scalar_prep()
155    }
156    /// Byte size of a [`crate::layouts::CnvPVecR`] buffer.
157    fn bytes_of_cnv_pvec_right(n: usize, cols: usize, size: usize) -> usize {
158        n * cols * size * Self::size_of_scalar_prep()
159    }
160    /// Deallocates a backend handle.
161    ///
162    /// # Safety
163    ///
164    /// `handle` must be a valid, non-dangling pointer that was previously
165    /// returned by the backend's allocation routine. Must not be called
166    /// more than once on the same handle.
167    unsafe fn destroy(handle: NonNull<Self::Handle>);
168}
169
170/// Primary entry point for all polynomial operations over `Z[X]/(X^N + 1)`.
171///
172/// A `Module` pairs a ring degree `N` (always a power of two) with a
173/// backend-specific handle that holds any required precomputed state. All
174/// [`api`](crate::api) trait methods are dispatched through this type.
175///
176/// The module **owns** its handle; dropping the `Module` calls
177/// [`Backend::destroy`].
178#[repr(C)]
179pub struct Module<B: Backend> {
180    ptr: NonNull<B::Handle>,
181    n: u64,
182    _marker: PhantomData<B>,
183}
184
185unsafe impl<B: Backend> Sync for Module<B> {}
186unsafe impl<B: Backend> Send for Module<B> {}
187
188impl<B: Backend> Module<B> {
189    /// Creates a backend module for ring degree `N`.
190    #[inline]
191    pub fn new(n: u64) -> Self
192    where
193        Self: crate::api::ModuleNew<B>,
194    {
195        crate::api::ModuleNew::new(n)
196    }
197
198    /// Creates a module from a [`NonNull`] backend handle.
199    ///
200    /// # Safety
201    ///
202    /// `ptr` must point to a valid, fully initialized backend handle whose
203    /// lifetime is transferred to this `Module` (it will be destroyed on drop).
204    #[allow(clippy::missing_safety_doc)]
205    #[inline]
206    pub unsafe fn from_nonnull(ptr: NonNull<B::Handle>, n: u64) -> Self {
207        assert!(n.is_power_of_two(), "n must be a power of two, got {n}");
208        Self {
209            ptr,
210            n,
211            _marker: PhantomData,
212        }
213    }
214
215    /// Construct from a raw pointer managed elsewhere.
216    /// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module.
217    #[inline]
218    #[allow(clippy::missing_safety_doc)]
219    pub unsafe fn from_raw_parts(ptr: *mut B::Handle, n: u64) -> Self {
220        assert!(n.is_power_of_two(), "n must be a power of two, got {n}");
221        Self {
222            ptr: NonNull::new(ptr).expect("null module ptr"),
223            n,
224            _marker: PhantomData,
225        }
226    }
227
228    /// Returns the raw pointer to the backend handle.
229    #[allow(clippy::missing_safety_doc)]
230    #[inline]
231    pub unsafe fn ptr(&self) -> *mut <B as Backend>::Handle {
232        self.ptr.as_ptr()
233    }
234
235    /// Returns the ring degree `N`.
236    #[inline]
237    pub fn n(&self) -> usize {
238        self.n as usize
239    }
240
241    /// Allocates a zero-initialized backend-owned [`ScalarZnx`].
242    #[inline]
243    pub fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnx<B::OwnedBuf> {
244        let n = self.n();
245        let len = ScalarZnx::<Vec<u8>>::bytes_of(n, cols);
246        let bytes = B::alloc_zeroed_bytes(len);
247        ScalarZnx::from_data(bytes, n, cols)
248    }
249
250    /// Allocates a zero-initialized backend-owned [`VecZnx`].
251    #[inline]
252    pub fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnx<B::OwnedBuf> {
253        self.vec_znx_alloc_with_max_size(cols, size, size)
254    }
255
256    /// Allocates a zero-initialized backend-owned [`VecZnx`] with explicit limb capacity.
257    #[inline]
258    pub fn vec_znx_alloc_with_max_size(&self, cols: usize, size: usize, max_size: usize) -> VecZnx<B::OwnedBuf> {
259        let n = self.n();
260        let len = VecZnx::<Vec<u8>>::bytes_of(n, cols, max_size);
261        let bytes = B::alloc_zeroed_bytes(len);
262        VecZnx::from_data_with_max_size(bytes, n, cols, size, max_size)
263    }
264
265    /// Allocates a zero-initialized backend-owned [`MatZnx`].
266    #[inline]
267    pub fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnx<B::OwnedBuf> {
268        let n = self.n();
269        let len = MatZnx::<Vec<u8>>::bytes_of(n, rows, cols_in, cols_out, size);
270        let bytes = B::alloc_zeroed_bytes(len);
271        MatZnx::from_data(bytes, n, rows, cols_in, cols_out, size)
272    }
273
274    /// Returns the raw pointer to the backend handle.
275    #[inline]
276    pub fn as_mut_ptr(&self) -> *mut B::Handle {
277        self.ptr.as_ptr()
278    }
279
280    /// Returns `log2(N)`.
281    #[inline]
282    pub fn log_n(&self) -> usize {
283        (usize::BITS - (self.n() - 1).leading_zeros()) as _
284    }
285
286    /// Reinterprets this `Module<B>` as a `Module<Other>` sharing the same
287    /// backend `Handle` type.
288    ///
289    /// This is a zero-cost view used to forward API calls to a compatible
290    /// source backend without rebuilding the handle.
291    #[inline]
292    pub fn reinterpret<Other>(&self) -> &Module<Other>
293    where
294        Other: Backend<Handle = B::Handle>,
295    {
296        // Safety: Module is #[repr(C)] and only contains an optional NonNull<Handle>,
297        // a u64, and a ZST PhantomData. When `Handle` matches, the layout is identical.
298        unsafe { &*(self as *const Self as *const Module<Other>) }
299    }
300
301    /// Mutable version of [`Module::reinterpret`].
302    #[inline]
303    pub fn reinterpret_mut<Other>(&mut self) -> &mut Module<Other>
304    where
305        Other: Backend<Handle = B::Handle>,
306    {
307        // Safety: see Module::reinterpret.
308        unsafe { &mut *(self as *mut Self as *mut Module<Other>) }
309    }
310}
311
312/// Returns the cyclotomic order `2N` for the ring `Z[X]/(X^N + 1)`.
313pub trait CyclotomicOrder
314where
315    Self: ModuleN,
316{
317    /// Returns `2N`, the order of the cyclotomic polynomial `X^N + 1`.
318    fn cyclotomic_order(&self) -> i64 {
319        (self.n() << 1) as _
320    }
321}
322
323impl<BE: Backend> ModuleLogN for Module<BE> where Self: ModuleN {}
324
325impl<BE: Backend> CyclotomicOrder for Module<BE> where Self: ModuleN {}
326
327/// Computes [`GALOISGENERATOR`]`^|generator| * sign(generator) mod cyclotomic_order`.
328///
329/// Returns `1` when `generator == 0`.
330///
331/// # Panics (debug)
332///
333/// Debug-asserts that `cyclotomic_order` is a positive power of two.
334#[inline(always)]
335pub fn galois_element(generator: i64, cyclotomic_order: i64) -> i64 {
336    debug_assert!(
337        cyclotomic_order > 0 && (cyclotomic_order as u64).is_power_of_two(),
338        "cyclotomic_order must be a power of two, got {cyclotomic_order}"
339    );
340
341    if generator == 0 {
342        return 1;
343    }
344
345    let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (cyclotomic_order - 1) as u64;
346    g_exp as i64 * generator.signum()
347}
348
349/// Galois group operations on the cyclotomic ring `Z[X]/(X^N + 1)`.
350///
351/// The Galois group `(Z/2NZ)*` acts on polynomials via the automorphisms
352/// `X -> X^k` for odd `k`. This trait provides methods to compute
353/// Galois elements and their inverses from a signed generator exponent.
354pub trait GaloisElement
355where
356    Self: CyclotomicOrder,
357{
358    /// Returns [`GALOISGENERATOR`]`^|generator| * sign(generator) mod 2N`.
359    fn galois_element(&self, generator: i64) -> i64 {
360        galois_element(generator, self.cyclotomic_order())
361    }
362
363    /// Returns the inverse of `gal_el` in the Galois group `(Z/2NZ)*`.
364    ///
365    /// # Panics
366    ///
367    /// Panics if `gal_el == 0`.
368    fn galois_element_inv(&self, gal_el: i64) -> i64 {
369        if gal_el == 0 {
370            panic!("cannot invert 0")
371        }
372
373        let g_exp: u64 =
374            mod_exp_u64(gal_el.unsigned_abs(), (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64;
375        g_exp as i64 * gal_el.signum()
376    }
377}
378
379impl<BE: Backend> GaloisElement for Module<BE> where Self: CyclotomicOrder {}
380
381impl<B: Backend> Drop for Module<B> {
382    fn drop(&mut self) {
383        unsafe { B::destroy(self.ptr) }
384    }
385}
386
387/// Computes `x^e mod 2^64` using square-and-multiply with wrapping arithmetic.
388pub fn mod_exp_u64(x: u64, e: usize) -> u64 {
389    let mut y: u64 = 1;
390    let mut x_pow: u64 = x;
391    let mut exp = e;
392    while exp > 0 {
393        if exp & 1 == 1 {
394            y = y.wrapping_mul(x_pow);
395        }
396        x_pow = x_pow.wrapping_mul(x_pow);
397        exp >>= 1;
398    }
399    y
400}