poulpy_core/encryption/compressed/
gglwe_ksk.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4 VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx,
5 VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
6 VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecretPrepared,
14 layouts::{GGLWECiphertext, GLWESecret, compressed::GGLWESwitchingKeyCompressed, prepared::GLWESecretPrepared},
15};
16
17impl GGLWESwitchingKeyCompressed<Vec<u8>> {
18 pub fn encrypt_sk_scratch_space<B: Backend>(
19 module: &Module<B>,
20 n: usize,
21 basek: usize,
22 k: usize,
23 rank_in: usize,
24 rank_out: usize,
25 ) -> usize
26 where
27 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes + SvpPPolAllocBytes,
28 {
29 (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1))
30 + ScalarZnx::alloc_bytes(n, rank_in)
31 + GLWESecretPrepared::bytes_of(module, n, rank_out)
32 }
33}
34
35impl<DataSelf: DataMut> GGLWESwitchingKeyCompressed<DataSelf> {
36 #[allow(clippy::too_many_arguments)]
37 pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
38 &mut self,
39 module: &Module<B>,
40 sk_in: &GLWESecret<DataSkIn>,
41 sk_out: &GLWESecret<DataSkOut>,
42 seed_xa: [u8; 32],
43 source_xe: &mut Source,
44 sigma: f64,
45 scratch: &mut Scratch<B>,
46 ) where
47 Module<B>: SvpPrepare<B>
48 + SvpPPolAllocBytes
49 + VecZnxSwithcDegree
50 + VecZnxDftAllocBytes
51 + VecZnxBigNormalize<B>
52 + VecZnxDftFromVecZnx<B>
53 + SvpApplyInplace<B>
54 + VecZnxDftToVecZnxBigConsume<B>
55 + VecZnxNormalizeTmpBytes
56 + VecZnxFillUniform
57 + VecZnxSubABInplace
58 + VecZnxAddInplace
59 + VecZnxNormalizeInplace<B>
60 + VecZnxAddNormal
61 + VecZnxNormalize<B>
62 + VecZnxSub
63 + VecZnxAddScalarInplace,
64 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
65 {
66 #[cfg(debug_assertions)]
67 {
68 use crate::layouts::{GGLWESwitchingKey, Infos};
69
70 assert!(sk_in.n() <= module.n());
71 assert!(sk_out.n() <= module.n());
72 assert!(
73 scratch.available()
74 >= GGLWESwitchingKey::encrypt_sk_scratch_space(
75 module,
76 sk_out.n(),
77 self.basek(),
78 self.k(),
79 self.rank_in(),
80 self.rank_out()
81 ),
82 "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
83 scratch.available(),
84 GGLWESwitchingKey::encrypt_sk_scratch_space(
85 module,
86 sk_out.n(),
87 self.basek(),
88 self.k(),
89 self.rank_in(),
90 self.rank_out()
91 )
92 )
93 }
94
95 let n: usize = sk_in.n().max(sk_out.n());
96
97 let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
98 (0..sk_in.rank()).for_each(|i| {
99 module.vec_znx_switch_degree(
100 &mut sk_in_tmp.as_vec_znx_mut(),
101 i,
102 &sk_in.data.as_vec_znx(),
103 i,
104 );
105 });
106
107 let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank());
108 {
109 let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
110 (0..sk_out.rank()).for_each(|i| {
111 module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
112 module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
113 });
114 }
115
116 self.key.encrypt_sk(
117 module,
118 &sk_in_tmp,
119 &sk_out_tmp,
120 seed_xa,
121 source_xe,
122 sigma,
123 scratch2,
124 );
125 self.sk_in_n = sk_in.n();
126 self.sk_out_n = sk_out.n();
127 }
128}