poulpy_core/keyswitching/
lwe_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxView,
5 ZnxViewMut, ZnxZero,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch},
8};
9
10use crate::{
11 TakeGLWECt,
12 layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWESwitchingKeyPrepared},
13};
14
15impl LWECiphertext<Vec<u8>> {
16 pub fn keyswitch_scratch_space<B: Backend>(
17 module: &Module<B>,
18 n: usize,
19 basek: usize,
20 k_lwe_out: usize,
21 k_lwe_in: usize,
22 k_ksk: usize,
23 ) -> usize
24 where
25 Module<B>: VecZnxDftAllocBytes
26 + VmpApplyTmpBytes
27 + VecZnxBigNormalizeTmpBytes
28 + VmpApplyTmpBytes
29 + VmpApply<B>
30 + VmpApplyAdd<B>
31 + VecZnxDftFromVecZnx<B>
32 + VecZnxDftToVecZnxBigConsume<B>
33 + VecZnxBigAddSmallInplace<B>
34 + VecZnxBigNormalize<B>,
35 {
36 GLWECiphertext::bytes_of(n, basek, k_lwe_out.max(k_lwe_in), 1)
37 + GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_lwe_out, k_ksk, 1, 1)
38 }
39}
40
41impl<DLwe: DataMut> LWECiphertext<DLwe> {
42 pub fn keyswitch<A, DKs, B: Backend>(
43 &mut self,
44 module: &Module<B>,
45 a: &LWECiphertext<A>,
46 ksk: &LWESwitchingKeyPrepared<DKs, B>,
47 scratch: &mut Scratch<B>,
48 ) where
49 A: DataRef,
50 DKs: DataRef,
51 Module<B>: VecZnxDftAllocBytes
52 + VmpApplyTmpBytes
53 + VecZnxBigNormalizeTmpBytes
54 + VmpApply<B>
55 + VmpApplyAdd<B>
56 + VecZnxDftFromVecZnx<B>
57 + VecZnxDftToVecZnxBigConsume<B>
58 + VecZnxBigAddSmallInplace<B>
59 + VecZnxBigNormalize<B>,
60 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
61 {
62 #[cfg(debug_assertions)]
63 {
64 assert!(self.n() <= module.n());
65 assert!(a.n() <= module.n());
66 assert_eq!(self.basek(), a.basek());
67 }
68
69 let max_k: usize = self.k().max(a.k());
70 let basek: usize = self.basek();
71
72 let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
73 glwe.data.zero();
74
75 let n_lwe: usize = a.n();
76
77 (0..a.size()).for_each(|i| {
78 let data_lwe: &[i64] = a.data.at(0, i);
79 glwe.data.at_mut(0, i)[0] = data_lwe[0];
80 glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
81 });
82
83 glwe.keyswitch_inplace(module, &ksk.0, scratch1);
84
85 self.sample_extract(&glwe);
86 }
87}