1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4 VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx,
5 VecZnxDftToVecZnxBigConsume, VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes,
6 VecZnxSub, VecZnxSubABInplace, VecZnxSwithcDegree,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecretPrepared,
14 layouts::{GGLWECiphertext, GGLWESwitchingKey, GLWESecret, prepared::GLWESecretPrepared},
15};
16
17impl GGLWESwitchingKey<Vec<u8>> {
18 pub fn encrypt_sk_scratch_space<B: Backend>(
19 module: &Module<B>,
20 n: usize,
21 basek: usize,
22 k: usize,
23 rank_in: usize,
24 rank_out: usize,
25 ) -> usize
26 where
27 Module<B>: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
28 {
29 (GGLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k) | ScalarZnx::alloc_bytes(n, 1))
30 + ScalarZnx::alloc_bytes(n, rank_in)
31 + GLWESecretPrepared::bytes_of(module, n, rank_out)
32 }
33
34 pub fn encrypt_pk_scratch_space<B: Backend>(
35 module: &Module<B>,
36 _n: usize,
37 _basek: usize,
38 _k: usize,
39 _rank_in: usize,
40 _rank_out: usize,
41 ) -> usize {
42 GGLWECiphertext::encrypt_pk_scratch_space(module, _n, _basek, _k, _rank_out)
43 }
44}
45
46impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
47 #[allow(clippy::too_many_arguments)]
48 pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
49 &mut self,
50 module: &Module<B>,
51 sk_in: &GLWESecret<DataSkIn>,
52 sk_out: &GLWESecret<DataSkOut>,
53 source_xa: &mut Source,
54 source_xe: &mut Source,
55 sigma: f64,
56 scratch: &mut Scratch<B>,
57 ) where
58 Module<B>: VecZnxAddScalarInplace
59 + VecZnxDftAllocBytes
60 + VecZnxBigNormalize<B>
61 + VecZnxDftFromVecZnx<B>
62 + SvpApplyInplace<B>
63 + VecZnxDftToVecZnxBigConsume<B>
64 + VecZnxNormalizeTmpBytes
65 + VecZnxFillUniform
66 + VecZnxSubABInplace
67 + VecZnxAddInplace
68 + VecZnxNormalizeInplace<B>
69 + VecZnxAddNormal
70 + VecZnxNormalize<B>
71 + VecZnxSub
72 + SvpPrepare<B>
73 + VecZnxSwithcDegree
74 + SvpPPolAllocBytes,
75 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
76 {
77 #[cfg(debug_assertions)]
78 {
79 use crate::layouts::Infos;
80
81 assert!(sk_in.n() <= module.n());
82 assert!(sk_out.n() <= module.n());
83 assert!(
84 scratch.available()
85 >= GGLWESwitchingKey::encrypt_sk_scratch_space(
86 module,
87 sk_out.n(),
88 self.basek(),
89 self.k(),
90 self.rank_in(),
91 self.rank_out()
92 ),
93 "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
94 scratch.available(),
95 GGLWESwitchingKey::encrypt_sk_scratch_space(
96 module,
97 sk_out.n(),
98 self.basek(),
99 self.k(),
100 self.rank_in(),
101 self.rank_out()
102 )
103 )
104 }
105
106 let n: usize = sk_in.n().max(sk_out.n());
107
108 let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
109 (0..sk_in.rank()).for_each(|i| {
110 module.vec_znx_switch_degree(
111 &mut sk_in_tmp.as_vec_znx_mut(),
112 i,
113 &sk_in.data.as_vec_znx(),
114 i,
115 );
116 });
117
118 let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank());
119 {
120 let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
121 (0..sk_out.rank()).for_each(|i| {
122 module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
123 module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
124 });
125 }
126
127 self.key.encrypt_sk(
128 module,
129 &sk_in_tmp,
130 &sk_out_tmp,
131 source_xa,
132 source_xe,
133 sigma,
134 scratch2,
135 );
136 self.sk_in_n = sk_in.n();
137 self.sk_out_n = sk_out.n();
138 }
139}