1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
4 TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecret, TakeGLWESecretPrepared,
14 layouts::{
15 Degree, GGLWELayoutInfos, GGLWESwitchingKey, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank,
16 prepared::{GLWESecretPrepared, Prepare},
17 },
18};
19
20impl GGLWETensorKey<Vec<u8>> {
21 pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
22 where
23 A: GGLWELayoutInfos,
24 Module<B>:
25 SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
26 {
27 GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out())
28 + module.vec_znx_dft_alloc_bytes(infos.rank_out().into(), 1)
29 + module.vec_znx_big_alloc_bytes(1, 1)
30 + module.vec_znx_dft_alloc_bytes(1, 1)
31 + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1))
32 + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos)
33 }
34}
35
36impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
37 pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
38 &mut self,
39 module: &Module<B>,
40 sk: &GLWESecret<DataSk>,
41 source_xa: &mut Source,
42 source_xe: &mut Source,
43 scratch: &mut Scratch<B>,
44 ) where
45 Module<B>: SvpApplyDftToDft<B>
46 + VecZnxIdftApplyTmpA<B>
47 + VecZnxAddScalarInplace
48 + VecZnxDftAllocBytes
49 + VecZnxBigNormalize<B>
50 + VecZnxDftApply<B>
51 + SvpApplyDftToDftInplace<B>
52 + VecZnxIdftApplyConsume<B>
53 + VecZnxNormalizeTmpBytes
54 + VecZnxFillUniform
55 + VecZnxSubInplace
56 + VecZnxAddInplace
57 + VecZnxNormalizeInplace<B>
58 + VecZnxAddNormal
59 + VecZnxNormalize<B>
60 + VecZnxSub
61 + SvpPrepare<B>
62 + VecZnxSwitchRing
63 + SvpPPolAllocBytes,
64 Scratch<B>:
65 TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B> + TakeVecZnxBig<B>,
66 {
67 #[cfg(debug_assertions)]
68 {
69 assert_eq!(self.rank_out(), sk.rank());
70 assert_eq!(self.n(), sk.n());
71 }
72
73 let n: Degree = sk.n();
74 let rank: Rank = self.rank_out();
75
76 let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
77 sk_dft_prep.prepare(module, sk, scratch_1);
78
79 let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n.into(), rank.into(), 1);
80
81 (0..rank.into()).for_each(|i| {
82 module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
83 });
84
85 let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n.into(), 1, 1);
86 let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, Rank(1));
87 let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n.into(), 1, 1);
88
89 (0..rank.into()).for_each(|i| {
90 (i..rank.into()).for_each(|j| {
91 module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
92
93 module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
94 module.vec_znx_big_normalize(
95 self.base2k().into(),
96 &mut sk_ij.data.as_vec_znx_mut(),
97 0,
98 self.base2k().into(),
99 &sk_ij_big,
100 0,
101 scratch_5,
102 );
103
104 self.at_mut(i, j)
105 .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5);
106 });
107 })
108 }
109}