1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5 VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero},
8};
9
10use crate::{
11 TakeGLWECt,
12 layouts::{
13 GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos,
14 prepared::LWEToGLWESwitchingKeyPrepared,
15 },
16};
17
18impl GLWECiphertext<Vec<u8>> {
19 pub fn from_lwe_scratch_space<B: Backend, OUT, IN, KEY>(
20 module: &Module<B>,
21 glwe_infos: &OUT,
22 lwe_infos: &IN,
23 key_infos: &KEY,
24 ) -> usize
25 where
26 OUT: GLWEInfos,
27 IN: LWEInfos,
28 KEY: GGLWELayoutInfos,
29 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
30 {
31 let ct: usize = GLWECiphertext::alloc_bytes_with(
32 module.n().into(),
33 key_infos.base2k(),
34 lwe_infos.k().max(glwe_infos.k()),
35 1u32.into(),
36 );
37 let ks: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, glwe_infos, key_infos);
38 if lwe_infos.base2k() == key_infos.base2k() {
39 ct + ks
40 } else {
41 let a_conv = VecZnx::alloc_bytes(module.n(), 1, lwe_infos.size()) + module.vec_znx_normalize_tmp_bytes();
42 ct + a_conv + ks
43 }
44 }
45}
46
47impl<D: DataMut> GLWECiphertext<D> {
48 pub fn from_lwe<DLwe, DKsk, B: Backend>(
49 &mut self,
50 module: &Module<B>,
51 lwe: &LWECiphertext<DLwe>,
52 ksk: &LWEToGLWESwitchingKeyPrepared<DKsk, B>,
53 scratch: &mut Scratch<B>,
54 ) where
55 DLwe: DataRef,
56 DKsk: DataRef,
57 Module<B>: VecZnxDftAllocBytes
58 + VmpApplyDftToDftTmpBytes
59 + VecZnxBigNormalizeTmpBytes
60 + VmpApplyDftToDft<B>
61 + VmpApplyDftToDftAdd<B>
62 + VecZnxDftApply<B>
63 + VecZnxIdftApplyConsume<B>
64 + VecZnxBigAddSmallInplace<B>
65 + VecZnxBigNormalize<B>
66 + VecZnxNormalize<B>
67 + VecZnxNormalizeTmpBytes,
68 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeGLWECt + TakeVecZnx,
69 {
70 #[cfg(debug_assertions)]
71 {
72 assert_eq!(self.n(), module.n() as u32);
73 assert_eq!(ksk.n(), module.n() as u32);
74 assert!(lwe.n() <= module.n() as u32);
75 }
76
77 let (mut glwe, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
78 n: ksk.n(),
79 base2k: ksk.base2k(),
80 k: lwe.k(),
81 rank: 1u32.into(),
82 });
83 glwe.data.zero();
84
85 let n_lwe: usize = lwe.n().into();
86
87 if lwe.base2k() == ksk.base2k() {
88 for i in 0..lwe.size() {
89 let data_lwe: &[i64] = lwe.data.at(0, i);
90 glwe.data.at_mut(0, i)[0] = data_lwe[0];
91 glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
92 }
93 } else {
94 let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, lwe.size());
95 a_conv.zero();
96 for j in 0..lwe.size() {
97 let data_lwe: &[i64] = lwe.data.at(0, j);
98 a_conv.at_mut(0, j)[0] = data_lwe[0]
99 }
100
101 module.vec_znx_normalize(
102 ksk.base2k().into(),
103 &mut glwe.data,
104 0,
105 lwe.base2k().into(),
106 &a_conv,
107 0,
108 scratch_2,
109 );
110
111 a_conv.zero();
112 for j in 0..lwe.size() {
113 let data_lwe: &[i64] = lwe.data.at(0, j);
114 a_conv.at_mut(0, j)[..n_lwe].copy_from_slice(&data_lwe[1..]);
115 }
116
117 module.vec_znx_normalize(
118 ksk.base2k().into(),
119 &mut glwe.data,
120 1,
121 lwe.base2k().into(),
122 &a_conv,
123 0,
124 scratch_2,
125 );
126 }
127
128 self.keyswitch(module, &glwe, &ksk.0, scratch_1);
129 }
130}