poulpy_core/encryption/
ggsw_ct.rs1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4 VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxFillUniform, VecZnxNormalize,
5 VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero},
8 source::Source,
9};
10
11use crate::{
12 TakeGLWEPt,
13 layouts::{GGSWCiphertext, GLWECiphertext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
18 where
19 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
20 {
21 let size = k.div_ceil(basek);
22 GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
23 + VecZnx::alloc_bytes(module.n(), rank + 1, size)
24 + VecZnx::alloc_bytes(module.n(), 1, size)
25 + module.vec_znx_dft_alloc_bytes(rank + 1, size)
26 }
27}
28
29impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
30 #[allow(clippy::too_many_arguments)]
31 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
32 &mut self,
33 module: &Module<B>,
34 pt: &ScalarZnx<DataPt>,
35 sk: &GLWESecretPrepared<DataSk, B>,
36 source_xa: &mut Source,
37 source_xe: &mut Source,
38 scratch: &mut Scratch<B>,
39 ) where
40 Module<B>: VecZnxAddScalarInplace
41 + VecZnxDftAllocBytes
42 + VecZnxBigNormalize<B>
43 + DFT<B>
44 + SvpApplyInplace<B>
45 + IDFTConsume<B>
46 + VecZnxNormalizeTmpBytes
47 + VecZnxFillUniform
48 + VecZnxSubABInplace
49 + VecZnxAddInplace
50 + VecZnxNormalizeInplace<B>
51 + VecZnxAddNormal
52 + VecZnxNormalize<B>
53 + VecZnxSub,
54 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
55 {
56 #[cfg(debug_assertions)]
57 {
58 use poulpy_hal::layouts::ZnxInfos;
59
60 assert_eq!(self.rank(), sk.rank());
61 assert_eq!(self.n(), sk.n());
62 assert_eq!(pt.n(), sk.n());
63 }
64
65 let basek: usize = self.basek();
66 let k: usize = self.k();
67 let rank: usize = self.rank();
68 let digits: usize = self.digits();
69
70 let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(self.n(), basek, k);
71
72 (0..self.rows()).for_each(|row_i| {
73 tmp_pt.data.zero();
74
75 module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
77 module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch1);
78
79 (0..rank + 1).for_each(|col_j| {
80 self.at_mut(row_i, col_j).encrypt_sk_internal(
83 module,
84 Some((&tmp_pt, col_j)),
85 sk,
86 source_xa,
87 source_xe,
88 scratch1,
89 );
90 });
91 });
92 }
93}