poulpy_core/encryption/
ggsw_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
4 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform,
5 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, ZnxZero,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, VecZnx},
8 source::Source,
9};
10
11use crate::{
12 TakeGLWEPt,
13 layouts::{GGSWCiphertext, GLWECiphertext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
18 where
19 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
20 {
21 let size = k.div_ceil(basek);
22 GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k)
23 + VecZnx::alloc_bytes(n, rank + 1, size)
24 + VecZnx::alloc_bytes(n, 1, size)
25 + module.vec_znx_dft_alloc_bytes(n, rank + 1, size)
26 }
27}
28
29impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
30 #[allow(clippy::too_many_arguments)]
31 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
32 &mut self,
33 module: &Module<B>,
34 pt: &ScalarZnx<DataPt>,
35 sk: &GLWESecretPrepared<DataSk, B>,
36 source_xa: &mut Source,
37 source_xe: &mut Source,
38 sigma: f64,
39 scratch: &mut Scratch<B>,
40 ) where
41 Module<B>: VecZnxAddScalarInplace
42 + VecZnxDftAllocBytes
43 + VecZnxBigNormalize<B>
44 + VecZnxDftFromVecZnx<B>
45 + SvpApplyInplace<B>
46 + VecZnxDftToVecZnxBigConsume<B>
47 + VecZnxNormalizeTmpBytes
48 + VecZnxFillUniform
49 + VecZnxSubABInplace
50 + VecZnxAddInplace
51 + VecZnxNormalizeInplace<B>
52 + VecZnxAddNormal
53 + VecZnxNormalize<B>
54 + VecZnxSub,
55 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
56 {
57 #[cfg(debug_assertions)]
58 {
59 use poulpy_hal::api::ZnxInfos;
60
61 assert_eq!(self.rank(), sk.rank());
62 assert_eq!(self.n(), sk.n());
63 assert_eq!(pt.n(), sk.n());
64 }
65
66 let basek: usize = self.basek();
67 let k: usize = self.k();
68 let rank: usize = self.rank();
69 let digits: usize = self.digits();
70
71 let (mut tmp_pt, scratch1) = scratch.take_glwe_pt(self.n(), basek, k);
72
73 (0..self.rows()).for_each(|row_i| {
74 tmp_pt.data.zero();
75
76 module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
78 module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scratch1);
79
80 (0..rank + 1).for_each(|col_j| {
81 self.at_mut(row_i, col_j).encrypt_sk_internal(
84 module,
85 Some((&tmp_pt, col_j)),
86 sk,
87 source_xa,
88 source_xe,
89 sigma,
90 scratch1,
91 );
92 });
93 });
94 }
95}