poulpy_core/encryption/compressed/
gglwe_tsk.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
4 TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA,
6 VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
7 VecZnxSwithcDegree,
8 },
9 layouts::{Backend, DataMut, DataRef, Module, Scratch},
10 source::Source,
11};
12
13use crate::{
14 TakeGLWESecret, TakeGLWESecretPrepared,
15 layouts::{GGLWETensorKey, GLWESecret, Infos, compressed::GGLWETensorKeyCompressed, prepared::Prepare},
16};
17
18impl GGLWETensorKeyCompressed<Vec<u8>> {
19 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
20 where
21 Module<B>:
22 SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
23 {
24 GGLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k, rank)
25 }
26}
27
28impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
29 pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
30 &mut self,
31 module: &Module<B>,
32 sk: &GLWESecret<DataSk>,
33 seed_xa: [u8; 32],
34 source_xe: &mut Source,
35 sigma: f64,
36 scratch: &mut Scratch<B>,
37 ) where
38 Module<B>: SvpApply<B>
39 + VecZnxDftToVecZnxBigTmpA<B>
40 + VecZnxDftAllocBytes
41 + VecZnxBigNormalize<B>
42 + VecZnxDftFromVecZnx<B>
43 + SvpApplyInplace<B>
44 + VecZnxDftToVecZnxBigConsume<B>
45 + VecZnxNormalizeTmpBytes
46 + VecZnxFillUniform
47 + VecZnxSubABInplace
48 + VecZnxAddInplace
49 + VecZnxNormalizeInplace<B>
50 + VecZnxAddNormal
51 + VecZnxNormalize<B>
52 + VecZnxSub
53 + VecZnxSwithcDegree
54 + VecZnxAddScalarInplace
55 + SvpPrepare<B>
56 + SvpPPolAllocBytes
57 + SvpPPolAlloc<B>,
58 Scratch<B>: ScratchAvailable
59 + TakeScalarZnx
60 + TakeVecZnxDft<B>
61 + TakeGLWESecretPrepared<B>
62 + ScratchAvailable
63 + TakeVecZnx
64 + TakeVecZnxBig<B>,
65 {
66 #[cfg(debug_assertions)]
67 {
68 assert_eq!(self.rank(), sk.rank());
69 assert_eq!(self.n(), sk.n());
70 }
71
72 let n: usize = sk.n();
73 let rank: usize = self.rank();
74
75 let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank);
76 sk_dft_prep.prepare(module, sk, scratch1);
77
78 let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
79
80 (0..rank).for_each(|i| {
81 module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
82 });
83
84 let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
85 let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
86 let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
87
88 let mut source_xa: Source = Source::new(seed_xa);
89
90 (0..rank).for_each(|i| {
91 (i..rank).for_each(|j| {
92 module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
93
94 module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
95 module.vec_znx_big_normalize(
96 self.basek(),
97 &mut sk_ij.data.as_vec_znx_mut(),
98 0,
99 &sk_ij_big,
100 0,
101 scratch5,
102 );
103
104 let (seed_xa_tmp, _) = source_xa.branch();
105
106 self.at_mut(i, j)
107 .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, sigma, scratch5);
108 });
109 })
110 }
111}