1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
4 TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecret, TakeGLWESecretPrepared,
14 layouts::{
15 GGLWESwitchingKey, GGLWETensorKey, GLWESecret, Infos,
16 prepared::{GLWESecretPrepared, Prepare},
17 },
18};
19
20impl GGLWETensorKey<Vec<u8>> {
21 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
22 where
23 Module<B>:
24 SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
25 {
26 GLWESecretPrepared::bytes_of(module, rank)
27 + module.vec_znx_dft_alloc_bytes(rank, 1)
28 + module.vec_znx_big_alloc_bytes(1, 1)
29 + module.vec_znx_dft_alloc_bytes(1, 1)
30 + GLWESecret::bytes_of(module.n(), 1)
31 + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank)
32 }
33}
34
35impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
36 pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
37 &mut self,
38 module: &Module<B>,
39 sk: &GLWESecret<DataSk>,
40 source_xa: &mut Source,
41 source_xe: &mut Source,
42 scratch: &mut Scratch<B>,
43 ) where
44 Module<B>: SvpApplyDftToDft<B>
45 + VecZnxIdftApplyTmpA<B>
46 + VecZnxAddScalarInplace
47 + VecZnxDftAllocBytes
48 + VecZnxBigNormalize<B>
49 + VecZnxDftApply<B>
50 + SvpApplyDftToDftInplace<B>
51 + VecZnxIdftApplyConsume<B>
52 + VecZnxNormalizeTmpBytes
53 + VecZnxFillUniform
54 + VecZnxSubABInplace
55 + VecZnxAddInplace
56 + VecZnxNormalizeInplace<B>
57 + VecZnxAddNormal
58 + VecZnxNormalize<B>
59 + VecZnxSub
60 + SvpPrepare<B>
61 + VecZnxSwitchRing
62 + SvpPPolAllocBytes,
63 Scratch<B>:
64 TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B> + TakeVecZnxBig<B>,
65 {
66 #[cfg(debug_assertions)]
67 {
68 assert_eq!(self.rank(), sk.rank());
69 assert_eq!(self.n(), sk.n());
70 }
71
72 let n: usize = sk.n();
73
74 let rank: usize = self.rank();
75
76 let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
77 sk_dft_prep.prepare(module, sk, scratch_1);
78
79 let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
80
81 (0..rank).for_each(|i| {
82 module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
83 });
84
85 let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
86 let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1);
87 let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
88
89 (0..rank).for_each(|i| {
90 (i..rank).for_each(|j| {
91 module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
92
93 module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
94 module.vec_znx_big_normalize(
95 self.basek(),
96 &mut sk_ij.data.as_vec_znx_mut(),
97 0,
98 &sk_ij_big,
99 0,
100 scratch_5,
101 );
102
103 self.at_mut(i, j)
104 .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5);
105 });
106 })
107 }
108}