poulpy_core/encryption/
gglwe_ksk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4        VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
5        VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
6        VecZnxSubInplace, VecZnxSwitchRing,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecretPrepared,
14    layouts::{
15        Degree, GGLWECiphertext, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos,
16        prepared::GLWESecretPrepared,
17    },
18};
19
20impl GGLWESwitchingKey<Vec<u8>> {
21    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
22    where
23        A: GGLWELayoutInfos,
24        Module<B>: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
25    {
26        (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1))
27            + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into())
28            + GLWESecretPrepared::alloc_bytes(module, &infos.glwe_layout())
29    }
30
31    pub fn encrypt_pk_scratch_space<B: Backend, A>(module: &Module<B>, _infos: &A) -> usize
32    where
33        A: GGLWELayoutInfos,
34    {
35        GGLWECiphertext::encrypt_pk_scratch_space(module, _infos)
36    }
37}
38
39impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
40    #[allow(clippy::too_many_arguments)]
41    pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
42        &mut self,
43        module: &Module<B>,
44        sk_in: &GLWESecret<DataSkIn>,
45        sk_out: &GLWESecret<DataSkOut>,
46        source_xa: &mut Source,
47        source_xe: &mut Source,
48        scratch: &mut Scratch<B>,
49    ) where
50        Module<B>: VecZnxAddScalarInplace
51            + VecZnxDftAllocBytes
52            + VecZnxBigNormalize<B>
53            + VecZnxDftApply<B>
54            + SvpApplyDftToDftInplace<B>
55            + VecZnxIdftApplyConsume<B>
56            + VecZnxNormalizeTmpBytes
57            + VecZnxFillUniform
58            + VecZnxSubInplace
59            + VecZnxAddInplace
60            + VecZnxNormalizeInplace<B>
61            + VecZnxAddNormal
62            + VecZnxNormalize<B>
63            + VecZnxSub
64            + SvpPrepare<B>
65            + VecZnxSwitchRing
66            + SvpPPolAllocBytes,
67        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
68    {
69        #[cfg(debug_assertions)]
70        {
71            assert!(sk_in.n().0 <= module.n() as u32);
72            assert!(sk_out.n().0 <= module.n() as u32);
73            assert!(
74                scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self),
75                "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
76                scratch.available(),
77                GGLWESwitchingKey::encrypt_sk_scratch_space(module, self)
78            )
79        }
80
81        let n: usize = sk_in.n().max(sk_out.n()).into();
82
83        let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into());
84        (0..sk_in.rank().into()).for_each(|i| {
85            module.vec_znx_switch_ring(
86                &mut sk_in_tmp.as_vec_znx_mut(),
87                i,
88                &sk_in.data.as_vec_znx(),
89                i,
90            );
91        });
92
93        let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank());
94        {
95            let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1);
96            (0..sk_out.rank().into()).for_each(|i| {
97                module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
98                module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
99            });
100        }
101
102        self.key.encrypt_sk(
103            module,
104            &sk_in_tmp,
105            &sk_out_tmp,
106            source_xa,
107            source_xe,
108            scratch_2,
109        );
110        self.sk_in_n = sk_in.n().into();
111        self.sk_out_n = sk_out.n().into();
112    }
113}