1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig,
4 TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
5 VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform,
6 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecret, TakeGLWESecretPrepared,
14 layouts::{
15 GGLWESwitchingKey, GGLWETensorKey, GLWESecret, Infos,
16 prepared::{GLWESecretPrepared, Prepare},
17 },
18};
19
20impl GGLWETensorKey<Vec<u8>> {
21 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
22 where
23 Module<B>:
24 SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
25 {
26 GLWESecretPrepared::bytes_of(module, n, rank)
27 + module.vec_znx_dft_alloc_bytes(n, rank, 1)
28 + module.vec_znx_big_alloc_bytes(n, 1, 1)
29 + module.vec_znx_dft_alloc_bytes(n, 1, 1)
30 + GLWESecret::bytes_of(n, 1)
31 + GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank)
32 }
33}
34
35impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
36 pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
37 &mut self,
38 module: &Module<B>,
39 sk: &GLWESecret<DataSk>,
40 source_xa: &mut Source,
41 source_xe: &mut Source,
42 sigma: f64,
43 scratch: &mut Scratch<B>,
44 ) where
45 Module<B>: SvpApply<B>
46 + VecZnxDftToVecZnxBigTmpA<B>
47 + VecZnxAddScalarInplace
48 + VecZnxDftAllocBytes
49 + VecZnxBigNormalize<B>
50 + VecZnxDftFromVecZnx<B>
51 + SvpApplyInplace<B>
52 + VecZnxDftToVecZnxBigConsume<B>
53 + VecZnxNormalizeTmpBytes
54 + VecZnxFillUniform
55 + VecZnxSubABInplace
56 + VecZnxAddInplace
57 + VecZnxNormalizeInplace<B>
58 + VecZnxAddNormal
59 + VecZnxNormalize<B>
60 + VecZnxSub
61 + SvpPrepare<B>
62 + VecZnxSwithcDegree
63 + SvpPPolAllocBytes,
64 Scratch<B>:
65 TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B> + TakeVecZnxBig<B>,
66 {
67 #[cfg(debug_assertions)]
68 {
69 assert_eq!(self.rank(), sk.rank());
70 assert_eq!(self.n(), sk.n());
71 }
72
73 let n: usize = sk.n();
74
75 let rank: usize = self.rank();
76
77 let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank);
78 sk_dft_prep.prepare(module, sk, scratch1);
79
80 let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
81
82 (0..rank).for_each(|i| {
83 module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
84 });
85
86 let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
87 let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
88 let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
89
90 (0..rank).for_each(|i| {
91 (i..rank).for_each(|j| {
92 module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
93
94 module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
95 module.vec_znx_big_normalize(
96 self.basek(),
97 &mut sk_ij.data.as_vec_znx_mut(),
98 0,
99 &sk_ij_big,
100 0,
101 scratch5,
102 );
103
104 self.at_mut(i, j)
105 .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch5);
106 });
107 })
108 }
109}