poulpy_hal/layouts/module.rs
1use std::{
2 fmt::{Debug, Display},
3 marker::PhantomData,
4 ptr::NonNull,
5};
6
7use bytemuck::Pod;
8use rand_distr::num_traits::Zero;
9
10use crate::layouts::{Data, Location, MatZnx, ScalarZnx, VecZnx};
11use crate::{
12 GALOISGENERATOR,
13 api::{ModuleLogN, ModuleN},
14};
15
16/// Core trait that every backend (CPU, GPU, FPGA, ...) must implement.
17///
18/// Defines the scalar types used for DFT-domain (`ScalarPrep`) and
19/// extended-precision (`ScalarBig`) representations, as well as the
20/// opaque `Handle` type that holds backend-specific precomputed state
21/// (e.g. FFT twiddle factors).
22///
23/// # Safety
24///
25/// [`destroy`](Backend::destroy) is called during [`Module`] drop and must
26/// correctly deallocate the handle without double-free.
27#[allow(clippy::missing_safety_doc)]
28pub trait Backend: Sized + Sync + Send {
29 /// Scalar type for extended-precision (big) polynomial representations.
30 type ScalarBig: Copy + Zero + Display + Debug + Pod;
31 /// Scalar type for DFT-domain (prepared) polynomial representations.
32 type ScalarPrep: Copy + Zero + Display + Debug + Pod;
33 /// Owned backend storage for layouts and scratch.
34 ///
35 /// This buffer may be host-resident or device-resident. It is intentionally
36 /// no longer required to expose direct host byte slices.
37 type OwnedBuf: Data + Send + Sync;
38 /// Shared borrowed view into backend-owned storage.
39 type BufRef<'a>: Data + Sync
40 where
41 Self: 'a;
42 /// Mutable borrowed view into backend-owned storage.
43 type BufMut<'a>: Data + Send
44 where
45 Self: 'a;
46 /// Opaque backend handle type (e.g. precomputed FFT twiddle factors).
47 type Handle: 'static;
48 /// Residency of this backend's buffers — [`Host`](crate::layouts::Host)
49 /// or [`Device`](crate::layouts::Device).
50 type Location: Location;
51 /// Allocates a backend-owned byte buffer of `len` bytes.
52 fn alloc_bytes(len: usize) -> Self::OwnedBuf;
53 /// Allocates a zero-initialized backend-owned byte buffer of `len` bytes.
54 ///
55 /// Backends may override this with a device-native implementation
56 /// (e.g. `cudaMemset`-backed allocation). The default implementation
57 /// falls back to allocating first and then zero-filling through the
58 /// existing host upload path.
59 fn alloc_zeroed_bytes(len: usize) -> Self::OwnedBuf {
60 let mut buf = Self::alloc_bytes(len);
61 let zeros = vec![0u8; len];
62 Self::copy_from_host(&mut buf, &zeros);
63 buf
64 }
65 /// Uploads or copies host bytes into backend-owned storage.
66 fn from_host_bytes(bytes: &[u8]) -> Self::OwnedBuf;
67 /// Wraps/Uploads a host-owned byte buffer into backend-owned storage.
68 ///
69 /// Backends may override this for a zero-copy fast path when the input is
70 /// already in a compatible host representation.
71 fn from_bytes(bytes: Vec<u8>) -> Self::OwnedBuf;
72 /// Copies the contents of a backend-owned buffer into a fresh host `Vec<u8>`.
73 ///
74 /// For host backends this is typically a simple clone of the underlying
75 /// storage; for device backends it performs a device-to-host download.
76 fn to_host_bytes(buf: &Self::OwnedBuf) -> Vec<u8>;
77 /// Copies the contents of a backend-owned buffer into a host byte slice.
78 ///
79 /// `dst.len()` must equal the byte length of `buf`.
80 fn copy_to_host(buf: &Self::OwnedBuf, dst: &mut [u8]);
81 /// Copies a host byte slice into a backend-owned buffer.
82 ///
83 /// `src.len()` must equal the byte length of `buf`.
84 fn copy_from_host(buf: &mut Self::OwnedBuf, src: &[u8]);
85 /// Returns the number of bytes stored in a backend-owned buffer.
86 fn len_bytes(buf: &Self::OwnedBuf) -> usize;
87 /// Borrows a shared backend-native view over an owned buffer.
88 fn view(buf: &Self::OwnedBuf) -> Self::BufRef<'_>;
89 /// Reborrows an existing shared backend-native view.
90 fn view_ref<'a, 'b>(buf: &'a Self::BufRef<'b>) -> Self::BufRef<'a>
91 where
92 Self: 'b;
93 /// Reborrows a mutable backend-native view as a shared backend-native view.
94 fn view_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>) -> Self::BufRef<'a>
95 where
96 Self: 'b;
97 /// Reborrows an existing mutable backend-native view.
98 fn view_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>) -> Self::BufMut<'a>
99 where
100 Self: 'b;
101 /// Borrows a mutable backend-native view over an owned buffer.
102 fn view_mut(buf: &mut Self::OwnedBuf) -> Self::BufMut<'_>;
103 /// Borrows a shared sub-region of an owned buffer.
104 fn region(buf: &Self::OwnedBuf, offset: usize, len: usize) -> Self::BufRef<'_>;
105 /// Borrows a mutable sub-region of an owned buffer.
106 fn region_mut(buf: &mut Self::OwnedBuf, offset: usize, len: usize) -> Self::BufMut<'_>;
107 /// Reborrows a shared sub-region of an existing shared backend-native view.
108 fn region_ref<'a, 'b>(buf: &'a Self::BufRef<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
109 where
110 Self: 'b;
111 /// Reborrows a shared sub-region of an existing mutable backend-native view.
112 fn region_ref_mut<'a, 'b>(buf: &'a Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufRef<'a>
113 where
114 Self: 'b;
115 /// Reborrows a mutable sub-region of an existing mutable backend-native view.
116 fn region_mut_ref<'a, 'b>(buf: &'a mut Self::BufMut<'b>, offset: usize, len: usize) -> Self::BufMut<'a>
117 where
118 Self: 'b;
119 /// Bytes size of `ScalarBig`.
120 fn size_of_scalar_big() -> usize {
121 size_of::<Self::ScalarBig>()
122 }
123 /// Bytes size of `ScalarPrep`.
124 fn size_of_scalar_prep() -> usize {
125 size_of::<Self::ScalarPrep>()
126 }
127
128 /// Required alignment (in bytes) for scratch-arena carved regions.
129 ///
130 /// Default to 64 (one CPU cache line). Device backends should override this
131 /// to match their native memory alignment requirement (e.g. 128 for CUDA,
132 /// 256 for ROCm). `ScratchArena::align_up` uses this constant so that
133 /// carved regions satisfy both alignment and SIMD requirements.
134 const SCRATCH_ALIGN: usize = 64;
135
136 /// Byte size of a [`crate::layouts::VecZnxDft`] buffer.
137 fn bytes_of_vec_znx_dft(n: usize, cols: usize, size: usize) -> usize {
138 n * cols * size * Self::size_of_scalar_prep()
139 }
140 /// Byte size of a [`crate::layouts::VecZnxBig`] buffer.
141 fn bytes_of_vec_znx_big(n: usize, cols: usize, size: usize) -> usize {
142 n * cols * size * Self::size_of_scalar_big()
143 }
144 /// Byte size of a [`crate::layouts::SvpPPol`] buffer.
145 fn bytes_of_svp_ppol(n: usize, cols: usize) -> usize {
146 n * cols * Self::size_of_scalar_prep()
147 }
148 /// Byte size of a [`crate::layouts::VmpPMat`] buffer.
149 fn bytes_of_vmp_pmat(n: usize, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> usize {
150 n * rows * cols_in * cols_out * size * Self::size_of_scalar_prep()
151 }
152 /// Byte size of a [`crate::layouts::CnvPVecL`] buffer.
153 fn bytes_of_cnv_pvec_left(n: usize, cols: usize, size: usize) -> usize {
154 n * cols * size * Self::size_of_scalar_prep()
155 }
156 /// Byte size of a [`crate::layouts::CnvPVecR`] buffer.
157 fn bytes_of_cnv_pvec_right(n: usize, cols: usize, size: usize) -> usize {
158 n * cols * size * Self::size_of_scalar_prep()
159 }
160 /// Deallocates a backend handle.
161 ///
162 /// # Safety
163 ///
164 /// `handle` must be a valid, non-dangling pointer that was previously
165 /// returned by the backend's allocation routine. Must not be called
166 /// more than once on the same handle.
167 unsafe fn destroy(handle: NonNull<Self::Handle>);
168}
169
170/// Primary entry point for all polynomial operations over `Z[X]/(X^N + 1)`.
171///
172/// A `Module` pairs a ring degree `N` (always a power of two) with a
173/// backend-specific handle that holds any required precomputed state. All
174/// [`api`](crate::api) trait methods are dispatched through this type.
175///
176/// The module **owns** its handle; dropping the `Module` calls
177/// [`Backend::destroy`].
178#[repr(C)]
179pub struct Module<B: Backend> {
180 ptr: NonNull<B::Handle>,
181 n: u64,
182 _marker: PhantomData<B>,
183}
184
185unsafe impl<B: Backend> Sync for Module<B> {}
186unsafe impl<B: Backend> Send for Module<B> {}
187
188impl<B: Backend> Module<B> {
189 /// Creates a backend module for ring degree `N`.
190 #[inline]
191 pub fn new(n: u64) -> Self
192 where
193 Self: crate::api::ModuleNew<B>,
194 {
195 crate::api::ModuleNew::new(n)
196 }
197
198 /// Creates a module from a [`NonNull`] backend handle.
199 ///
200 /// # Safety
201 ///
202 /// `ptr` must point to a valid, fully initialized backend handle whose
203 /// lifetime is transferred to this `Module` (it will be destroyed on drop).
204 #[allow(clippy::missing_safety_doc)]
205 #[inline]
206 pub unsafe fn from_nonnull(ptr: NonNull<B::Handle>, n: u64) -> Self {
207 assert!(n.is_power_of_two(), "n must be a power of two, got {n}");
208 Self {
209 ptr,
210 n,
211 _marker: PhantomData,
212 }
213 }
214
215 /// Construct from a raw pointer managed elsewhere.
216 /// SAFETY: `ptr` must be non-null and remain valid for the lifetime of this Module.
217 #[inline]
218 #[allow(clippy::missing_safety_doc)]
219 pub unsafe fn from_raw_parts(ptr: *mut B::Handle, n: u64) -> Self {
220 assert!(n.is_power_of_two(), "n must be a power of two, got {n}");
221 Self {
222 ptr: NonNull::new(ptr).expect("null module ptr"),
223 n,
224 _marker: PhantomData,
225 }
226 }
227
228 /// Returns the raw pointer to the backend handle.
229 #[allow(clippy::missing_safety_doc)]
230 #[inline]
231 pub unsafe fn ptr(&self) -> *mut <B as Backend>::Handle {
232 self.ptr.as_ptr()
233 }
234
235 /// Returns the ring degree `N`.
236 #[inline]
237 pub fn n(&self) -> usize {
238 self.n as usize
239 }
240
241 /// Allocates a zero-initialized backend-owned [`ScalarZnx`].
242 #[inline]
243 pub fn scalar_znx_alloc(&self, cols: usize) -> ScalarZnx<B::OwnedBuf> {
244 let n = self.n();
245 let len = ScalarZnx::<Vec<u8>>::bytes_of(n, cols);
246 let bytes = B::alloc_zeroed_bytes(len);
247 ScalarZnx::from_data(bytes, n, cols)
248 }
249
250 /// Allocates a zero-initialized backend-owned [`VecZnx`].
251 #[inline]
252 pub fn vec_znx_alloc(&self, cols: usize, size: usize) -> VecZnx<B::OwnedBuf> {
253 self.vec_znx_alloc_with_max_size(cols, size, size)
254 }
255
256 /// Allocates a zero-initialized backend-owned [`VecZnx`] with explicit limb capacity.
257 #[inline]
258 pub fn vec_znx_alloc_with_max_size(&self, cols: usize, size: usize, max_size: usize) -> VecZnx<B::OwnedBuf> {
259 let n = self.n();
260 let len = VecZnx::<Vec<u8>>::bytes_of(n, cols, max_size);
261 let bytes = B::alloc_zeroed_bytes(len);
262 VecZnx::from_data_with_max_size(bytes, n, cols, size, max_size)
263 }
264
265 /// Allocates a zero-initialized backend-owned [`MatZnx`].
266 #[inline]
267 pub fn mat_znx_alloc(&self, rows: usize, cols_in: usize, cols_out: usize, size: usize) -> MatZnx<B::OwnedBuf> {
268 let n = self.n();
269 let len = MatZnx::<Vec<u8>>::bytes_of(n, rows, cols_in, cols_out, size);
270 let bytes = B::alloc_zeroed_bytes(len);
271 MatZnx::from_data(bytes, n, rows, cols_in, cols_out, size)
272 }
273
274 /// Returns the raw pointer to the backend handle.
275 #[inline]
276 pub fn as_mut_ptr(&self) -> *mut B::Handle {
277 self.ptr.as_ptr()
278 }
279
280 /// Returns `log2(N)`.
281 #[inline]
282 pub fn log_n(&self) -> usize {
283 (usize::BITS - (self.n() - 1).leading_zeros()) as _
284 }
285
286 /// Reinterprets this `Module<B>` as a `Module<Other>` sharing the same
287 /// backend `Handle` type.
288 ///
289 /// This is a zero-cost view used to forward API calls to a compatible
290 /// source backend without rebuilding the handle.
291 #[inline]
292 pub fn reinterpret<Other>(&self) -> &Module<Other>
293 where
294 Other: Backend<Handle = B::Handle>,
295 {
296 // Safety: Module is #[repr(C)] and only contains an optional NonNull<Handle>,
297 // a u64, and a ZST PhantomData. When `Handle` matches, the layout is identical.
298 unsafe { &*(self as *const Self as *const Module<Other>) }
299 }
300
301 /// Mutable version of [`Module::reinterpret`].
302 #[inline]
303 pub fn reinterpret_mut<Other>(&mut self) -> &mut Module<Other>
304 where
305 Other: Backend<Handle = B::Handle>,
306 {
307 // Safety: see Module::reinterpret.
308 unsafe { &mut *(self as *mut Self as *mut Module<Other>) }
309 }
310}
311
312/// Returns the cyclotomic order `2N` for the ring `Z[X]/(X^N + 1)`.
313pub trait CyclotomicOrder
314where
315 Self: ModuleN,
316{
317 /// Returns `2N`, the order of the cyclotomic polynomial `X^N + 1`.
318 fn cyclotomic_order(&self) -> i64 {
319 (self.n() << 1) as _
320 }
321}
322
323impl<BE: Backend> ModuleLogN for Module<BE> where Self: ModuleN {}
324
325impl<BE: Backend> CyclotomicOrder for Module<BE> where Self: ModuleN {}
326
327/// Computes [`GALOISGENERATOR`]`^|generator| * sign(generator) mod cyclotomic_order`.
328///
329/// Returns `1` when `generator == 0`.
330///
331/// # Panics (debug)
332///
333/// Debug-asserts that `cyclotomic_order` is a positive power of two.
334#[inline(always)]
335pub fn galois_element(generator: i64, cyclotomic_order: i64) -> i64 {
336 debug_assert!(
337 cyclotomic_order > 0 && (cyclotomic_order as u64).is_power_of_two(),
338 "cyclotomic_order must be a power of two, got {cyclotomic_order}"
339 );
340
341 if generator == 0 {
342 return 1;
343 }
344
345 let g_exp: u64 = mod_exp_u64(GALOISGENERATOR, generator.unsigned_abs() as usize) & (cyclotomic_order - 1) as u64;
346 g_exp as i64 * generator.signum()
347}
348
349/// Galois group operations on the cyclotomic ring `Z[X]/(X^N + 1)`.
350///
351/// The Galois group `(Z/2NZ)*` acts on polynomials via the automorphisms
352/// `X -> X^k` for odd `k`. This trait provides methods to compute
353/// Galois elements and their inverses from a signed generator exponent.
354pub trait GaloisElement
355where
356 Self: CyclotomicOrder,
357{
358 /// Returns [`GALOISGENERATOR`]`^|generator| * sign(generator) mod 2N`.
359 fn galois_element(&self, generator: i64) -> i64 {
360 galois_element(generator, self.cyclotomic_order())
361 }
362
363 /// Returns the inverse of `gal_el` in the Galois group `(Z/2NZ)*`.
364 ///
365 /// # Panics
366 ///
367 /// Panics if `gal_el == 0`.
368 fn galois_element_inv(&self, gal_el: i64) -> i64 {
369 if gal_el == 0 {
370 panic!("cannot invert 0")
371 }
372
373 let g_exp: u64 =
374 mod_exp_u64(gal_el.unsigned_abs(), (self.cyclotomic_order() - 1) as usize) & (self.cyclotomic_order() - 1) as u64;
375 g_exp as i64 * gal_el.signum()
376 }
377}
378
379impl<BE: Backend> GaloisElement for Module<BE> where Self: CyclotomicOrder {}
380
381impl<B: Backend> Drop for Module<B> {
382 fn drop(&mut self) {
383 unsafe { B::destroy(self.ptr) }
384 }
385}
386
387/// Computes `x^e mod 2^64` using square-and-multiply with wrapping arithmetic.
388pub fn mod_exp_u64(x: u64, e: usize) -> u64 {
389 let mut y: u64 = 1;
390 let mut x_pow: u64 = x;
391 let mut exp = e;
392 while exp > 0 {
393 if exp & 1 == 1 {
394 y = y.wrapping_mul(x_pow);
395 }
396 x_pow = x_pow.wrapping_mul(x_pow);
397 exp >>= 1;
398 }
399 y
400}