Skip to main content

poulpy_core/layouts/prepared/
ggsw.rs

1use poulpy_hal::{
2    api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes, VmpZero},
3    layouts::{Backend, Data, HostDataRef, Module, ScratchArena, VmpPMat, VmpPMatToBackendMut, VmpPMatToBackendRef},
4};
5
6use crate::layouts::{
7    Base2K, Degree, Dnum, Dsize, GGSWInfos, GGSWToBackendRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision,
8};
9
10/// DFT-domain (prepared) variant of [`GGSW`].
11///
12/// Stores the GGSW gadget matrix with polynomials in the frequency domain
13/// of the backend's DFT/NTT transform, enabling O(N log N) polynomial
14/// operations. Tied to a specific backend via `B: Backend`.
15#[derive(PartialEq, Eq)]
16pub struct GGSWPrepared<D: Data, B: Backend> {
17    pub(crate) data: VmpPMat<D, B>,
18    pub(crate) base2k: Base2K,
19    pub(crate) dsize: Dsize,
20}
21
22pub type GGSWPreparedBackendRef<'a, B> = GGSWPrepared<<B as Backend>::BufRef<'a>, B>;
23pub type GGSWPreparedBackendMut<'a, B> = GGSWPrepared<<B as Backend>::BufMut<'a>, B>;
24
25impl<D: Data, B: Backend> LWEInfos for GGSWPrepared<D, B> {
26    fn n(&self) -> Degree {
27        Degree(self.data.n() as u32)
28    }
29
30    fn base2k(&self) -> Base2K {
31        self.base2k
32    }
33
34    fn size(&self) -> usize {
35        self.data.size()
36    }
37}
38
39impl<D: Data, B: Backend> LWEInfos for &GGSWPrepared<D, B> {
40    fn n(&self) -> Degree {
41        Degree(self.data.n() as u32)
42    }
43
44    fn base2k(&self) -> Base2K {
45        self.base2k
46    }
47
48    fn size(&self) -> usize {
49        self.data.size()
50    }
51}
52
53impl<D: Data, B: Backend> GLWEInfos for GGSWPrepared<D, B> {
54    fn rank(&self) -> Rank {
55        Rank(self.data.cols_out() as u32 - 1)
56    }
57}
58
59impl<D: Data, B: Backend> GLWEInfos for &GGSWPrepared<D, B> {
60    fn rank(&self) -> Rank {
61        Rank(self.data.cols_out() as u32 - 1)
62    }
63}
64
65impl<D: Data, B: Backend> GGSWInfos for GGSWPrepared<D, B> {
66    fn dsize(&self) -> Dsize {
67        self.dsize
68    }
69
70    fn dnum(&self) -> Dnum {
71        Dnum(self.data.rows() as u32)
72    }
73}
74
75impl<D: Data, B: Backend> GGSWInfos for &GGSWPrepared<D, B> {
76    fn dsize(&self) -> Dsize {
77        self.dsize
78    }
79
80    fn dnum(&self) -> Dnum {
81        Dnum(self.data.rows() as u32)
82    }
83}
84
85/// Trait for allocating and preparing DFT-domain GGSW ciphertexts.
86pub trait GGSWPreparedFactory<B: Backend>
87where
88    Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B> + VmpZero<B>,
89{
90    /// Allocates a new prepared GGSW with the given parameters.
91    fn ggsw_prepared_alloc(
92        &self,
93        base2k: Base2K,
94        k: TorusPrecision,
95        dnum: Dnum,
96        dsize: Dsize,
97        rank: Rank,
98    ) -> GGSWPrepared<B::OwnedBuf, B> {
99        let size: usize = k.0.div_ceil(base2k.0) as usize;
100        debug_assert!(
101            size as u32 > dsize.0,
102            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
103            dsize.0
104        );
105
106        assert!(
107            dnum.0 * dsize.0 <= size as u32,
108            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
109            dnum.0,
110            dsize.0,
111        );
112
113        GGSWPrepared {
114            data: self.vmp_pmat_alloc(
115                dnum.into(),
116                (rank + 1).into(),
117                (rank + 1).into(),
118                k.0.div_ceil(base2k.0) as usize,
119            ),
120            base2k,
121            dsize,
122        }
123    }
124
125    fn ggsw_prepared_alloc_from_infos<A>(&self, infos: &A) -> GGSWPrepared<B::OwnedBuf, B>
126    where
127        A: GGSWInfos,
128    {
129        assert_eq!(self.ring_degree(), infos.n());
130        self.ggsw_prepared_alloc(infos.base2k(), infos.max_k(), infos.dnum(), infos.dsize(), infos.rank())
131    }
132
133    fn ggsw_prepared_bytes_of(&self, base2k: Base2K, k: TorusPrecision, dnum: Dnum, dsize: Dsize, rank: Rank) -> usize {
134        let size: usize = k.0.div_ceil(base2k.0) as usize;
135        debug_assert!(
136            size as u32 > dsize.0,
137            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
138            dsize.0
139        );
140
141        assert!(
142            dnum.0 * dsize.0 <= size as u32,
143            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
144            dnum.0,
145            dsize.0,
146        );
147
148        self.bytes_of_vmp_pmat(dnum.into(), (rank + 1).into(), (rank + 1).into(), size)
149    }
150
151    fn ggsw_prepared_bytes_of_from_infos<A>(&self, infos: &A) -> usize
152    where
153        A: GGSWInfos,
154    {
155        assert_eq!(self.ring_degree(), infos.n());
156        self.ggsw_prepared_bytes_of(infos.base2k(), infos.max_k(), infos.dnum(), infos.dsize(), infos.rank())
157    }
158
159    fn ggsw_prepare_tmp_bytes<A>(&self, infos: &A) -> usize
160    where
161        A: GGSWInfos,
162    {
163        assert_eq!(self.ring_degree(), infos.n());
164        let lvl_0: usize = self.vmp_prepare_tmp_bytes(
165            infos.dnum().into(),
166            (infos.rank() + 1).into(),
167            (infos.rank() + 1).into(),
168            infos.size(),
169        );
170        lvl_0
171    }
172    fn ggsw_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut ScratchArena<'_, B>)
173    where
174        R: GGSWPreparedToBackendMut<B>,
175        O: GGSWToBackendRef<B>,
176    {
177        let mut res = res.to_backend_mut();
178        let other = other.to_backend_ref();
179        assert_eq!(res.n(), self.ring_degree());
180        assert_eq!(other.n(), self.ring_degree());
181        assert_eq!(res.base2k, other.base2k);
182        assert_eq!(res.dsize, other.dsize);
183        assert!(
184            scratch.available() >= self.ggsw_prepare_tmp_bytes(&res),
185            "scratch.available(): {} < GGSWPreparedFactory::ggsw_prepare_tmp_bytes: {}",
186            scratch.available(),
187            self.ggsw_prepare_tmp_bytes(&res)
188        );
189        self.vmp_prepare(&mut res.data, &other.data, scratch);
190    }
191
192    fn ggsw_zero<R>(&self, res: &mut R)
193    where
194        R: GGSWPreparedToBackendMut<B>,
195    {
196        let mut res = res.to_backend_mut();
197        self.vmp_zero(&mut res.data);
198    }
199}
200
201impl<B: Backend> GGSWPreparedFactory<B> for Module<B> where
202    Self: GetDegree + VmpPMatAlloc<B> + VmpPMatBytesOf + VmpPrepareTmpBytes + VmpPrepare<B> + VmpZero<B>
203{
204}
205
206// module-only API: allocation/size helpers are provided by `GGSWPreparedFactory` on `Module`.
207
208impl<D: HostDataRef, B: Backend> GGSWPrepared<D, B> {
209    pub fn data(&self) -> &VmpPMat<D, B> {
210        &self.data
211    }
212}
213
214// module-only API: preparation sizing is provided by `GGSWPreparedFactory` on `Module`.
215
216// module-only API: preparation and zeroing are provided by `GGSWPreparedFactory` on `Module`.
217
218pub trait GGSWPreparedToBackendRef<B: Backend> {
219    fn to_backend_ref(&self) -> GGSWPreparedBackendRef<'_, B>;
220}
221
222impl<B: Backend> GGSWPreparedToBackendRef<B> for GGSWPrepared<B::OwnedBuf, B> {
223    fn to_backend_ref(&self) -> GGSWPreparedBackendRef<'_, B> {
224        GGSWPrepared {
225            base2k: self.base2k,
226            dsize: self.dsize,
227            data: self.data.to_backend_ref(),
228        }
229    }
230}
231
232pub trait GGSWPreparedToBackendMut<B: Backend> {
233    fn to_backend_mut(&mut self) -> GGSWPreparedBackendMut<'_, B>;
234}
235
236impl<B: Backend> GGSWPreparedToBackendMut<B> for GGSWPrepared<B::OwnedBuf, B> {
237    fn to_backend_mut(&mut self) -> GGSWPreparedBackendMut<'_, B> {
238        GGSWPrepared {
239            base2k: self.base2k,
240            dsize: self.dsize,
241            data: self.data.to_backend_mut(),
242        }
243    }
244}