poulpy_core/keyswitching/
lwe_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
5 VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
8};
9
10use crate::{
11 TakeGLWECt,
12 layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWESwitchingKeyPrepared},
13};
14
15impl LWECiphertext<Vec<u8>> {
16 pub fn keyswitch_scratch_space<B: Backend>(
17 module: &Module<B>,
18 basek: usize,
19 k_lwe_out: usize,
20 k_lwe_in: usize,
21 k_ksk: usize,
22 ) -> usize
23 where
24 Module<B>: VecZnxDftAllocBytes
25 + VmpApplyDftToDftTmpBytes
26 + VecZnxBigNormalizeTmpBytes
27 + VmpApplyDftToDftTmpBytes
28 + VmpApplyDftToDft<B>
29 + VmpApplyDftToDftAdd<B>
30 + VecZnxDftApply<B>
31 + VecZnxIdftApplyConsume<B>
32 + VecZnxBigAddSmallInplace<B>
33 + VecZnxBigNormalize<B>,
34 {
35 GLWECiphertext::bytes_of(module.n(), basek, k_lwe_out.max(k_lwe_in), 1)
36 + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1)
37 }
38}
39
40impl<DLwe: DataMut> LWECiphertext<DLwe> {
41 pub fn keyswitch<A, DKs, B: Backend>(
42 &mut self,
43 module: &Module<B>,
44 a: &LWECiphertext<A>,
45 ksk: &LWESwitchingKeyPrepared<DKs, B>,
46 scratch: &mut Scratch<B>,
47 ) where
48 A: DataRef,
49 DKs: DataRef,
50 Module<B>: VecZnxDftAllocBytes
51 + VmpApplyDftToDftTmpBytes
52 + VecZnxBigNormalizeTmpBytes
53 + VmpApplyDftToDft<B>
54 + VmpApplyDftToDftAdd<B>
55 + VecZnxDftApply<B>
56 + VecZnxIdftApplyConsume<B>
57 + VecZnxBigAddSmallInplace<B>
58 + VecZnxBigNormalize<B>,
59 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
60 {
61 #[cfg(debug_assertions)]
62 {
63 assert!(self.n() <= module.n());
64 assert!(a.n() <= module.n());
65 assert_eq!(self.basek(), a.basek());
66 }
67
68 let max_k: usize = self.k().max(a.k());
69 let basek: usize = self.basek();
70
71 let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
72 glwe.data.zero();
73
74 let n_lwe: usize = a.n();
75
76 (0..a.size()).for_each(|i| {
77 let data_lwe: &[i64] = a.data.at(0, i);
78 glwe.data.at_mut(0, i)[0] = data_lwe[0];
79 glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
80 });
81
82 glwe.keyswitch_inplace(module, &ksk.0, scratch_1);
83
84 self.sample_extract(&glwe);
85 }
86}