1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDftInplace, SvpPPolAllocBytes, SvpPrepare, TakeScalarZnx, TakeVecZnx, TakeVecZnxDft,
4 VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply,
5 VecZnxFillUniform, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub,
6 VecZnxSubInplace, VecZnxSwitchRing,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
9 source::Source,
10};
11
12use crate::{
13 TakeGLWESecretPrepared,
14 layouts::{
15 Degree, GGLWECiphertext, GGLWELayoutInfos, GGLWESwitchingKey, GLWEInfos, GLWESecret, LWEInfos,
16 prepared::GLWESecretPrepared,
17 },
18};
19
20impl GGLWESwitchingKey<Vec<u8>> {
21 pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
22 where
23 A: GGLWELayoutInfos,
24 Module<B>: SvpPPolAllocBytes + VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
25 {
26 (GGLWECiphertext::encrypt_sk_scratch_space(module, infos) | ScalarZnx::alloc_bytes(module.n(), 1))
27 + ScalarZnx::alloc_bytes(module.n(), infos.rank_in().into())
28 + GLWESecretPrepared::alloc_bytes(module, &infos.glwe_layout())
29 }
30
31 pub fn encrypt_pk_scratch_space<B: Backend, A>(module: &Module<B>, _infos: &A) -> usize
32 where
33 A: GGLWELayoutInfos,
34 {
35 GGLWECiphertext::encrypt_pk_scratch_space(module, _infos)
36 }
37}
38
39impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
40 #[allow(clippy::too_many_arguments)]
41 pub fn encrypt_sk<DataSkIn: DataRef, DataSkOut: DataRef, B: Backend>(
42 &mut self,
43 module: &Module<B>,
44 sk_in: &GLWESecret<DataSkIn>,
45 sk_out: &GLWESecret<DataSkOut>,
46 source_xa: &mut Source,
47 source_xe: &mut Source,
48 scratch: &mut Scratch<B>,
49 ) where
50 Module<B>: VecZnxAddScalarInplace
51 + VecZnxDftAllocBytes
52 + VecZnxBigNormalize<B>
53 + VecZnxDftApply<B>
54 + SvpApplyDftToDftInplace<B>
55 + VecZnxIdftApplyConsume<B>
56 + VecZnxNormalizeTmpBytes
57 + VecZnxFillUniform
58 + VecZnxSubInplace
59 + VecZnxAddInplace
60 + VecZnxNormalizeInplace<B>
61 + VecZnxAddNormal
62 + VecZnxNormalize<B>
63 + VecZnxSub
64 + SvpPrepare<B>
65 + VecZnxSwitchRing
66 + SvpPPolAllocBytes,
67 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx + TakeScalarZnx + TakeGLWESecretPrepared<B>,
68 {
69 #[cfg(debug_assertions)]
70 {
71 assert!(sk_in.n().0 <= module.n() as u32);
72 assert!(sk_out.n().0 <= module.n() as u32);
73 assert!(
74 scratch.available() >= GGLWESwitchingKey::encrypt_sk_scratch_space(module, self),
75 "scratch.available()={} < GLWESwitchingKey::encrypt_sk_scratch_space={}",
76 scratch.available(),
77 GGLWESwitchingKey::encrypt_sk_scratch_space(module, self)
78 )
79 }
80
81 let n: usize = sk_in.n().max(sk_out.n()).into();
82
83 let (mut sk_in_tmp, scratch_1) = scratch.take_scalar_znx(n, sk_in.rank().into());
84 (0..sk_in.rank().into()).for_each(|i| {
85 module.vec_znx_switch_ring(
86 &mut sk_in_tmp.as_vec_znx_mut(),
87 i,
88 &sk_in.data.as_vec_znx(),
89 i,
90 );
91 });
92
93 let (mut sk_out_tmp, scratch_2) = scratch_1.take_glwe_secret_prepared(Degree(n as u32), sk_out.rank());
94 {
95 let (mut tmp, _) = scratch_2.take_scalar_znx(n, 1);
96 (0..sk_out.rank().into()).for_each(|i| {
97 module.vec_znx_switch_ring(&mut tmp.as_vec_znx_mut(), 0, &sk_out.data.as_vec_znx(), i);
98 module.svp_prepare(&mut sk_out_tmp.data, i, &tmp, 0);
99 });
100 }
101
102 self.key.encrypt_sk(
103 module,
104 &sk_in_tmp,
105 &sk_out_tmp,
106 source_xa,
107 source_xe,
108 scratch_2,
109 );
110 self.sk_in_n = sk_in.n().into();
111 self.sk_out_n = sk_out.n().into();
112 }
113}