poulpy_core/encryption/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
4        TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecret, TakeGLWESecretPrepared,
14    layouts::{
15        Degree, GGLWELayoutInfos, GGLWESwitchingKey, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank,
16        prepared::{GLWESecretPrepared, Prepare},
17    },
18};
19
20impl GGLWETensorKey<Vec<u8>> {
21    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
22    where
23        A: GGLWELayoutInfos,
24        Module<B>:
25            SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
26    {
27        GLWESecretPrepared::alloc_bytes_with(module, infos.rank_out())
28            + module.vec_znx_dft_alloc_bytes(infos.rank_out().into(), 1)
29            + module.vec_znx_big_alloc_bytes(1, 1)
30            + module.vec_znx_dft_alloc_bytes(1, 1)
31            + GLWESecret::alloc_bytes_with(Degree(module.n() as u32), Rank(1))
32            + GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos)
33    }
34}
35
36impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
37    pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
38        &mut self,
39        module: &Module<B>,
40        sk: &GLWESecret<DataSk>,
41        source_xa: &mut Source,
42        source_xe: &mut Source,
43        scratch: &mut Scratch<B>,
44    ) where
45        Module<B>: SvpApplyDftToDft<B>
46            + VecZnxIdftApplyTmpA<B>
47            + VecZnxAddScalarInplace
48            + VecZnxDftAllocBytes
49            + VecZnxBigNormalize<B>
50            + VecZnxDftApply<B>
51            + SvpApplyDftToDftInplace<B>
52            + VecZnxIdftApplyConsume<B>
53            + VecZnxNormalizeTmpBytes
54            + VecZnxFillUniform
55            + VecZnxSubInplace
56            + VecZnxAddInplace
57            + VecZnxNormalizeInplace<B>
58            + VecZnxAddNormal
59            + VecZnxNormalize<B>
60            + VecZnxSub
61            + SvpPrepare<B>
62            + VecZnxSwitchRing
63            + SvpPPolAllocBytes,
64        Scratch<B>:
65            TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B> + TakeVecZnxBig<B>,
66    {
67        #[cfg(debug_assertions)]
68        {
69            assert_eq!(self.rank_out(), sk.rank());
70            assert_eq!(self.n(), sk.n());
71        }
72
73        let n: Degree = sk.n();
74        let rank: Rank = self.rank_out();
75
76        let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
77        sk_dft_prep.prepare(module, sk, scratch_1);
78
79        let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n.into(), rank.into(), 1);
80
81        (0..rank.into()).for_each(|i| {
82            module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
83        });
84
85        let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n.into(), 1, 1);
86        let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, Rank(1));
87        let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n.into(), 1, 1);
88
89        (0..rank.into()).for_each(|i| {
90            (i..rank.into()).for_each(|j| {
91                module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
92
93                module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
94                module.vec_znx_big_normalize(
95                    self.base2k().into(),
96                    &mut sk_ij.data.as_vec_znx_mut(),
97                    0,
98                    self.base2k().into(),
99                    &sk_ij_big,
100                    0,
101                    scratch_5,
102                );
103
104                self.at_mut(i, j)
105                    .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, scratch_5);
106            });
107        })
108    }
109}