Skip to main content

poulpy_core/layouts/prepared/
gglwe.rs

1use poulpy_hal::{
2    api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes},
3    layouts::{Backend, Data, Module, ScratchArena, VmpPMat, VmpPMatToBackendMut, VmpPMatToBackendRef},
4};
5
6use crate::layouts::{
7    Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEToBackendRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision,
8};
9
10/// DFT-domain (prepared) variant of [`GGLWE`].
11///
12/// Stores the gadget GLWE matrix with polynomials in the frequency domain
13/// of the backend's DFT/NTT transform, enabling O(N log N) polynomial
14/// multiplication. The underlying data is held as a [`VmpPMat`], which
15/// represents a prepared matrix suitable for vector-matrix products.
16///
17/// Tied to a specific backend via `B: Backend`.
18#[derive(PartialEq, Eq)]
19pub struct GGLWEPrepared<D: Data, B: Backend> {
20    pub(crate) data: VmpPMat<D, B>,
21    pub(crate) base2k: Base2K,
22    pub(crate) dsize: Dsize,
23}
24
25pub type GGLWEPreparedBackendRef<'a, B> = GGLWEPrepared<<B as Backend>::BufRef<'a>, B>;
26pub type GGLWEPreparedBackendMut<'a, B> = GGLWEPrepared<<B as Backend>::BufMut<'a>, B>;
27
28/// Provides LWE-level parameter accessors (degree, base2k, precision, size).
29impl<D: Data, B: Backend> LWEInfos for GGLWEPrepared<D, B> {
30    fn n(&self) -> Degree {
31        Degree(self.data.n() as u32)
32    }
33
34    fn base2k(&self) -> Base2K {
35        self.base2k
36    }
37
38    fn size(&self) -> usize {
39        self.data.size()
40    }
41}
42
43/// Provides the GLWE rank, derived from the output rank.
44impl<D: Data, B: Backend> GLWEInfos for GGLWEPrepared<D, B> {
45    fn rank(&self) -> Rank {
46        self.rank_out()
47    }
48}
49
50/// Provides GGLWE-specific parameter accessors (input/output rank, dsize, dnum).
51impl<D: Data, B: Backend> GGLWEInfos for GGLWEPrepared<D, B> {
52    fn rank_in(&self) -> Rank {
53        Rank(self.data.cols_in() as u32)
54    }
55
56    fn rank_out(&self) -> Rank {
57        Rank(self.data.cols_out() as u32 - 1)
58    }
59
60    fn dsize(&self) -> Dsize {
61        self.dsize
62    }
63
64    fn dnum(&self) -> Dnum {
65        Dnum(self.data.rows() as u32)
66    }
67}
68
69/// Factory trait for allocating and preparing [`GGLWEPrepared`] instances.
70///
71/// Requires the backend module to support VMP prepared-matrix allocation,
72/// byte-size queries, and the prepare transform.
73pub trait GGLWEPreparedFactory<BE: Backend>
74where
75    Self: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes,
76{
77    /// Allocates a new [`GGLWEPrepared`] with the given parameters.
78    ///
79    /// Panics if `dnum * dsize > ceil(k / base2k)`.
80    fn gglwe_prepared_alloc(
81        &self,
82        base2k: Base2K,
83        k: TorusPrecision,
84        rank_in: Rank,
85        rank_out: Rank,
86        dnum: Dnum,
87        dsize: Dsize,
88    ) -> GGLWEPrepared<BE::OwnedBuf, BE> {
89        let size: usize = k.0.div_ceil(base2k.0) as usize;
90        debug_assert!(
91            size as u32 > dsize.0,
92            "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
93            dsize.0
94        );
95
96        assert!(
97            dnum.0 * dsize.0 <= size as u32,
98            "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
99            dnum.0,
100            dsize.0,
101        );
102
103        GGLWEPrepared {
104            data: self.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size),
105            base2k,
106            dsize,
107        }
108    }
109
110    /// Allocates a new [`GGLWEPrepared`] matching the parameters of `infos`.
111    fn gglwe_prepared_alloc_from_infos<A>(&self, infos: &A) -> GGLWEPrepared<BE::OwnedBuf, BE>
112    where
113        A: GGLWEInfos,
114    {
115        assert_eq!(self.ring_degree(), infos.n());
116        self.gglwe_prepared_alloc(
117            infos.base2k(),
118            infos.max_k(),
119            infos.rank_in(),
120            infos.rank_out(),
121            infos.dnum(),
122            infos.dsize(),
123        )
124    }
125
126    /// Returns the byte size required to store a [`GGLWEPrepared`] with the given parameters.
127    fn gglwe_prepared_bytes_of(
128        &self,
129        base2k: Base2K,
130        k: TorusPrecision,
131        rank_in: Rank,
132        rank_out: Rank,
133        dnum: Dnum,
134        dsize: Dsize,
135    ) -> usize {
136        let size: usize = k.0.div_ceil(base2k.0) as usize;
137        debug_assert!(
138            size as u32 > dsize.0,
139            "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
140            dsize.0
141        );
142
143        assert!(
144            dnum.0 * dsize.0 <= size as u32,
145            "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
146            dnum.0,
147            dsize.0,
148        );
149
150        self.bytes_of_vmp_pmat(dnum.into(), rank_in.into(), (rank_out + 1).into(), size)
151    }
152
153    /// Returns the byte size required to store a [`GGLWEPrepared`] matching `infos`.
154    fn gglwe_prepared_bytes_of_from_infos<A>(&self, infos: &A) -> usize
155    where
156        A: GGLWEInfos,
157    {
158        assert_eq!(self.ring_degree(), infos.n());
159        self.gglwe_prepared_bytes_of(
160            infos.base2k(),
161            infos.max_k(),
162            infos.rank_in(),
163            infos.rank_out(),
164            infos.dnum(),
165            infos.dsize(),
166        )
167    }
168
169    /// Returns the scratch-space bytes needed by [`gglwe_prepare`](Self::gglwe_prepare).
170    fn gglwe_prepare_tmp_bytes<A>(&self, infos: &A) -> usize
171    where
172        A: GGLWEInfos,
173    {
174        let lvl_0: usize = self.vmp_prepare_tmp_bytes(
175            infos.dnum().into(),
176            infos.rank_in().into(),
177            (infos.rank() + 1).into(),
178            infos.size(),
179        );
180        lvl_0
181    }
182
183    /// Transforms a standard [`GGLWE`] into the DFT domain, writing the result into `res`.
184    ///
185    /// Both `res` and `other` must share the same ring degree, base2k, precision, and dsize.
186    fn gglwe_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut ScratchArena<'_, BE>)
187    where
188        R: GGLWEPreparedToBackendMut<BE>,
189        O: GGLWEToBackendRef<BE>,
190    {
191        let mut res = res.to_backend_mut();
192        let other = other.to_backend_ref();
193
194        assert_eq!(res.n(), self.ring_degree());
195        assert_eq!(other.n(), self.ring_degree());
196        assert_eq!(res.base2k, other.base2k);
197        assert_eq!(res.size(), other.size());
198        assert_eq!(res.dsize, other.dsize);
199        assert!(
200            scratch.available() >= self.gglwe_prepare_tmp_bytes(&res),
201            "scratch.available(): {} < GGLWEPreparedFactory::gglwe_prepare_tmp_bytes: {}",
202            scratch.available(),
203            self.gglwe_prepare_tmp_bytes(&res)
204        );
205        self.vmp_prepare(&mut res.data, &other.data, scratch);
206    }
207}
208
209impl<BE: Backend> GGLWEPreparedFactory<BE> for Module<BE> where
210    Module<BE>: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes
211{
212}
213
214// module-only API: allocation/size helpers are provided by `GGLWEPreparedFactory` on `Module`.
215
216// module-only API: preparation is provided by `GGLWEPreparedFactory` on `Module`.
217
218pub trait GGLWEPreparedToBackendRef<B: Backend> {
219    fn to_backend_ref(&self) -> GGLWEPreparedBackendRef<'_, B>;
220}
221
222impl<B: Backend> GGLWEPreparedToBackendRef<B> for GGLWEPrepared<B::OwnedBuf, B> {
223    fn to_backend_ref(&self) -> GGLWEPreparedBackendRef<'_, B> {
224        GGLWEPrepared {
225            base2k: self.base2k,
226            dsize: self.dsize,
227            data: self.data.to_backend_ref(),
228        }
229    }
230}
231
232pub trait GGLWEPreparedToBackendMut<B: Backend> {
233    fn to_backend_mut(&mut self) -> GGLWEPreparedBackendMut<'_, B>;
234}
235
236impl<B: Backend> GGLWEPreparedToBackendMut<B> for GGLWEPrepared<B::OwnedBuf, B> {
237    fn to_backend_mut(&mut self) -> GGLWEPreparedBackendMut<'_, B> {
238        GGLWEPrepared {
239            base2k: self.base2k,
240            dsize: self.dsize,
241            data: self.data.to_backend_mut(),
242        }
243    }
244}