poulpy_core/encryption/compressed/
ggsw_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4 VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5 VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
8 source::Source,
9};
10
11use crate::{
12 TakeGLWEPt,
13 encryption::{SIGMA, glwe_encrypt_sk_internal},
14 layouts::{
15 GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, compressed::GGSWCiphertextCompressed, prepared::GLWESecretPrepared,
16 },
17};
18
19impl GGSWCiphertextCompressed<Vec<u8>> {
20 pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
21 where
22 A: GGSWInfos,
23 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes,
24 {
25 GGSWCiphertext::encrypt_sk_scratch_space(module, infos)
26 }
27}
28
29impl<DataSelf: DataMut> GGSWCiphertextCompressed<DataSelf> {
30 #[allow(clippy::too_many_arguments)]
31 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
32 &mut self,
33 module: &Module<B>,
34 pt: &ScalarZnx<DataPt>,
35 sk: &GLWESecretPrepared<DataSk, B>,
36 seed_xa: [u8; 32],
37 source_xe: &mut Source,
38 scratch: &mut Scratch<B>,
39 ) where
40 Module<B>: VecZnxAddScalarInplace
41 + VecZnxDftAllocBytes
42 + VecZnxBigNormalize<B>
43 + VecZnxDftApply<B>
44 + SvpApplyDftToDftInplace<B>
45 + VecZnxIdftApplyConsume<B>
46 + VecZnxNormalizeTmpBytes
47 + VecZnxFillUniform
48 + VecZnxSubInplace
49 + VecZnxAddInplace
50 + VecZnxNormalizeInplace<B>
51 + VecZnxAddNormal
52 + VecZnxNormalize<B>
53 + VecZnxSub,
54 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
55 {
56 #[cfg(debug_assertions)]
57 {
58 use poulpy_hal::layouts::ZnxInfos;
59
60 assert_eq!(self.rank(), sk.rank());
61 assert_eq!(self.n(), sk.n());
62 assert_eq!(pt.n() as u32, sk.n());
63 }
64
65 let base2k: usize = self.base2k().into();
66 let rank: usize = self.rank().into();
67 let cols: usize = rank + 1;
68 let digits: usize = self.digits().into();
69
70 let (mut tmp_pt, scratch_1) = scratch.take_glwe_pt(&self.glwe_layout());
71
72 let mut source = Source::new(seed_xa);
73
74 self.seed = vec![[0u8; 32]; self.rows().0 as usize * cols];
75
76 (0..self.rows().into()).for_each(|row_i| {
77 tmp_pt.data.zero();
78
79 module.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (digits - 1) + row_i * digits, pt, 0);
81 module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1);
82
83 (0..rank + 1).for_each(|col_j| {
84 let (seed, mut source_xa_tmp) = source.branch();
87
88 self.seed[row_i * cols + col_j] = seed;
89
90 glwe_encrypt_sk_internal(
91 module,
92 self.base2k().into(),
93 self.k().into(),
94 &mut self.at_mut(row_i, col_j).data,
95 cols,
96 true,
97 Some((&tmp_pt, col_j)),
98 sk,
99 &mut source_xa_tmp,
100 source_xe,
101 SIGMA,
102 scratch_1,
103 );
104 });
105 });
106 }
107}