1use poulpy_hal::{
2 api::{
3 ScratchAvailable, ScratchTakeBasic, VecZnxBigAddSmallInplace, VecZnxBigBytesOf, VecZnxBigNormalize,
4 VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
5 },
6 layouts::{Backend, DataMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VecZnxDftToRef, VecZnxToRef},
7};
8
9use crate::{
10 GGLWEProduct, GLWECopy, ScratchTakeCore,
11 layouts::{
12 GGLWE, GGLWEInfos, GGLWEToGGSWKeyPrepared, GGLWEToGGSWKeyPreparedToRef, GGLWEToRef, GGSW, GGSWInfos, GGSWToMut, GLWE,
13 GLWEInfos, LWEInfos,
14 },
15};
16
17impl GGLWE<Vec<u8>> {
18 pub fn from_gglw_tmp_bytes<R, A, M, BE: Backend>(module: &M, res_infos: &R, tsk_infos: &A) -> usize
19 where
20 M: GGSWFromGGLWE<BE>,
21 R: GGSWInfos,
22 A: GGLWEInfos,
23 {
24 module.ggsw_from_gglwe_tmp_bytes(res_infos, tsk_infos)
25 }
26}
27
28impl<D: DataMut> GGSW<D> {
29 pub fn from_gglwe<G, M, T, BE: Backend>(&mut self, module: &M, gglwe: &G, tsk: &T, scratch: &mut Scratch<BE>)
30 where
31 M: GGSWFromGGLWE<BE>,
32 G: GGLWEToRef,
33 T: GGLWEToGGSWKeyPreparedToRef<BE>,
34 Scratch<BE>: ScratchTakeCore<BE>,
35 {
36 module.ggsw_from_gglwe(self, gglwe, tsk, scratch);
37 }
38}
39
40impl<BE: Backend> GGSWFromGGLWE<BE> for Module<BE>
41where
42 Self: GGSWExpandRows<BE> + GLWECopy,
43{
44 fn ggsw_from_gglwe_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
45 where
46 R: GGSWInfos,
47 A: GGLWEInfos,
48 {
49 self.ggsw_expand_rows_tmp_bytes(res_infos, tsk_infos)
50 }
51
52 fn ggsw_from_gglwe<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
53 where
54 R: GGSWToMut,
55 A: GGLWEToRef,
56 T: GGLWEToGGSWKeyPreparedToRef<BE>,
57 Scratch<BE>: ScratchTakeCore<BE>,
58 {
59 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
60 let a: &GGLWE<&[u8]> = &a.to_ref();
61 let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
62
63 assert_eq!(res.rank(), a.rank_out());
64 assert_eq!(res.dnum(), a.dnum());
65 assert_eq!(res.n(), self.n() as u32);
66 assert_eq!(a.n(), self.n() as u32);
67 assert_eq!(tsk.n(), self.n() as u32);
68 assert_eq!(res.base2k(), a.base2k());
69
70 for row in 0..res.dnum().into() {
71 self.glwe_copy(&mut res.at_mut(row, 0), &a.at(row, 0));
72 }
73
74 self.ggsw_expand_row(res, tsk, scratch);
75 }
76}
77
78pub trait GGSWFromGGLWE<BE: Backend> {
79 fn ggsw_from_gglwe_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
80 where
81 R: GGSWInfos,
82 A: GGLWEInfos;
83
84 fn ggsw_from_gglwe<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
85 where
86 R: GGSWToMut,
87 A: GGLWEToRef,
88 T: GGLWEToGGSWKeyPreparedToRef<BE>,
89 Scratch<BE>: ScratchTakeCore<BE>;
90}
91
92pub trait GGSWExpandRows<BE: Backend> {
93 fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
94 where
95 R: GGSWInfos,
96 A: GGLWEInfos;
97
98 fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
99 where
100 R: GGSWToMut,
101 T: GGLWEToGGSWKeyPreparedToRef<BE>,
102 Scratch<BE>: ScratchTakeCore<BE>;
103}
104
105impl<BE: Backend> GGSWExpandRows<BE> for Module<BE>
106where
107 Self: GGLWEProduct<BE>
108 + VecZnxBigNormalize<BE>
109 + VecZnxBigNormalizeTmpBytes
110 + VecZnxBigBytesOf
111 + VecZnxDftBytesOf
112 + VecZnxDftApply<BE>
113 + VecZnxNormalize<BE>
114 + VecZnxBigAddSmallInplace<BE>
115 + VecZnxIdftApplyConsume<BE>
116 + VecZnxCopy,
117{
118 fn ggsw_expand_rows_tmp_bytes<R, A>(&self, res_infos: &R, tsk_infos: &A) -> usize
119 where
120 R: GGSWInfos,
121 A: GGLWEInfos,
122 {
123 let base2k_tsk: usize = tsk_infos.base2k().into();
124
125 let rank: usize = res_infos.rank().into();
126 let cols: usize = rank + 1;
127
128 let res_size: usize = res_infos.size();
129 let a_size: usize = res_infos.max_k().as_usize().div_ceil(base2k_tsk);
130
131 let a_0: usize = VecZnx::bytes_of(self.n(), 1, a_size);
132 let a_dft: usize = self.bytes_of_vec_znx_dft(cols - 1, a_size);
133 let res_dft: usize = self.bytes_of_vec_znx_dft(cols, a_size);
134 let gglwe_prod: usize = self.gglwe_product_dft_tmp_bytes(res_size, a_size, tsk_infos);
135 let normalize: usize = self.vec_znx_big_normalize_tmp_bytes();
136
137 (a_0 + a_dft + res_dft + gglwe_prod).max(normalize)
138 }
139
140 fn ggsw_expand_row<R, T>(&self, res: &mut R, tsk: &T, scratch: &mut Scratch<BE>)
141 where
142 R: GGSWToMut,
143 T: GGLWEToGGSWKeyPreparedToRef<BE>,
144 Scratch<BE>: ScratchTakeCore<BE>,
145 {
146 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
147 let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
148
149 let base2k_res: usize = res.base2k().into();
150 let base2k_tsk: usize = tsk.base2k().into();
151
152 assert!(scratch.available() >= self.ggsw_expand_rows_tmp_bytes(res, tsk));
153
154 let rank: usize = res.rank().into();
155 let cols: usize = rank + 1;
156
157 let res_conv_size: usize = res.max_k().as_usize().div_ceil(base2k_tsk);
158
159 let (mut a_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols - 1, res_conv_size);
160 let (mut a_0, scratch_2) = scratch_1.take_vec_znx(self.n(), 1, res_conv_size);
161
162 for row in 0..res.dnum().as_usize() {
164 let glwe_mi_1: &GLWE<&[u8]> = &res.at(row, 0);
165
166 if base2k_res == base2k_tsk {
167 for col_i in 0..cols - 1 {
168 self.vec_znx_dft_apply(1, 0, &mut a_dft, col_i, glwe_mi_1.data(), col_i + 1);
169 }
170 self.vec_znx_copy(&mut a_0, 0, glwe_mi_1.data(), 0);
171 } else {
172 for i in 0..cols - 1 {
173 self.vec_znx_normalize(
174 base2k_tsk,
175 &mut a_0,
176 0,
177 base2k_res,
178 glwe_mi_1.data(),
179 i + 1,
180 scratch_2,
181 );
182 self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a_0, 0);
183 }
184 self.vec_znx_normalize(
185 base2k_tsk,
186 &mut a_0,
187 0,
188 base2k_res,
189 glwe_mi_1.data(),
190 0,
191 scratch_2,
192 );
193 }
194
195 ggsw_expand_rows_internal(self, row, res, &a_0, &a_dft, tsk, scratch_2)
196 }
197 }
198}
199
200fn ggsw_expand_rows_internal<M, R, C, A, T, BE: Backend>(
201 module: &M,
202 row: usize,
203 res: &mut R,
204 a_0: &C,
205 a_dft: &A,
206 tsk: &T,
207 scratch: &mut Scratch<BE>,
208) where
209 R: GGSWToMut,
210 C: VecZnxToRef,
211 A: VecZnxDftToRef<BE>,
212 M: GGLWEProduct<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigAddSmallInplace<BE> + VecZnxBigNormalize<BE>,
213 T: GGLWEToGGSWKeyPreparedToRef<BE>,
214 Scratch<BE>: ScratchTakeCore<BE>,
215{
216 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
217 let a_0: &VecZnx<&[u8]> = &a_0.to_ref();
218 let a_dft: &VecZnxDft<&[u8], BE> = &a_dft.to_ref();
219 let tsk: &GGLWEToGGSWKeyPrepared<&[u8], BE> = &tsk.to_ref();
220 let cols: usize = res.rank().as_usize() + 1;
221
222 for col in 1..cols {
242 let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(module, cols, tsk.size()); module.gglwe_product_dft(&mut res_dft, a_dft, tsk.at(col - 1), scratch_1);
256
257 let mut res_big: VecZnxBig<&mut [u8], BE> = module.vec_znx_idft_apply_consume(res_dft);
258
259 module.vec_znx_big_add_small_inplace(&mut res_big, col, a_0, 0);
269
270 for j in 0..cols {
271 module.vec_znx_big_normalize(
272 res.base2k().as_usize(),
273 res.at_mut(row, col).data_mut(),
274 j,
275 tsk.base2k().as_usize(),
276 &res_big,
277 j,
278 scratch_1,
279 );
280 }
281 }
282}