poulpy_core/encryption/compressed/
gglwe_ksk.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4 VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
5 VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
6 VecZnxSubABInplace, VecZnxSwitchRing,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecretPrepared,
14 layouts::{GGLWECiphertext, GLWESecret, compressed::GGLWESwitchingKeyCompressed, prepared::GLWESecretPrepared},
15};
16
17impl GGLWESwitchingKeyCompressed<Vec<u8>> {
18 pub fn encrypt_sk_scratch_space<B: Backend>(
19 module: &Module<B>,
20 basek: usize,
21 k: usize,
22 rank_in: usize,
23 rank_out: usize,
24 ) -> usize
25 where
26 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes,
27 {
28 (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | ScalarZnx::alloc_bytes(module.n(), 1))
29 + ScalarZnx::alloc_bytes(module.n(), rank_in)
30 + GLWESecretPrepared::bytes_of(module, rank_out)
31 }
32}
33
34impl<DataSelf: DataMut> GGLWESwitchingKeyCompressed<DataSelf> {
35 #[allow(clippy::too_many_arguments)]
36 pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
37 &mut self,
38 module: &Module<B>,
39 sk_in: &GLWESecret<DataSkIn>,
40 sk_out: &GLWESecret<DataSkOut>,
41 seed_xa: [u8; 32],
42 source_xe: &mut Source,
43 scratch: &mut Scratch<B>,
44 ) where
45 Module<B>: SvpPrepare<B>
46 + SvpPPolAllocBytes
47 + VecZnxSwitchRing
48 + VecZnxDftAllocBytes
49 + VecZnxBigNormalize<B>
50 + VecZnxDftApply<B>
51 + SvpApplyDftToDftInplace<B>
52 + VecZnxIdftApplyConsume<B>
53 + VecZnxNormalizeTmpBytes
54 + VecZnxFillUniform
55 + VecZnxSubABInplace
56 + VecZnxAddInplace
57 + VecZnxNormalizeInplace<B>
58 + VecZnxAddNormal
59 + VecZnxNormalize<B>
60 + VecZnxSub
61 + VecZnxAddScalarInplace,
62 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
63 {
64 #[cfg(debug_assertions)]
65 {
66 use crate::layouts::{GGLWESwitchingKey, Infos};
67
68 assert!(sk_in.n() <= module.n());
69 assert!(sk_out.n() <= module.n());
70 assert!(
71 scratch.available()
72 >= GGLWESwitchingKey::encrypt_sk_scratch_space(
73 module,
74 self.basek(),
75 self.k(),
76 self.rank_in(),
77 self.rank_out()
78 ),
79 "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
80 scratch.available(),
81 GGLWESwitchingKey::encrypt_sk_scratch_space(
82 module,
83 self.basek(),
84 self.k(),
85 self.rank_in(),
86 self.rank_out()
87 )
88 )
89 }
90
91 let n: usize = sk_in.n().max(sk_out.n());
92
93 let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank());
94 (0..sk_in.rank()).for_each(|i| {
95 module.vec_znx_switch_ring(
96 &mut sk_in_tmp.as_vec_znx_mut(),
97 i,
98 &sk_in.data.as_vec_znx(),
99 i,
100 );
101 });
102
103 let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(n, sk_out.rank());
104 {
105 let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1);
106 (0..sk_out.rank()).for_each(|i| {
107 module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
108 module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
109 });
110 }
111
112 self.key.encrypt_sk(
113 module,
114 &sk_in_tmp,
115 &sk_out_tmp,
116 seed_xa,
117 source_xe,
118 scratch_2,
119 );
120 self.sk_in_n = sk_in.n();
121 self.sk_out_n = sk_out.n();
122 }
123}