poulpy_core/encryption/
gglwe_atk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4        VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxAutomorphism, VecZnxBigNormalize, VecZnxDftAllocBytes,
5        VecZnxDftApply, VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace,
6        VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace, VecZnxSwitchRing,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecret, TakeGLWESecretPrepared,
14    layouts::{GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos},
15};
16
17impl GGLWEAutomorphismKey<Vec<u8>> {
18    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
19    where
20        A: GGLWELayoutInfos,
21        Module<B>: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
22    {
23        assert_eq!(
24            infos.rank_in(),
25            infos.rank_out(),
26            "rank_in != rank_out is not supported for GGLWEAutomorphismKey"
27        );
28        GGLWESwitchingKey::encrypt_sk_scratch_space(module, infos) + GLWESecret::alloc_bytes(&infos.glwe_layout())
29    }
30
31    pub fn encrypt_pk_scratch_space<B: Backend, A>(module: &Module<B>, _infos: &A) -> usize
32    where
33        A: GGLWELayoutInfos,
34    {
35        assert_eq!(
36            _infos.rank_in(),
37            _infos.rank_out(),
38            "rank_in != rank_out is not supported for GGLWEAutomorphismKey"
39        );
40        GGLWESwitchingKey::encrypt_pk_scratch_space(module, _infos)
41    }
42}
43
44impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
45    #[allow(clippy::too_many_arguments)]
46    pub fn encrypt_sk<DataSk: DataRef, B: Backend>(
47        &mut self,
48        module: &Module<B>,
49        p: i64,
50        sk: &GLWESecret<DataSk>,
51        source_xa: &mut Source,
52        source_xe: &mut Source,
53        scratch: &mut Scratch<B>,
54    ) where
55        Module<B>: VecZnxAddScalarInplace
56            + VecZnxDftAllocBytes
57            + VecZnxBigNormalize<B>
58            + VecZnxDftApply<B>
59            + SvpApplyDftToDftInplace<B>
60            + VecZnxIdftApplyConsume<B>
61            + VecZnxNormalizeTmpBytes
62            + VecZnxFillUniform
63            + VecZnxSubInplace
64            + VecZnxAddInplace
65            + VecZnxNormalizeInplace<B>
66            + VecZnxAddNormal
67            + VecZnxNormalize<B>
68            + VecZnxSub
69            + SvpPrepare<B>
70            + VecZnxSwitchRing
71            + SvpPPolAllocBytes
72            + VecZnxAutomorphism,
73        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
74    {
75        #[cfg(debug_assertions)]
76        {
77            use crate::layouts::{GLWEInfos, LWEInfos};
78
79            assert_eq!(self.n(), sk.n());
80            assert_eq!(self.rank_out(), self.rank_in());
81            assert_eq!(sk.rank(), self.rank_out());
82            assert!(
83                scratch.available() >= GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self),
84                "scratch.available(): {} < AutomorphismKey::encrypt_sk_scratch_space: {:?}",
85                scratch.available(),
86                GGLWEAutomorphismKey::encrypt_sk_scratch_space(module, self)
87            )
88        }
89
90        let (mut sk_out, scratch_1) = scratch.take_glwe_secret(sk.n(), sk.rank());
91
92        {
93            (0..self.rank_out().into()).for_each(|i| {
94                module.vec_znx_automorphism(
95                    module.galois_element_inv(p),
96                    &mut sk_out.data.as_vec_znx_mut(),
97                    i,
98                    &sk.data.as_vec_znx(),
99                    i,
100                );
101            });
102        }
103
104        self.key
105            .encrypt_sk(module, sk, &sk_out, source_xa, source_xe, scratch_1);
106
107        self.p = p;
108    }
109}