poulpy_core/encryption/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
4        TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecret, TakeGLWESecretPrepared,
14    layouts::{
15        GGLWESwitchingKey, GGLWETensorKey, GLWESecret, Infos,
16        prepared::{GLWESecretPrepared, Prepare},
17    },
18};
19
20impl GGLWETensorKey<Vec<u8>> {
21    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
22    where
23        Module<B>:
24            SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
25    {
26        GLWESecretPrepared::bytes_of(module, rank)
27            + module.vec_znx_dft_alloc_bytes(rank, 1)
28            + module.vec_znx_big_alloc_bytes(1, 1)
29            + module.vec_znx_dft_alloc_bytes(1, 1)
30            + GLWESecret::bytes_of(module.n(), 1)
31            + GGLWESwitchingKey::encrypt_sk_scratch_space(module, basek, k, rank, rank)
32    }
33}
34
35impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
36    pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
37        &mut self,
38        module: &Module<B>,
39        sk: &GLWESecret<DataSk>,
40        source_xa: &mut Source,
41        source_xe: &mut Source,
42        scratch: &mut Scratch<B>,
43    ) where
44        Module<B>: SvpApplyDftToDft<B>
45            + VecZnxIdftApplyTmpA<B>
46            + VecZnxAddScalarInplace
47            + VecZnxDftAllocBytes
48            + VecZnxBigNormalize<B>
49            + VecZnxDftApply<B>
50            + SvpApplyDftToDftInplace<B>
51            + VecZnxIdftApplyConsume<B>
52            + VecZnxNormalizeTmpBytes
53            + VecZnxFillUniform
54            + VecZnxSubABInplace
55            + VecZnxAddInplace
56            + VecZnxNormalizeInplace<B>
57            + VecZnxAddNormal
58            + VecZnxNormalize<B>
59            + VecZnxSub
60            + SvpPrepare<B>
61            + VecZnxSwitchRing
62            + SvpPPolAllocBytes,
63        Scratch<B>:
64            TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B> + TakeVecZnxBig<B>,
65    {
66        #[cfg(debug_assertions)]
67        {
68            assert_eq!(self.rank(), sk.rank());
69            assert_eq!(self.n(), sk.n());
70        }
71
72        let n: usize = sk.n();
73
74        let rank: usize = self.rank();
75
76        let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
77        sk_dft_prep.prepare(module, sk, scratch_1);
78
79        let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
80
81        (0..rank).for_each(|i| {
82            module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
83        });
84
85        let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
86        let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1);
87        let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
88
89        (0..rank).for_each(|i| {
90            (i..rank).for_each(|j| {
91                module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
92
93                module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
94                module.vec_znx_big_normalize(
95                    self.basek(),
96                    &mut sk_ij.data.as_vec_znx_mut(),
97                    0,
98                    &sk_ij_big,
99                    0,
100                    scratch_5,
101                );
102
103                self.at_mut(i, j)
104                    .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5);
105            });
106        })
107    }
108}