poulpy_core/encryption/
ggsw_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4 VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5 VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx, ZnxZero},
8 source::Source,
9};
10
11use crate::{
12 TakeGLWEPt,
13 layouts::{GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GLWESecretPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17 pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
18 where
19 A: GGSWInfos,
20 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
21 {
22 let size = infos.size();
23 GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout())
24 + VecZnx::alloc_bytes(module.n(), (infos.rank() + 1).into(), size)
25 + VecZnx::alloc_bytes(module.n(), 1, size)
26 + module.vec_znx_dft_alloc_bytes((infos.rank() + 1).into(), size)
27 }
28}
29
30impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
31 #[allow(clippy::too_many_arguments)]
32 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
33 &mut self,
34 module: &Module<B>,
35 pt: &ScalarZnx<DataPt>,
36 sk: &GLWESecretPrepared<DataSk, B>,
37 source_xa: &mut Source,
38 source_xe: &mut Source,
39 scratch: &mut Scratch<B>,
40 ) where
41 Module<B>: VecZnxAddScalarInplace
42 + VecZnxDftAllocBytes
43 + VecZnxBigNormalize<B>
44 + VecZnxDftApply<B>
45 + SvpApplyDftToDftInplace<B>
46 + VecZnxIdftApplyConsume<B>
47 + VecZnxNormalizeTmpBytes
48 + VecZnxFillUniform
49 + VecZnxSubInplace
50 + VecZnxAddInplace
51 + VecZnxNormalizeInplace<B>
52 + VecZnxAddNormal
53 + VecZnxNormalize<B>
54 + VecZnxSub,
55 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
56 {
57 #[cfg(debug_assertions)]
58 {
59 use poulpy_hal::layouts::ZnxInfos;
60
61 assert_eq!(self.rank(), sk.rank());
62 assert_eq!(self.n(), sk.n());
63 assert_eq!(pt.n() as u32, sk.n());
64 }
65
66 let base2k: usize = self.base2k().into();
67 let rank: usize = self.rank().into();
68 let digits: usize = self.digits().into();
69
70 let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout());
71
72 (0..self.rows().into()).for_each(|row_i| {
73 tmp_pt.data.zero();
74
75 module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
77 module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1);
78
79 (0..rank + 1).for_each(|col_j| {
80 self.at_mut(row_i, col_j).encrypt_sk_internal(
83 module,
84 Some((&tmp_pt, col_j)),
85 sk,
86 source_xa,
87 source_xe,
88 scratch_1,
89 );
90 });
91 });
92 }
93}