poulpy_core/keyswitching/
lwe_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes,
5 VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
8};
9
10use crate::{
11 TakeGLWECt,
12 layouts::{
13 GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, LWECiphertext, LWEInfos, Rank, TorusPrecision,
14 prepared::LWESwitchingKeyPrepared,
15 },
16};
17
18impl LWECiphertext<Vec<u8>> {
19 pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
20 module: &Module<B>,
21 out_infos: &OUT,
22 in_infos: &IN,
23 key_infos: &KEY,
24 ) -> usize
25 where
26 OUT: LWEInfos,
27 IN: LWEInfos,
28 KEY: GGLWELayoutInfos,
29 Module<B>: VecZnxDftAllocBytes
30 + VmpApplyDftToDftTmpBytes
31 + VecZnxBigNormalizeTmpBytes
32 + VmpApplyDftToDftTmpBytes
33 + VmpApplyDftToDft<B>
34 + VmpApplyDftToDftAdd<B>
35 + VecZnxDftApply<B>
36 + VecZnxIdftApplyConsume<B>
37 + VecZnxBigAddSmallInplace<B>
38 + VecZnxBigNormalize<B>
39 + VecZnxNormalizeTmpBytes,
40 {
41 let max_k: TorusPrecision = in_infos.k().max(out_infos.k());
42
43 let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout {
44 n: module.n().into(),
45 base2k: in_infos.base2k(),
46 k: max_k,
47 rank: Rank(1),
48 };
49
50 let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout {
51 n: module.n().into(),
52 base2k: out_infos.base2k(),
53 k: max_k,
54 rank: Rank(1),
55 };
56
57 let glwe_in: usize = GLWECiphertext::alloc_bytes(&glwe_in_infos);
58 let glwe_out: usize = GLWECiphertext::alloc_bytes(&glwe_out_infos);
59 let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, key_infos);
60
61 glwe_in + glwe_out + ks
62 }
63}
64
65impl<DLwe: DataMut> LWECiphertext<DLwe> {
66 pub fn keyswitch<A, DKs, B: Backend>(
67 &mut self,
68 module: &Module<B>,
69 a: &LWECiphertext<A>,
70 ksk: &LWESwitchingKeyPrepared<DKs, B>,
71 scratch: &mut Scratch<B>,
72 ) where
73 A: DataRef,
74 DKs: DataRef,
75 Module<B>: VecZnxDftAllocBytes
76 + VmpApplyDftToDftTmpBytes
77 + VecZnxBigNormalizeTmpBytes
78 + VmpApplyDftToDft<B>
79 + VmpApplyDftToDftAdd<B>
80 + VecZnxDftApply<B>
81 + VecZnxIdftApplyConsume<B>
82 + VecZnxBigAddSmallInplace<B>
83 + VecZnxBigNormalize<B>
84 + VecZnxNormalize<B>
85 + VecZnxNormalizeTmpBytes
86 + VecZnxCopy,
87 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
88 {
89 #[cfg(debug_assertions)]
90 {
91 assert!(self.n() <= module.n() as u32);
92 assert!(a.n() <= module.n() as u32);
93 assert!(scratch.available() >= LWECiphertext::keyswitch_scratch_space(module, self, a, ksk));
94 }
95
96 let max_k: TorusPrecision = self.k().max(a.k());
97
98 let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize;
99
100 let (mut glwe_in, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
101 n: ksk.n(),
102 base2k: a.base2k(),
103 k: max_k,
104 rank: Rank(1),
105 });
106 glwe_in.data.zero();
107
108 let (mut glwe_out, scratch_1) = scratch_1.take_glwe_ct(&GLWECiphertextLayout {
109 n: ksk.n(),
110 base2k: self.base2k(),
111 k: max_k,
112 rank: Rank(1),
113 });
114
115 let n_lwe: usize = a.n().into();
116
117 for i in 0..a_size {
118 let data_lwe: &[i64] = a.data.at(0, i);
119 glwe_in.data.at_mut(0, i)[0] = data_lwe[0];
120 glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
121 }
122
123 glwe_out.keyswitch(module, &glwe_in, &ksk.0, scratch_1);
124 self.sample_extract(&glwe_out);
125 }
126}