poulpy_core/encryption/compressed/
gglwe_tsk.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx,
4 TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecret, TakeGLWESecretPrepared,
14 layouts::{
15 GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, compressed::GGLWETensorKeyCompressed,
16 prepared::Prepare,
17 },
18};
19
20impl GGLWETensorKeyCompressed<Vec<u8>> {
21 pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
22 where
23 A: GGLWELayoutInfos,
24 Module<B>:
25 SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
26 {
27 GGLWETensorKey::encrypt_sk_scratch_space(module, infos)
28 }
29}
30
31impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
32 pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
33 &mut self,
34 module: &Module<B>,
35 sk: &GLWESecret<DataSk>,
36 seed_xa: [u8; 32],
37 source_xe: &mut Source,
38 scratch: &mut Scratch<B>,
39 ) where
40 Module<B>: SvpApplyDftToDft<B>
41 + VecZnxIdftApplyTmpA<B>
42 + VecZnxDftAllocBytes
43 + VecZnxBigNormalize<B>
44 + VecZnxDftApply<B>
45 + SvpApplyDftToDftInplace<B>
46 + VecZnxIdftApplyConsume<B>
47 + VecZnxNormalizeTmpBytes
48 + VecZnxFillUniform
49 + VecZnxSubInplace
50 + VecZnxAddInplace
51 + VecZnxNormalizeInplace<B>
52 + VecZnxAddNormal
53 + VecZnxNormalize<B>
54 + VecZnxSub
55 + VecZnxSwitchRing
56 + VecZnxAddScalarInplace
57 + SvpPrepare<B>
58 + SvpPPolAllocBytes
59 + SvpPPolAlloc<B>,
60 Scratch<B>: ScratchAvailable
61 + TakeScalarZnx
62 + TakeVecZnxDft<B>
63 + TakeGLWESecretPrepared<B>
64 + ScratchAvailable
65 + TakeVecZnx
66 + TakeVecZnxBig<B>,
67 {
68 #[cfg(debug_assertions)]
69 {
70 assert_eq!(self.rank_out(), sk.rank());
71 assert_eq!(self.n(), sk.n());
72 }
73
74 let n: usize = sk.n().into();
75 let rank: usize = self.rank_out().into();
76
77 let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), self.rank_out());
78 sk_dft_prep.prepare(module, sk, scratch_1);
79
80 let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
81
82 for i in 0..rank {
83 module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
84 }
85
86 let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
87 let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(sk.n(), Rank(1));
88 let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
89
90 let mut source_xa: Source = Source::new(seed_xa);
91
92 for i in 0..rank {
93 for j in i..rank {
94 module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
95
96 module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
97 module.vec_znx_big_normalize(
98 self.base2k().into(),
99 &mut sk_ij.data.as_vec_znx_mut(),
100 0,
101 self.base2k().into(),
102 &sk_ij_big,
103 0,
104 scratch_5,
105 );
106
107 let (seed_xa_tmp, _) = source_xa.branch();
108
109 self.at_mut(i, j)
110 .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5);
111 }
112 }
113 }
114}