1use poulpy_hal::{
2 api::ScratchTakeBasic,
3 layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero},
4};
5
6use crate::{
7 GLWEKeyswitch, ScratchTakeCore,
8 layouts::{GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, LWE, LWEInfos, LWEToRef},
9};
10
11impl<BE: Backend> GLWEFromLWE<BE> for Module<BE> where Self: GLWEKeyswitch<BE> {}
12
13pub trait GLWEFromLWE<BE: Backend>
14where
15 Self: GLWEKeyswitch<BE>,
16{
17 fn glwe_from_lwe_tmp_bytes<R, A, K>(&self, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize
18 where
19 R: GLWEInfos,
20 A: LWEInfos,
21 K: GGLWEInfos,
22 {
23 let ct: usize = GLWE::bytes_of(
24 self.n().into(),
25 key_infos.base2k(),
26 lwe_infos.k().max(glwe_infos.k()),
27 1u32.into(),
28 );
29
30 let ks: usize = self.glwe_keyswitch_tmp_bytes(glwe_infos, glwe_infos, key_infos);
31 if lwe_infos.base2k() == key_infos.base2k() {
32 ct + ks
33 } else {
34 let a_conv = VecZnx::bytes_of(self.n(), 1, lwe_infos.size()) + self.vec_znx_normalize_tmp_bytes();
35 ct + a_conv + ks
36 }
37 }
38
39 fn glwe_from_lwe<R, A, K>(&self, res: &mut R, lwe: &A, ksk: &K, scratch: &mut Scratch<BE>)
40 where
41 R: GLWEToMut,
42 A: LWEToRef,
43 K: GGLWEPreparedToRef<BE> + GGLWEInfos,
44 Scratch<BE>: ScratchTakeCore<BE>,
45 {
46 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
47 let lwe: &LWE<&[u8]> = &lwe.to_ref();
48
49 assert_eq!(res.n(), self.n() as u32);
50 assert_eq!(ksk.n(), self.n() as u32);
51 assert!(lwe.n() <= self.n() as u32);
52
53 let (mut glwe, scratch_1) = scratch.take_glwe(&GLWELayout {
54 n: ksk.n(),
55 base2k: ksk.base2k(),
56 k: lwe.k(),
57 rank: 1u32.into(),
58 });
59 glwe.data.zero();
60
61 let n_lwe: usize = lwe.n().into();
62
63 if lwe.base2k() == ksk.base2k() {
64 for i in 0..lwe.size() {
65 let data_lwe: &[i64] = lwe.data.at(0, i);
66 glwe.data.at_mut(0, i)[0] = data_lwe[0];
67 glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
68 }
69 } else {
70 let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, lwe.size());
71 a_conv.zero();
72 for j in 0..lwe.size() {
73 let data_lwe: &[i64] = lwe.data.at(0, j);
74 a_conv.at_mut(0, j)[0] = data_lwe[0]
75 }
76
77 self.vec_znx_normalize(
78 ksk.base2k().into(),
79 &mut glwe.data,
80 0,
81 lwe.base2k().into(),
82 &a_conv,
83 0,
84 scratch_2,
85 );
86
87 a_conv.zero();
88 for j in 0..lwe.size() {
89 let data_lwe: &[i64] = lwe.data.at(0, j);
90 a_conv.at_mut(0, j)[..n_lwe].copy_from_slice(&data_lwe[1..]);
91 }
92
93 self.vec_znx_normalize(
94 ksk.base2k().into(),
95 &mut glwe.data,
96 1,
97 lwe.base2k().into(),
98 &a_conv,
99 0,
100 scratch_2,
101 );
102 }
103
104 self.glwe_keyswitch(res, &glwe, ksk, scratch_1);
105 }
106}
107
108impl GLWE<Vec<u8>> {
109 pub fn from_lwe_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, glwe_infos: &R, lwe_infos: &A, key_infos: &K) -> usize
110 where
111 R: GLWEInfos,
112 A: LWEInfos,
113 K: GGLWEInfos,
114 M: GLWEFromLWE<BE>,
115 {
116 module.glwe_from_lwe_tmp_bytes(glwe_infos, lwe_infos, key_infos)
117 }
118}
119
120impl<D: DataMut> GLWE<D> {
121 pub fn from_lwe<A, K, M, BE: Backend>(&mut self, module: &M, lwe: &A, ksk: &K, scratch: &mut Scratch<BE>)
122 where
123 M: GLWEFromLWE<BE>,
124 A: LWEToRef,
125 K: GGLWEPreparedToRef<BE> + GGLWEInfos,
126 Scratch<BE>: ScratchTakeCore<BE>,
127 {
128 module.glwe_from_lwe(self, lwe, ksk, scratch);
129 }
130}