poulpy_core/encryption/compressed/
gglwe_tsk.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx,
4 TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecret, TakeGLWESecretPrepared,
14 layouts::{GGLWETensorKey, GLWESecret, Infos, compressed::GGLWETensorKeyCompressed, prepared::Prepare},
15};
16
17impl GGLWETensorKeyCompressed<Vec<u8>> {
18 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
19 where
20 Module<B>:
21 SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
22 {
23 GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k, rank)
24 }
25}
26
27impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
28 pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
29 &mut self,
30 module: &Module<B>,
31 sk: &GLWESecret<DataSk>,
32 seed_xa: [u8; 32],
33 source_xe: &mut Source,
34 scratch: &mut Scratch<B>,
35 ) where
36 Module<B>: SvpApplyDftToDft<B>
37 + VecZnxIdftApplyTmpA<B>
38 + VecZnxDftAllocBytes
39 + VecZnxBigNormalize<B>
40 + VecZnxDftApply<B>
41 + SvpApplyDftToDftInplace<B>
42 + VecZnxIdftApplyConsume<B>
43 + VecZnxNormalizeTmpBytes
44 + VecZnxFillUniform
45 + VecZnxSubABInplace
46 + VecZnxAddInplace
47 + VecZnxNormalizeInplace<B>
48 + VecZnxAddNormal
49 + VecZnxNormalize<B>
50 + VecZnxSub
51 + VecZnxSwitchRing
52 + VecZnxAddScalarInplace
53 + SvpPrepare<B>
54 + SvpPPolAllocBytes
55 + SvpPPolAlloc<B>,
56 Scratch<B>: ScratchAvailable
57 + TakeScalarZnx
58 + TakeVecZnxDft<B>
59 + TakeGLWESecretPrepared<B>
60 + ScratchAvailable
61 + TakeVecZnx
62 + TakeVecZnxBig<B>,
63 {
64 #[cfg(debug_assertions)]
65 {
66 assert_eq!(self.rank(), sk.rank());
67 assert_eq!(self.n(), sk.n());
68 }
69
70 let n: usize = sk.n();
71 let rank: usize = self.rank();
72
73 let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
74 sk_dft_prep.prepare(module, sk, scratch_1);
75
76 let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
77
78 (0..rank).for_each(|i| {
79 module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
80 });
81
82 let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
83 let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1);
84 let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
85
86 let mut source_xa: Source = Source::new(seed_xa);
87
88 (0..rank).for_each(|i| {
89 (i..rank).for_each(|j| {
90 module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
91
92 module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
93 module.vec_znx_big_normalize(
94 self.basek(),
95 &mut sk_ij.data.as_vec_znx_mut(),
96 0,
97 &sk_ij_big,
98 0,
99 scratch_5,
100 );
101
102 let (seed_xa_tmp, _) = source_xa.branch();
103
104 self.at_mut(i, j)
105 .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5);
106 });
107 })
108 }
109}