poulpy_core/external_product/
glwe.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
4        VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5    },
6    layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
7};
8
9use crate::{
10    ScratchTakeCore,
11    layouts::{
12        GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos,
13        prepared::{GGSWPrepared, GGSWPreparedToRef},
14    },
15};
16
17impl GLWE<Vec<u8>> {
18    pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
19    where
20        R: GLWEInfos,
21        A: GLWEInfos,
22        B: GGSWInfos,
23        M: GLWEExternalProduct<BE>,
24    {
25        module.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
26    }
27}
28
29impl<DataSelf: DataMut> GLWE<DataSelf> {
30    pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
31    where
32        A: GLWEToRef,
33        B: GGSWPreparedToRef<BE>,
34        M: GLWEExternalProduct<BE>,
35        Scratch<BE>: ScratchTakeCore<BE>,
36    {
37        module.glwe_external_product(self, a, b, scratch);
38    }
39
40    pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
41    where
42        A: GGSWPreparedToRef<BE>,
43        M: GLWEExternalProduct<BE>,
44        Scratch<BE>: ScratchTakeCore<BE>,
45    {
46        module.glwe_external_product_inplace(self, a, scratch);
47    }
48}
49
50pub trait GLWEExternalProduct<BE: Backend>
51where
52    Self: Sized
53        + ModuleN
54        + VecZnxDftBytesOf
55        + VmpApplyDftToDftTmpBytes
56        + VecZnxNormalizeTmpBytes
57        + VecZnxDftApply<BE>
58        + VmpApplyDftToDft<BE>
59        + VmpApplyDftToDftAdd<BE>
60        + VecZnxIdftApplyConsume<BE>
61        + VecZnxBigNormalize<BE>
62        + VecZnxNormalize<BE>,
63{
64    fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
65    where
66        R: GLWEInfos,
67        A: GLWEInfos,
68        B: GGSWInfos,
69    {
70        let in_size: usize = a_infos
71            .k()
72            .div_ceil(b_infos.base2k())
73            .div_ceil(b_infos.dsize().into()) as usize;
74        let out_size: usize = res_infos.size();
75        let ggsw_size: usize = b_infos.size();
76        let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size);
77        let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
78        let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
79            out_size,
80            in_size,
81            in_size,                     // rows
82            (b_infos.rank() + 1).into(), // cols in
83            (b_infos.rank() + 1).into(), // cols out
84            ggsw_size,
85        );
86        let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
87
88        if a_infos.base2k() == b_infos.base2k() {
89            res_dft + a_dft + (vmp | normalize_big)
90        } else {
91            let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
92            res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
93        }
94    }
95
96    fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
97    where
98        R: GLWEToMut,
99        D: GGSWPreparedToRef<BE>,
100        Scratch<BE>: ScratchTakeCore<BE>,
101    {
102        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
103        let rhs: &GGSWPrepared<&[u8], BE> = &a.to_ref();
104
105        let basek_in: usize = res.base2k().into();
106        let basek_ggsw: usize = rhs.base2k().into();
107
108        #[cfg(debug_assertions)]
109        {
110            use poulpy_hal::api::ScratchAvailable;
111
112            assert_eq!(rhs.rank(), res.rank());
113            assert_eq!(rhs.n(), res.n());
114            assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs));
115        }
116
117        let cols: usize = (rhs.rank() + 1).into();
118        let dsize: usize = rhs.dsize().into();
119        let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw);
120
121        let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
122        let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
123        a_dft.data_mut().fill(0);
124
125        if basek_in == basek_ggsw {
126            for di in 0..dsize {
127                // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
128                a_dft.set_size((res.size() + di) / dsize);
129
130                // Small optimization for dsize > 2
131                // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
132                // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
133                // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
134                // It is possible to further ignore the last dsize-1 limbs, but this introduce
135                // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
136                // noise is kept with respect to the ideal functionality.
137                res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
138
139                for j in 0..cols {
140                    self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
141                }
142
143                if di == 0 {
144                    self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
145                } else {
146                    self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
147                }
148            }
149        } else {
150            let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
151
152            for j in 0..cols {
153                self.vec_znx_normalize(
154                    basek_ggsw,
155                    &mut a_conv,
156                    j,
157                    basek_in,
158                    &res.data,
159                    j,
160                    scratch_3,
161                );
162            }
163
164            for di in 0..dsize {
165                // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
166                a_dft.set_size((res.size() + di) / dsize);
167
168                // Small optimization for dsize > 2
169                // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
170                // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
171                // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
172                // It is possible to further ignore the last dsize-1 limbs, but this introduce
173                // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
174                // noise is kept with respect to the ideal functionality.
175                res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
176
177                for j in 0..cols {
178                    self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
179                }
180
181                if di == 0 {
182                    self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
183                } else {
184                    self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
185                }
186            }
187        }
188
189        let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
190
191        for j in 0..cols {
192            self.vec_znx_big_normalize(
193                basek_in,
194                &mut res.data,
195                j,
196                basek_ggsw,
197                &res_big,
198                j,
199                scratch_1,
200            );
201        }
202    }
203
204    fn glwe_external_product<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
205    where
206        R: GLWEToMut,
207        A: GLWEToRef,
208        D: GGSWPreparedToRef<BE>,
209        Scratch<BE>: ScratchTakeCore<BE>,
210    {
211        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
212        let lhs: &GLWE<&[u8]> = &lhs.to_ref();
213
214        let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref();
215
216        let basek_in: usize = lhs.base2k().into();
217        let basek_ggsw: usize = rhs.base2k().into();
218        let basek_out: usize = res.base2k().into();
219
220        #[cfg(debug_assertions)]
221        {
222            use poulpy_hal::api::ScratchAvailable;
223
224            assert_eq!(rhs.rank(), lhs.rank());
225            assert_eq!(rhs.rank(), res.rank());
226            assert_eq!(rhs.n(), res.n());
227            assert_eq!(lhs.n(), res.n());
228            assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs));
229        }
230
231        let cols: usize = (rhs.rank() + 1).into();
232        let dsize: usize = rhs.dsize().into();
233
234        let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
235
236        let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); // Todo optimise
237        let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
238        a_dft.data_mut().fill(0);
239
240        if basek_in == basek_ggsw {
241            for di in 0..dsize {
242                // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
243                a_dft.set_size((lhs.size() + di) / dsize);
244
245                // Small optimization for dsize > 2
246                // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
247                // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
248                // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
249                // It is possible to further ignore the last dsize-1 limbs, but this introduce
250                // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
251                // noise is kept with respect to the ideal functionality.
252                res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
253
254                for j in 0..cols {
255                    self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j);
256                }
257
258                if di == 0 {
259                    self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
260                } else {
261                    self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
262                }
263            }
264        } else {
265            let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
266
267            for j in 0..cols {
268                self.vec_znx_normalize(
269                    basek_ggsw,
270                    &mut a_conv,
271                    j,
272                    basek_in,
273                    &lhs.data,
274                    j,
275                    scratch_3,
276                );
277            }
278
279            for di in 0..dsize {
280                // (lhs.size() + di) / dsize = (a - (digit - di - 1)).div_ceil(dsize)
281                a_dft.set_size((a_size + di) / dsize);
282
283                // Small optimization for dsize > 2
284                // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
285                // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(dsize-1) * B}.
286                // As such we can ignore the last dsize-2 limbs safely of the sum of vmp products.
287                // It is possible to further ignore the last dsize-1 limbs, but this introduce
288                // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
289                // noise is kept with respect to the ideal functionality.
290                res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
291
292                for j in 0..cols {
293                    self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j);
294                }
295
296                if di == 0 {
297                    self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
298                } else {
299                    self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
300                }
301            }
302        }
303
304        let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
305
306        (0..cols).for_each(|i| {
307            self.vec_znx_big_normalize(
308                basek_out,
309                res.data_mut(),
310                i,
311                basek_ggsw,
312                &res_big,
313                i,
314                scratch_1,
315            );
316        });
317    }
318}
319
320impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
321    Self: ModuleN
322        + VecZnxDftBytesOf
323        + VmpApplyDftToDftTmpBytes
324        + VecZnxNormalizeTmpBytes
325        + VecZnxDftApply<BE>
326        + VmpApplyDftToDft<BE>
327        + VmpApplyDftToDftAdd<BE>
328        + VecZnxIdftApplyConsume<BE>
329        + VecZnxBigNormalize<BE>
330        + VecZnxNormalize<BE>
331        + VecZnxDftBytesOf
332        + VmpApplyDftToDftTmpBytes
333        + VecZnxNormalizeTmpBytes
334{
335}