poulpy_core/encryption/compressed/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDft, SvpApplyDftToDftInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx,
4        TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecret, TakeGLWESecretPrepared,
14    layouts::{
15        GGLWELayoutInfos, GGLWETensorKey, GLWEInfos, GLWESecret, LWEInfos, Rank, compressed::GGLWETensorKeyCompressed,
16        prepared::Prepare,
17    },
18};
19
20impl GGLWETensorKeyCompressed<Vec<u8>> {
21    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
22    where
23        A: GGLWELayoutInfos,
24        Module<B>:
25            SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
26    {
27        GGLWETensorKey::encrypt_sk_scratch_space(module, infos)
28    }
29}
30
31impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
32    pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
33        &mut self,
34        module: &Module<B>,
35        sk: &GLWESecret<DataSk>,
36        seed_xa: [u8; 32],
37        source_xe: &mut Source,
38        scratch: &mut Scratch<B>,
39    ) where
40        Module<B>: SvpApplyDftToDft<B>
41            + VecZnxIdftApplyTmpA<B>
42            + VecZnxDftAllocBytes
43            + VecZnxBigNormalize<B>
44            + VecZnxDftApply<B>
45            + SvpApplyDftToDftInplace<B>
46            + VecZnxIdftApplyConsume<B>
47            + VecZnxNormalizeTmpBytes
48            + VecZnxFillUniform
49            + VecZnxSubInplace
50            + VecZnxAddInplace
51            + VecZnxNormalizeInplace<B>
52            + VecZnxAddNormal
53            + VecZnxNormalize<B>
54            + VecZnxSub
55            + VecZnxSwitchRing
56            + VecZnxAddScalarInplace
57            + SvpPrepare<B>
58            + SvpPPolAllocBytes
59            + SvpPPolAlloc<B>,
60        Scratch<B>: ScratchAvailable
61            + TakeScalarZnx
62            + TakeVecZnxDft<B>
63            + TakeGLWESecretPrepared<B>
64            + ScratchAvailable
65            + TakeVecZnx
66            + TakeVecZnxBig<B>,
67    {
68        #[cfg(debug_assertions)]
69        {
70            assert_eq!(self.rank_out(), sk.rank());
71            assert_eq!(self.n(), sk.n());
72        }
73
74        let n: usize = sk.n().into();
75        let rank: usize = self.rank_out().into();
76
77        let (mut sk_dft_prep, scratch_1) = scratch.take_glwe_secret_prepared(sk.n(), self.rank_out());
78        sk_dft_prep.prepare(module, sk, scratch_1);
79
80        let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(n, rank, 1);
81
82        for i in 0..rank {
83            module.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
84        }
85
86        let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(n, 1, 1);
87        let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(sk.n(), Rank(1));
88        let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(n, 1, 1);
89
90        let mut source_xa: Source = Source::new(seed_xa);
91
92        for i in 0..rank {
93            for j in i..rank {
94                module.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
95
96                module.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
97                module.vec_znx_big_normalize(
98                    self.base2k().into(),
99                    &mut sk_ij.data.as_vec_znx_mut(),
100                    0,
101                    self.base2k().into(),
102                    &sk_ij_big,
103                    0,
104                    scratch_5,
105                );
106
107                let (seed_xa_tmp, _) = source_xa.branch();
108
109                self.at_mut(i, j)
110                    .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, scratch_5);
111            }
112        }
113    }
114}