poulpy_core/encryption/
gglwe_ksk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4        VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx,
5        VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
6        VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9    source::Source,
10};
11
12use crate::{
13    TakeGLWESecretPrepared,
14    layouts::{GGLWECiphertext, GGLWESwitchingKey, GLWESecret, prepared::GLWESecretPrepared},
15};
16
17impl GGLWESwitchingKey<Vec<u8>> {
18    pub fn encrypt_sk_scratch_space<B: Backend>(
19        module: &Module<B>,
20        n: usize,
21        basek: usize,
22        k: usize,
23        rank_in: usize,
24        rank_out: usize,
25    ) -> usize
26    where
27        Module<B>: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
28    {
29        (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1))
30            + ScalarZnx::alloc_bytes(n, rank_in)
31            + GLWESecretPrepared::bytes_of(module, n, rank_out)
32    }
33
34    pub fn encrypt_pk_scratch_space<B: Backend>(
35        module: &Module<B>,
36        _n: usize,
37        _basek: usize,
38        _k: usize,
39        _rank_in: usize,
40        _rank_out: usize,
41    ) -> usize {
42        GGLWECiphertext::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank_out)
43    }
44}
45
46impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
47    #[allow(clippy::too_many_arguments)]
48    pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
49        &mut self,
50        module: &Module<B>,
51        sk_in: &GLWESecret<DataSkIn>,
52        sk_out: &GLWESecret<DataSkOut>,
53        source_xa: &mut Source,
54        source_xe: &mut Source,
55        sigma: f64,
56        scratch: &mut Scratch<B>,
57    ) where
58        Module<B>: VecZnxAddScalarInplace
59            + VecZnxDftAllocBytes
60            + VecZnxBigNormalize<B>
61            + VecZnxDftFromVecZnx<B>
62            + SvpApplyInplace<B>
63            + VecZnxDftToVecZnxBigConsume<B>
64            + VecZnxNormalizeTmpBytes
65            + VecZnxFillUniform
66            + VecZnxSubABInplace
67            + VecZnxAddInplace
68            + VecZnxNormalizeInplace<B>
69            + VecZnxAddNormal
70            + VecZnxNormalize<B>
71            + VecZnxSub
72            + SvpPrepare<B>
73            + VecZnxSwithcDegree
74            + SvpPPolAllocBytes,
75        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
76    {
77        #[cfg(debug_assertions)]
78        {
79            use crate::layouts::Infos;
80
81            assert!(sk_in.n() <= module.n());
82            assert!(sk_out.n() <= module.n());
83            assert!(
84                scratch.available()
85                    >= GGLWESwitchingKey::encrypt_sk_scratch_space(
86                        module,
87                        sk_out.n(),
88                        self.basek(),
89                        self.k(),
90                        self.rank_in(),
91                        self.rank_out()
92                    ),
93                "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
94                scratch.available(),
95                GGLWESwitchingKey::encrypt_sk_scratch_space(
96                    module,
97                    sk_out.n(),
98                    self.basek(),
99                    self.k(),
100                    self.rank_in(),
101                    self.rank_out()
102                )
103            )
104        }
105
106        let n: usize = sk_in.n().max(sk_out.n());
107
108        let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
109        (0..sk_in.rank()).for_each(|i| {
110            module.vec_znx_switch_degree(
111                &mut sk_in_tmp.as_vec_znx_mut(),
112                i,
113                &sk_in.data.as_vec_znx(),
114                i,
115            );
116        });
117
118        let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank());
119        {
120            let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
121            (0..sk_out.rank()).for_each(|i| {
122                module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
123                module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
124            });
125        }
126
127        self.key.encrypt_sk(
128            module,
129            &sk_in_tmp,
130            &sk_out_tmp,
131            source_xa,
132            source_xe,
133            sigma,
134            scratch2,
135        );
136        self.sk_in_n = sk_in.n();
137        self.sk_out_n = sk_out.n();
138    }
139}