poulpy_core/encryption/compressed/
gglwe_ksk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4        VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx,
5        VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
6        VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecretPrepared,
14    layouts::{GGLWECiphertext, GLWESecret, compressed::GGLWESwitchingKeyCompressed, prepared::GLWESecretPrepared},
15};
16
17impl GGLWESwitchingKeyCompressed<Vec<u8>> {
18    pub fn encrypt_sk_scratch_space<B: Backend>(
19        module: &Module<B>,
20        n: usize,
21        basek: usize,
22        k: usize,
23        rank_in: usize,
24        rank_out: usize,
25    ) -> usize
26    where
27        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes,
28    {
29        (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1))
30            + ScalarZnx::alloc_bytes(n, rank_in)
31            + GLWESecretPrepared::bytes_of(module, n, rank_out)
32    }
33}
34
35impl<DataSelf: DataMut> GGLWESwitchingKeyCompressed<DataSelf> {
36    #[allow(clippy::too_many_arguments)]
37    pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
38        &mut self,
39        module: &Module<B>,
40        sk_in: &GLWESecret<DataSkIn>,
41        sk_out: &GLWESecret<DataSkOut>,
42        seed_xa: [u8; 32],
43        source_xe: &mut Source,
44        sigma: f64,
45        scratch: &mut Scratch<B>,
46    ) where
47        Module<B>: SvpPrepare<B>
48            + SvpPPolAllocBytes
49            + VecZnxSwithcDegree
50            + VecZnxDftAllocBytes
51            + VecZnxBigNormalize<B>
52            + VecZnxDftFromVecZnx<B>
53            + SvpApplyInplace<B>
54            + VecZnxDftToVecZnxBigConsume<B>
55            + VecZnxNormalizeTmpBytes
56            + VecZnxFillUniform
57            + VecZnxSubABInplace
58            + VecZnxAddInplace
59            + VecZnxNormalizeInplace<B>
60            + VecZnxAddNormal
61            + VecZnxNormalize<B>
62            + VecZnxSub
63            + VecZnxAddScalarInplace,
64        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
65    {
66        #[cfg(debug_assertions)]
67        {
68            use crate::layouts::{GGLWESwitchingKey, Infos};
69
70            assert!(sk_in.n() <= module.n());
71            assert!(sk_out.n() <= module.n());
72            assert!(
73                scratch.available()
74                    >= GGLWESwitchingKey::encrypt_sk_scratch_space(
75                        module,
76                        sk_out.n(),
77                        self.basek(),
78                        self.k(),
79                        self.rank_in(),
80                        self.rank_out()
81                    ),
82                "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
83                scratch.available(),
84                GGLWESwitchingKey::encrypt_sk_scratch_space(
85                    module,
86                    sk_out.n(),
87                    self.basek(),
88                    self.k(),
89                    self.rank_in(),
90                    self.rank_out()
91                )
92            )
93        }
94
95        let n: usize = sk_in.n().max(sk_out.n());
96
97        let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
98        (0..sk_in.rank()).for_each(|i| {
99            module.vec_znx_switch_degree(
100                &mut sk_in_tmp.as_vec_znx_mut(),
101                i,
102                &sk_in.data.as_vec_znx(),
103                i,
104            );
105        });
106
107        let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank());
108        {
109            let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
110            (0..sk_out.rank()).for_each(|i| {
111                module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
112                module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
113            });
114        }
115
116        self.key.encrypt_sk(
117            module,
118            &sk_in_tmp,
119            &sk_out_tmp,
120            seed_xa,
121            source_xe,
122            sigma,
123            scratch2,
124        );
125        self.sk_in_n = sk_in.n();
126        self.sk_out_n = sk_out.n();
127    }
128}