1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, ScratchAvailable, SvpApplyInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx,
4 TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes,
5 VecZnxFillUniform, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
6 VecZnxSwithcDegree,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecretPrepared,
14 layouts::{GGLWECiphertext, GGLWESwitchingKey, GLWESecret, prepared::GLWESecretPrepared},
15};
16
17impl GGLWESwitchingKey<Vec<u8>> {
18 pub fn encrypt_sk_scratch_space<B: Backend>(
19 module: &Module<B>,
20 basek: usize,
21 k: usize,
22 rank_in: usize,
23 rank_out: usize,
24 ) -> usize
25 where
26 Module<B>: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
27 {
28 (GGLWECiphertext::encrypt_sk_scratch_space(module, basek, k) | ScalarZnx::alloc_bytes(module.n(), 1))
29 + ScalarZnx::alloc_bytes(module.n(), rank_in)
30 + GLWESecretPrepared::bytes_of(module, rank_out)
31 }
32
33 pub fn encrypt_pk_scratch_space<B: Backend>(
34 module: &Module<B>,
35 _basek: usize,
36 _k: usize,
37 _rank_in: usize,
38 _rank_out: usize,
39 ) -> usize {
40 GGLWECiphertext::encrypt_pk_scratch_space(module, _basek, _k, _rank_out)
41 }
42}
43
44impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
45 #[allow(clippy::too_many_arguments)]
46 pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
47 &mut self,
48 module: &Module<B>,
49 sk_in: &GLWESecret<DataSkIn>,
50 sk_out: &GLWESecret<DataSkOut>,
51 source_xa: &mut Source,
52 source_xe: &mut Source,
53 scratch: &mut Scratch<B>,
54 ) where
55 Module<B>: VecZnxAddScalarInplace
56 + VecZnxDftAllocBytes
57 + VecZnxBigNormalize<B>
58 + DFT<B>
59 + SvpApplyInplace<B>
60 + IDFTConsume<B>
61 + VecZnxNormalizeTmpBytes
62 + VecZnxFillUniform
63 + VecZnxSubABInplace
64 + VecZnxAddInplace
65 + VecZnxNormalizeInplace<B>
66 + VecZnxAddNormal
67 + VecZnxNormalize<B>
68 + VecZnxSub
69 + SvpPrepare<B>
70 + VecZnxSwithcDegree
71 + SvpPPolAllocBytes,
72 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
73 {
74 #[cfg(debug_assertions)]
75 {
76 use crate::layouts::Infos;
77
78 assert!(sk_in.n() <= module.n());
79 assert!(sk_out.n() <= module.n());
80 assert!(
81 scratch.available()
82 >= GGLWESwitchingKey::encrypt_sk_scratch_space(
83 module,
84 self.basek(),
85 self.k(),
86 self.rank_in(),
87 self.rank_out()
88 ),
89 "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
90 scratch.available(),
91 GGLWESwitchingKey::encrypt_sk_scratch_space(
92 module,
93 self.basek(),
94 self.k(),
95 self.rank_in(),
96 self.rank_out()
97 )
98 )
99 }
100
101 let n: usize = sk_in.n().max(sk_out.n());
102
103 let (mut sk_in_tmp, scratch1) = scratch.take_scalar_znx(n, sk_in.rank());
104 (0..sk_in.rank()).for_each(|i| {
105 module.vec_znx_switch_degree(
106 &mut sk_in_tmp.as_vec_znx_mut(),
107 i,
108 &sk_in.data.as_vec_znx(),
109 i,
110 );
111 });
112
113 let (mut sk_out_tmp, scratch2) = scratch1.take_glwe_secret_prepared(n, sk_out.rank());
114 {
115 let (mut tmp, _) = scratch2.take_scalar_znx(n, 1);
116 (0..sk_out.rank()).for_each(|i| {
117 module.vec_znx_switch_degree(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
118 module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
119 });
120 }
121
122 self.key.encrypt_sk(
123 module,
124 &sk_in_tmp,
125 &sk_out_tmp,
126 source_xa,
127 source_xe,
128 scratch2,
129 );
130 self.sk_in_n = sk_in.n();
131 self.sk_out_n = sk_out.n();
132 }
133}