poulpy_core/encryption/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxBig,
4        TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes, VecZnxBigNormalize,
5        VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxFillUniform,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecret, TakeGLWESecretPrepared,
14    layouts::{
15        GGLWESwitchingKey, GGLWETensorKey, GLWESecret, Infos,
16        prepared::{GLWESecretPrepared, Prepare},
17    },
18};
19
20impl GGLWETensorKey<Vec<u8>> {
21    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
22    where
23        Module<B>:
24            SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
25    {
26        GLWESecretPrepared::bytes_of(module, n, rank)
27            + module.vec_znx_dft_alloc_bytes(n, rank, 1)
28            + module.vec_znx_big_alloc_bytes(n, 1, 1)
29            + module.vec_znx_dft_alloc_bytes(n, 1, 1)
30            + GLWESecret::bytes_of(n, 1)
31            + GGLWESwitchingKey::encrypt_sk_scratch_space(module, n, basek, k, rank, rank)
32    }
33}
34
35impl<DataSelf: DataMut> GGLWETensorKey<DataSelf> {
36    pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
37        &mut self,
38        module: &Module<B>,
39        sk: &GLWESecret<DataSk>,
40        source_xa: &mut Source,
41        source_xe: &mut Source,
42        sigma: f64,
43        scratch: &mut Scratch<B>,
44    ) where
45        Module<B>: SvpApply<B>
46            + VecZnxDftToVecZnxBigTmpA<B>
47            + VecZnxAddScalarInplace
48            + VecZnxDftAllocBytes
49            + VecZnxBigNormalize<B>
50            + VecZnxDftFromVecZnx<B>
51            + SvpApplyInplace<B>
52            + VecZnxDftToVecZnxBigConsume<B>
53            + VecZnxNormalizeTmpBytes
54            + VecZnxFillUniform
55            + VecZnxSubABInplace
56            + VecZnxAddInplace
57            + VecZnxNormalizeInplace<B>
58            + VecZnxAddNormal
59            + VecZnxNormalize<B>
60            + VecZnxSub
61            + SvpPrepare<B>
62            + VecZnxSwithcDegree
63            + SvpPPolAllocBytes,
64        Scratch<B>:
65            TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B> + TakeVecZnxBig<B>,
66    {
67        #[cfg(debug_assertions)]
68        {
69            assert_eq!(self.rank(), sk.rank());
70            assert_eq!(self.n(), sk.n());
71        }
72
73        let n: usize = sk.n();
74
75        let rank: usize = self.rank();
76
77        let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank);
78        sk_dft_prep.prepare(module, sk, scratch1);
79
80        let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
81
82        (0..rank).for_each(|i| {
83            module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
84        });
85
86        let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
87        let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
88        let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
89
90        (0..rank).for_each(|i| {
91            (i..rank).for_each(|j| {
92                module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
93
94                module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
95                module.vec_znx_big_normalize(
96                    self.basek(),
97                    &mut sk_ij.data.as_vec_znx_mut(),
98                    0,
99                    &sk_ij_big,
100                    0,
101                    scratch5,
102                );
103
104                self.at_mut(i, j)
105                    .encrypt_sk(module, &sk_ij, sk, source_xa, source_xe, sigma, scratch5);
106            });
107        })
108    }
109}