poulpy_core/layouts/prepared/
gglwe.rs

1use poulpy_hal::{
2    api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes},
3    layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos},
4};
5
6use crate::layouts::{
7    Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision,
8};
9
10#[derive(PartialEq, Eq)]
11pub struct GGLWEPrepared<D: Data, B: Backend> {
12    pub(crate) data: VmpPMat<D, B>,
13    pub(crate) k: TorusPrecision,
14    pub(crate) base2k: Base2K,
15    pub(crate) dsize: Dsize,
16}
17
18impl<D: Data, B: Backend> LWEInfos for GGLWEPrepared<D, B> {
19    fn n(&self) -> Degree {
20        Degree(self.data.n() as u32)
21    }
22
23    fn base2k(&self) -> Base2K {
24        self.base2k
25    }
26
27    fn k(&self) -> TorusPrecision {
28        self.k
29    }
30
31    fn size(&self) -> usize {
32        self.data.size()
33    }
34}
35
36impl<D: Data, B: Backend> GLWEInfos for GGLWEPrepared<D, B> {
37    fn rank(&self) -> Rank {
38        self.rank_out()
39    }
40}
41
42impl<D: Data, B: Backend> GGLWEInfos for GGLWEPrepared<D, B> {
43    fn rank_in(&self) -> Rank {
44        Rank(self.data.cols_in() as u32)
45    }
46
47    fn rank_out(&self) -> Rank {
48        Rank(self.data.cols_out() as u32 - 1)
49    }
50
51    fn dsize(&self) -> Dsize {
52        self.dsize
53    }
54
55    fn dnum(&self) -> Dnum {
56        Dnum(self.data.rows() as u32)
57    }
58}
59
60pub trait GGLWEPreparedFactory<BE: Backend>
61where
62    Self: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes,
63{
64    fn alloc_gglwe_prepared(
65        &self,
66        base2k: Base2K,
67        k: TorusPrecision,
68        rank_in: Rank,
69        rank_out: Rank,
70        dnum: Dnum,
71        dsize: Dsize,
72    ) -> GGLWEPrepared<Vec<u8>, BE> {
73        let size: usize = k.0.div_ceil(base2k.0) as usize;
74        debug_assert!(
75            size as u32 > dsize.0,
76            "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
77            dsize.0
78        );
79
80        assert!(
81            dnum.0 * dsize.0 <= size as u32,
82            "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
83            dnum.0,
84            dsize.0,
85        );
86
87        GGLWEPrepared {
88            data: self.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size),
89            k,
90            base2k,
91            dsize,
92        }
93    }
94
95    fn alloc_gglwe_prepared_from_infos<A>(&self, infos: &A) -> GGLWEPrepared<Vec<u8>, BE>
96    where
97        A: GGLWEInfos,
98    {
99        assert_eq!(self.ring_degree(), infos.n());
100        self.alloc_gglwe_prepared(
101            infos.base2k(),
102            infos.k(),
103            infos.rank_in(),
104            infos.rank_out(),
105            infos.dnum(),
106            infos.dsize(),
107        )
108    }
109
110    fn bytes_of_gglwe_prepared(
111        &self,
112        base2k: Base2K,
113        k: TorusPrecision,
114        rank_in: Rank,
115        rank_out: Rank,
116        dnum: Dnum,
117        dsize: Dsize,
118    ) -> usize {
119        let size: usize = k.0.div_ceil(base2k.0) as usize;
120        debug_assert!(
121            size as u32 > dsize.0,
122            "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
123            dsize.0
124        );
125
126        assert!(
127            dnum.0 * dsize.0 <= size as u32,
128            "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
129            dnum.0,
130            dsize.0,
131        );
132
133        self.bytes_of_vmp_pmat(dnum.into(), rank_in.into(), (rank_out + 1).into(), size)
134    }
135
136    fn bytes_of_gglwe_prepared_from_infos<A>(&self, infos: &A) -> usize
137    where
138        A: GGLWEInfos,
139    {
140        assert_eq!(self.ring_degree(), infos.n());
141        self.bytes_of_gglwe_prepared(
142            infos.base2k(),
143            infos.k(),
144            infos.rank_in(),
145            infos.rank_out(),
146            infos.dnum(),
147            infos.dsize(),
148        )
149    }
150
151    fn prepare_gglwe_tmp_bytes<A>(&self, infos: &A) -> usize
152    where
153        A: GGLWEInfos,
154    {
155        self.vmp_prepare_tmp_bytes(
156            infos.dnum().into(),
157            infos.rank_in().into(),
158            (infos.rank() + 1).into(),
159            infos.size(),
160        )
161    }
162
163    fn prepare_gglwe<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<BE>)
164    where
165        R: GGLWEPreparedToMut<BE>,
166        O: GGLWEToRef,
167    {
168        let mut res: GGLWEPrepared<&mut [u8], BE> = res.to_mut();
169        let other: GGLWE<&[u8]> = other.to_ref();
170
171        assert_eq!(res.n(), self.ring_degree());
172        assert_eq!(other.n(), self.ring_degree());
173        assert_eq!(res.base2k, other.base2k);
174        assert_eq!(res.k, other.k);
175        assert_eq!(res.dsize, other.dsize);
176
177        self.vmp_prepare(&mut res.data, &other.data, scratch);
178    }
179}
180
181impl<BE: Backend> GGLWEPreparedFactory<BE> for Module<BE> where
182    Module<BE>: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes
183{
184}
185
186impl<B: Backend> GGLWEPrepared<Vec<u8>, B> {
187    pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
188    where
189        A: GGLWEInfos,
190        M: GGLWEPreparedFactory<B>,
191    {
192        module.alloc_gglwe_prepared_from_infos(infos)
193    }
194
195    pub fn alloc<M>(
196        module: &M,
197        base2k: Base2K,
198        k: TorusPrecision,
199        rank_in: Rank,
200        rank_out: Rank,
201        dnum: Dnum,
202        dsize: Dsize,
203    ) -> Self
204    where
205        M: GGLWEPreparedFactory<B>,
206    {
207        module.alloc_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize)
208    }
209
210    pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize
211    where
212        A: GGLWEInfos,
213        M: GGLWEPreparedFactory<B>,
214    {
215        module.bytes_of_gglwe_prepared_from_infos(infos)
216    }
217
218    pub fn bytes_of<M>(
219        module: &M,
220        base2k: Base2K,
221        k: TorusPrecision,
222        rank_in: Rank,
223        rank_out: Rank,
224        dnum: Dnum,
225        dsize: Dsize,
226    ) -> usize
227    where
228        M: GGLWEPreparedFactory<B>,
229    {
230        module.bytes_of_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize)
231    }
232}
233
234impl<D: DataMut, B: Backend> GGLWEPrepared<D, B> {
235    pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>)
236    where
237        O: GGLWEToRef,
238        M: GGLWEPreparedFactory<B>,
239    {
240        module.prepare_gglwe(self, other, scratch);
241    }
242}
243
244impl<B: Backend> GGLWEPrepared<Vec<u8>, B> {
245    pub fn prepare_tmp_bytes<M>(&self, module: &M) -> usize
246    where
247        M: GGLWEPreparedFactory<B>,
248    {
249        module.prepare_gglwe_tmp_bytes(self)
250    }
251}
252
253pub trait GGLWEPreparedToMut<B: Backend> {
254    fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B>;
255}
256
257impl<D: DataMut, B: Backend> GGLWEPreparedToMut<B> for GGLWEPrepared<D, B> {
258    fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> {
259        GGLWEPrepared {
260            k: self.k,
261            base2k: self.base2k,
262            dsize: self.dsize,
263            data: self.data.to_mut(),
264        }
265    }
266}
267
268pub trait GGLWEPreparedToRef<B: Backend> {
269    fn to_ref(&self) -> GGLWEPrepared<&[u8], B>;
270}
271
272impl<D: DataRef, B: Backend> GGLWEPreparedToRef<B> for GGLWEPrepared<D, B> {
273    fn to_ref(&self) -> GGLWEPrepared<&[u8], B> {
274        GGLWEPrepared {
275            k: self.k,
276            base2k: self.base2k,
277            dsize: self.dsize,
278            data: self.data.to_ref(),
279        }
280    }
281}