poulpy_core/keyswitching/
lwe_ct.rs1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
4 VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5 },
6 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
7};
8
9use crate::{
10 TakeGLWECt,
11 layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWESwitchingKeyPrepared},
12};
13
14impl LWECiphertext<Vec<u8>> {
15 pub fn keyswitch_scratch_space<B: Backend>(
16 module: &Module<B>,
17 basek: usize,
18 k_lwe_out: usize,
19 k_lwe_in: usize,
20 k_ksk: usize,
21 ) -> usize
22 where
23 Module<B>: VecZnxDftAllocBytes
24 + VmpApplyDftToDftTmpBytes
25 + VecZnxBigNormalizeTmpBytes
26 + VmpApplyDftToDftTmpBytes
27 + VmpApplyDftToDft<B>
28 + VmpApplyDftToDftAdd<B>
29 + DFT<B>
30 + IDFTConsume<B>
31 + VecZnxBigAddSmallInplace<B>
32 + VecZnxBigNormalize<B>,
33 {
34 GLWECiphertext::bytes_of(module.n(), basek, k_lwe_out.max(k_lwe_in), 1)
35 + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1)
36 }
37}
38
39impl<DLwe: DataMut> LWECiphertext<DLwe> {
40 pub fn keyswitch<A, DKs, B: Backend>(
41 &mut self,
42 module: &Module<B>,
43 a: &LWECiphertext<A>,
44 ksk: &LWESwitchingKeyPrepared<DKs, B>,
45 scratch: &mut Scratch<B>,
46 ) where
47 A: DataRef,
48 DKs: DataRef,
49 Module<B>: VecZnxDftAllocBytes
50 + VmpApplyDftToDftTmpBytes
51 + VecZnxBigNormalizeTmpBytes
52 + VmpApplyDftToDft<B>
53 + VmpApplyDftToDftAdd<B>
54 + DFT<B>
55 + IDFTConsume<B>
56 + VecZnxBigAddSmallInplace<B>
57 + VecZnxBigNormalize<B>,
58 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
59 {
60 #[cfg(debug_assertions)]
61 {
62 assert!(self.n() <= module.n());
63 assert!(a.n() <= module.n());
64 assert_eq!(self.basek(), a.basek());
65 }
66
67 let max_k: usize = self.k().max(a.k());
68 let basek: usize = self.basek();
69
70 let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
71 glwe.data.zero();
72
73 let n_lwe: usize = a.n();
74
75 (0..a.size()).for_each(|i| {
76 let data_lwe: &[i64] = a.data.at(0, i);
77 glwe.data.at_mut(0, i)[0] = data_lwe[0];
78 glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
79 });
80
81 glwe.keyswitch_inplace(module, &ksk.0, scratch1);
82
83 self.sample_extract(&glwe);
84 }
85}