poulpy_core/encryption/compressed/
gglwe_tsk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApply, SvpApplyInplace, SvpPPolAlloc, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
4        TakeVecZnxBig, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigAllocBytes,
5        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA,
6        VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
7        VecZnxSwithcDegree,
8    },
9    layouts::{Backend, DataMut, DataRef, Module, Scratch},
10    source::Source,
11};
12
13use crate::{
14    TakeGLWESecret, TakeGLWESecretPrepared,
15    layouts::{GGLWETensorKey, GLWESecret, Infos, compressed::GGLWETensorKeyCompressed, prepared::Prepare},
16};
17
18impl GGLWETensorKeyCompressed<Vec<u8>> {
19    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize, rank: usize) -> usize
20    where
21        Module<B>:
22            SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigAllocBytes,
23    {
24        GGLWETensorKey::encrypt_sk_scratch_space(module, n, basek, k, rank)
25    }
26}
27
28impl<DataSelf: DataMut> GGLWETensorKeyCompressed<DataSelf> {
29    pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
30        &mut self,
31        module: &Module<B>,
32        sk: &GLWESecret<DataSk>,
33        seed_xa: [u8; 32],
34        source_xe: &mut Source,
35        sigma: f64,
36        scratch: &mut Scratch<B>,
37    ) where
38        Module<B>: SvpApply<B>
39            + VecZnxDftToVecZnxBigTmpA<B>
40            + VecZnxDftAllocBytes
41            + VecZnxBigNormalize<B>
42            + VecZnxDftFromVecZnx<B>
43            + SvpApplyInplace<B>
44            + VecZnxDftToVecZnxBigConsume<B>
45            + VecZnxNormalizeTmpBytes
46            + VecZnxFillUniform
47            + VecZnxSubABInplace
48            + VecZnxAddInplace
49            + VecZnxNormalizeInplace<B>
50            + VecZnxAddNormal
51            + VecZnxNormalize<B>
52            + VecZnxSub
53            + VecZnxSwithcDegree
54            + VecZnxAddScalarInplace
55            + SvpPrepare<B>
56            + SvpPPolAllocBytes
57            + SvpPPolAlloc<B>,
58        Scratch<B>: ScratchAvailable
59            + TakeScalarZnx
60            + TakeVecZnxDft<B>
61            + TakeGLWESecretPrepared<B>
62            + ScratchAvailable
63            + TakeVecZnx
64            + TakeVecZnxBig<B>,
65    {
66        #[cfg(debug_assertions)]
67        {
68            assert_eq!(self.rank(), sk.rank());
69            assert_eq!(self.n(), sk.n());
70        }
71
72        let n: usize = sk.n();
73        let rank: usize = self.rank();
74
75        let (mut sk_dft_prep, scratch1) = scratch.take_glwe_secret_prepared(n, rank);
76        sk_dft_prep.prepare(module, sk, scratch1);
77
78        let (mut sk_dft, scratch2) = scratch1.take_vec_znx_dft(n, rank, 1);
79
80        (0..rank).for_each(|i| {
81            module.vec_znx_dft_from_vec_znx(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
82        });
83
84        let (mut sk_ij_big, scratch3) = scratch2.take_vec_znx_big(n, 1, 1);
85        let (mut sk_ij, scratch4) = scratch3.take_glwe_secret(n, 1);
86        let (mut sk_ij_dft, scratch5) = scratch4.take_vec_znx_dft(n, 1, 1);
87
88        let mut source_xa: Source = Source::new(seed_xa);
89
90        (0..rank).for_each(|i| {
91            (i..rank).for_each(|j| {
92                module.svp_apply(&mut sk_ij_dft, 0, &sk_dft_prep.data, j, &sk_dft, i);
93
94                module.vec_znx_dft_to_vec_znx_big_tmp_a(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
95                module.vec_znx_big_normalize(
96                    self.basek(),
97                    &mut sk_ij.data.as_vec_znx_mut(),
98                    0,
99                    &sk_ij_big,
100                    0,
101                    scratch5,
102                );
103
104                let (seed_xa_tmp, _) = source_xa.branch();
105
106                self.at_mut(i, j)
107                    .encrypt_sk(module, &sk_ij, sk, seed_xa_tmp, source_xe, sigma, scratch5);
108            });
109        })
110    }
111}