poulpy_core/encryption/compressed/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx,
4        TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, VecZnxSwitchRing,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecret, TakeGLWESecretPrepared,
14    layouts::{GGLWETensorKey, GLWESecret, Infos, compressed::GGLWETensorKeyCompressed, prepared::Prepare},
15};
16
17impl GGLWETensorKeyCompressed<Vec<u8>> {
18    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize, rank: usize) -> usize
19    where
20        Module<B>:
21            SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
22    {
23        GGLWETensorKey::encrypt_sk_scratch_space(module, basek, k, rank)
24    }
25}
26
27impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
28    pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
29        &mut self,
30        module: &Module<B>,
31        sk: &GLWESecret<DataSk>,
32        seed_xa: [u8; 32],
33        source_xe: &mut Source,
34        scratch: &mut Scratch<B>,
35    ) where
36        Module<B>: SvpApplyDftToDft<B>
37            + VecZnxIdftApplyTmpA<B>
38            + VecZnxDftAllocBytes
39            + VecZnxBigNormalize<B>
40            + VecZnxDftApply<B>
41            + SvpApplyDftToDftInplace<B>
42            + VecZnxIdftApplyConsume<B>
43            + VecZnxNormalizeTmpBytes
44            + VecZnxFillUniform
45            + VecZnxSubABInplace
46            + VecZnxAddInplace
47            + VecZnxNormalizeInplace<B>
48            + VecZnxAddNormal
49            + VecZnxNormalize<B>
50            + VecZnxSub
51            + VecZnxSwitchRing
52            + VecZnxAddScalarInplace
53            + SvpPrepare<B>
54            + SvpPPolAllocBytes
55            + SvpPPolAlloc<B>,
56        Scratch<B>: ScratchAvailable
57            + TakeScalarZnx
58            + TakeVecZnxDft<B>
59            + TakeGLWESecretPrepared<B>
60            + ScratchAvailable
61            + TakeVecZnx
62            + TakeVecZnxBig<B>,
63    {
64        #[cfg(debug_assertions)]
65        {
66            assert_eq!(self.rank(), sk.rank());
67            assert_eq!(self.n(), sk.n());
68        }
69
70        let n: usize = sk.n();
71        let rank: usize = self.rank();
72
73        let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(n, rank);
74        sk_dft_prep.prepare(module, sk, scratch_1);
75
76        let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
77
78        (0..rank).for_each(|i| {
79            module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
80        });
81
82        let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
83        let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(n, 1);
84        let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
85
86        let mut source_xa: Source = Source::new(seed_xa);
87
88        (0..rank).for_each(|i| {
89            (i..rank).for_each(|j| {
90                module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
91
92                module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
93                module.vec_znx_big_normalize(
94                    self.basek(),
95                    &mut sk_ij.data.as_vec_znx_mut(),
96                    0,
97                    &sk_ij_big,
98                    0,
99                    scratch_5,
100                );
101
102                let (seed_xa_tmp, _) = source_xa.branch();
103
104                self.at_mut(i, j)
105                    .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5);
106            });
107        })
108    }
109}