poulpy_core/decryption/
glwe_ct.rs1use poulpy_hal::{
2 api::{
3 DataViewMut, SvpApplyInplace, TakeVecZnxBig, TakeVecZnxDft, VecZnxBigAddInplace, VecZnxBigAddSmallInplace,
4 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxNormalizeTmpBytes,
5 },
6 layouts::{Backend, DataMut, DataRef, Module, Scratch},
7};
8
9use crate::layouts::{GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared};
10
11impl GLWECiphertext<Vec<u8>> {
12 pub fn decrypt_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
13 where
14 Module<B>: VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
15 {
16 let size: usize = k.div_ceil(basek);
17 (module.vec_znx_normalize_tmp_bytes(n) | module.vec_znx_dft_alloc_bytes(n, 1, size))
18 + module.vec_znx_dft_alloc_bytes(n, 1, size)
19 }
20}
21
22impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
23 pub fn decrypt<DataPt: DataMut, DataSk: DataRef, B: Backend>(
24 &self,
25 module: &Module<B>,
26 pt: &mut GLWEPlaintext<DataPt>,
27 sk: &GLWESecretPrepared<DataSk, B>,
28 scratch: &mut Scratch<B>,
29 ) where
30 Module<B>: VecZnxDftFromVecZnx<B>
31 + SvpApplyInplace<B>
32 + VecZnxDftToVecZnxBigConsume<B>
33 + VecZnxBigAddInplace<B>
34 + VecZnxBigAddSmallInplace<B>
35 + VecZnxBigNormalize<B>,
36 Scratch<B>: TakeVecZnxDft<B> + TakeVecZnxBig<B>,
37 {
38 #[cfg(debug_assertions)]
39 {
40 assert_eq!(self.rank(), sk.rank());
41 assert_eq!(self.n(), sk.n());
42 assert_eq!(pt.n(), sk.n());
43 }
44
45 let cols: usize = self.rank() + 1;
46
47 let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self.n(), 1, self.size()); c0_big.data_mut().fill(0);
49
50 {
51 (1..cols).for_each(|i| {
52 let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self.n(), 1, self.size()); module.vec_znx_dft_from_vec_znx(1, 0, &mut ci_dft, 0, &self.data, i);
55 module.svp_apply_inplace(&mut ci_dft, 0, &sk.data, i - 1);
56 let ci_big = module.vec_znx_dft_to_vec_znx_big_consume(ci_dft);
57
58 module.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
60 });
61 }
62
63 module.vec_znx_big_add_small_inplace(&mut c0_big, 0, &self.data, 0);
65
66 module.vec_znx_big_normalize(self.basek(), &mut pt.data, 0, &c0_big, 0, scratch_1);
68
69 pt.basek = self.basek();
70 pt.k = pt.k().min(self.k());
71 }
72}