poulpy_core/encryption/
gglwe_ksk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4        VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
5        VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
6        VecZnxSubABInplace, VecZnxSwitchRing,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecretPrepared,
14    layouts::{GGLWECiphertext, GGLWESwitchingKey, GLWESecret, prepared::GLWESecretPrepared},
15};
16
17impl GGLWESwitchingKey<Vec<u8>> {
18    pub fn encrypt_sk_scratch_space<B: Backend>(
19        module: &Module<B>,
20        basek: usize,
21        k: usize,
22        rank_in: usize,
23        rank_out: usize,
24    ) -> usize
25    where
26        Module<B>: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
27    {
28        (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | ScalarZnx::alloc_bytes(module.n(), 1))
29            + ScalarZnx::alloc_bytes(module.n(), rank_in)
30            + GLWESecretPrepared::bytes_of(module, rank_out)
31    }
32
33    pub fn encrypt_pk_scratch_space<B: Backend>(
34        module: &Module<B>,
35        _basek: usize,
36        _k: usize,
37        _rank_in: usize,
38        _rank_out: usize,
39    ) -> usize {
40        GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out)
41    }
42}
43
44impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
45    #[allow(clippy::too_many_arguments)]
46    pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
47        &mut self,
48        module: &Module<B>,
49        sk_in: &GLWESecret<DataSkIn>,
50        sk_out: &GLWESecret<DataSkOut>,
51        source_xa: &mut Source,
52        source_xe: &mut Source,
53        scratch: &mut Scratch<B>,
54    ) where
55        Module<B>: VecZnxAddScalarInplace
56            + VecZnxDftAllocBytes
57            + VecZnxBigNormalize<B>
58            + VecZnxDftApply<B>
59            + SvpApplyDftToDftInplace<B>
60            + VecZnxIdftApplyConsume<B>
61            + VecZnxNormalizeTmpBytes
62            + VecZnxFillUniform
63            + VecZnxSubABInplace
64            + VecZnxAddInplace
65            + VecZnxNormalizeInplace<B>
66            + VecZnxAddNormal
67            + VecZnxNormalize<B>
68            + VecZnxSub
69            + SvpPrepare<B>
70            + VecZnxSwitchRing
71            + SvpPPolAllocBytes,
72        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
73    {
74        #[cfg(debug_assertions)]
75        {
76            use crate::layouts::Infos;
77
78            assert!(sk_in.n() <= module.n());
79            assert!(sk_out.n() <= module.n());
80            assert!(
81                scratch.available()
82                    >= GGLWESwitchingKey::encrypt_sk_scratch_space(
83                        module,
84                        self.basek(),
85                        self.k(),
86                        self.rank_in(),
87                        self.rank_out()
88                    ),
89                "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
90                scratch.available(),
91                GGLWESwitchingKey::encrypt_sk_scratch_space(
92                    module,
93                    self.basek(),
94                    self.k(),
95                    self.rank_in(),
96                    self.rank_out()
97                )
98            )
99        }
100
101        let n: usize = sk_in.n().max(sk_out.n());
102
103        let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank());
104        (0..sk_in.rank()).for_each(|i| {
105            module.vec_znx_switch_ring(
106                &mut sk_in_tmp.as_vec_znx_mut(),
107                i,
108                &sk_in.data.as_vec_znx(),
109                i,
110            );
111        });
112
113        let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank());
114        {
115            let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1);
116            (0..sk_out.rank()).for_each(|i| {
117                module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
118                module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
119            });
120        }
121
122        self.key.encrypt_sk(
123            module,
124            &sk_in_tmp,
125            &sk_out_tmp,
126            source_xa,
127            source_xe,
128            scratch_2,
129        );
130        self.sk_in_n = sk_in.n();
131        self.sk_out_n = sk_out.n();
132    }
133}