1use poulpy_hal::{
2 api::{
3 ModuleN, ScratchTakeBasic, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalize,
4 VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5 },
6 layouts::{Backend, DataMut, DataViewMut, Module, Scratch, VecZnx, VecZnxBig},
7};
8
9use crate::{
10 ScratchTakeCore,
11 layouts::{
12 GGSWInfos, GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos,
13 prepared::{GGSWPrepared, GGSWPreparedToRef},
14 },
15};
16
17impl GLWE<Vec<u8>> {
18 pub fn external_product_tmp_bytes<R, A, B, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
19 where
20 R: GLWEInfos,
21 A: GLWEInfos,
22 B: GGSWInfos,
23 M: GLWEExternalProduct<BE>,
24 {
25 module.glwe_external_product_tmp_bytes(res_infos, a_infos, b_infos)
26 }
27}
28
29impl<DataSelf: DataMut> GLWE<DataSelf> {
30 pub fn external_product<A, B, M, BE: Backend>(&mut self, module: &M, a: &A, b: &B, scratch: &mut Scratch<BE>)
31 where
32 A: GLWEToRef,
33 B: GGSWPreparedToRef<BE>,
34 M: GLWEExternalProduct<BE>,
35 Scratch<BE>: ScratchTakeCore<BE>,
36 {
37 module.glwe_external_product(self, a, b, scratch);
38 }
39
40 pub fn external_product_inplace<A, M, BE: Backend>(&mut self, module: &M, a: &A, scratch: &mut Scratch<BE>)
41 where
42 A: GGSWPreparedToRef<BE>,
43 M: GLWEExternalProduct<BE>,
44 Scratch<BE>: ScratchTakeCore<BE>,
45 {
46 module.glwe_external_product_inplace(self, a, scratch);
47 }
48}
49
50pub trait GLWEExternalProduct<BE: Backend>
51where
52 Self: Sized
53 + ModuleN
54 + VecZnxDftBytesOf
55 + VmpApplyDftToDftTmpBytes
56 + VecZnxNormalizeTmpBytes
57 + VecZnxDftApply<BE>
58 + VmpApplyDftToDft<BE>
59 + VmpApplyDftToDftAdd<BE>
60 + VecZnxIdftApplyConsume<BE>
61 + VecZnxBigNormalize<BE>
62 + VecZnxNormalize<BE>,
63{
64 fn glwe_external_product_tmp_bytes<R, A, B>(&self, res_infos: &R, a_infos: &A, b_infos: &B) -> usize
65 where
66 R: GLWEInfos,
67 A: GLWEInfos,
68 B: GGSWInfos,
69 {
70 let in_size: usize = a_infos
71 .k()
72 .div_ceil(b_infos.base2k())
73 .div_ceil(b_infos.dsize().into()) as usize;
74 let out_size: usize = res_infos.size();
75 let ggsw_size: usize = b_infos.size();
76 let res_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), ggsw_size);
77 let a_dft: usize = self.bytes_of_vec_znx_dft((b_infos.rank() + 1).into(), in_size);
78 let vmp: usize = self.vmp_apply_dft_to_dft_tmp_bytes(
79 out_size,
80 in_size,
81 in_size, (b_infos.rank() + 1).into(), (b_infos.rank() + 1).into(), ggsw_size,
85 );
86 let normalize_big: usize = self.vec_znx_normalize_tmp_bytes();
87
88 if a_infos.base2k() == b_infos.base2k() {
89 res_dft + a_dft + (vmp | normalize_big)
90 } else {
91 let normalize_conv: usize = VecZnx::bytes_of(self.n(), (b_infos.rank() + 1).into(), in_size);
92 res_dft + ((a_dft + normalize_conv + (self.vec_znx_normalize_tmp_bytes() | vmp)) | normalize_big)
93 }
94 }
95
96 fn glwe_external_product_inplace<R, D>(&self, res: &mut R, a: &D, scratch: &mut Scratch<BE>)
97 where
98 R: GLWEToMut,
99 D: GGSWPreparedToRef<BE>,
100 Scratch<BE>: ScratchTakeCore<BE>,
101 {
102 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
103 let rhs: &GGSWPrepared<&[u8], BE> = &a.to_ref();
104
105 let basek_in: usize = res.base2k().into();
106 let basek_ggsw: usize = rhs.base2k().into();
107
108 #[cfg(debug_assertions)]
109 {
110 use poulpy_hal::api::ScratchAvailable;
111
112 assert_eq!(rhs.rank(), res.rank());
113 assert_eq!(rhs.n(), res.n());
114 assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, res, rhs));
115 }
116
117 let cols: usize = (rhs.rank() + 1).into();
118 let dsize: usize = rhs.dsize().into();
119 let a_size: usize = (res.size() * basek_in).div_ceil(basek_ggsw);
120
121 let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
123 a_dft.data_mut().fill(0);
124
125 if basek_in == basek_ggsw {
126 for di in 0..dsize {
127 a_dft.set_size((res.size() + di) / dsize);
129
130 res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
138
139 for j in 0..cols {
140 self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
141 }
142
143 if di == 0 {
144 self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
145 } else {
146 self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
147 }
148 }
149 } else {
150 let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
151
152 for j in 0..cols {
153 self.vec_znx_normalize(
154 basek_ggsw,
155 &mut a_conv,
156 j,
157 basek_in,
158 &res.data,
159 j,
160 scratch_3,
161 );
162 }
163
164 for di in 0..dsize {
165 a_dft.set_size((res.size() + di) / dsize);
167
168 res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
176
177 for j in 0..cols {
178 self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &res.data, j);
179 }
180
181 if di == 0 {
182 self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
183 } else {
184 self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
185 }
186 }
187 }
188
189 let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
190
191 for j in 0..cols {
192 self.vec_znx_big_normalize(
193 basek_in,
194 &mut res.data,
195 j,
196 basek_ggsw,
197 &res_big,
198 j,
199 scratch_1,
200 );
201 }
202 }
203
204 fn glwe_external_product<R, A, D>(&self, res: &mut R, lhs: &A, rhs: &D, scratch: &mut Scratch<BE>)
205 where
206 R: GLWEToMut,
207 A: GLWEToRef,
208 D: GGSWPreparedToRef<BE>,
209 Scratch<BE>: ScratchTakeCore<BE>,
210 {
211 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
212 let lhs: &GLWE<&[u8]> = &lhs.to_ref();
213
214 let rhs: &GGSWPrepared<&[u8], BE> = &rhs.to_ref();
215
216 let basek_in: usize = lhs.base2k().into();
217 let basek_ggsw: usize = rhs.base2k().into();
218 let basek_out: usize = res.base2k().into();
219
220 #[cfg(debug_assertions)]
221 {
222 use poulpy_hal::api::ScratchAvailable;
223
224 assert_eq!(rhs.rank(), lhs.rank());
225 assert_eq!(rhs.rank(), res.rank());
226 assert_eq!(rhs.n(), res.n());
227 assert_eq!(lhs.n(), res.n());
228 assert!(scratch.available() >= self.glwe_external_product_tmp_bytes(res, lhs, rhs));
229 }
230
231 let cols: usize = (rhs.rank() + 1).into();
232 let dsize: usize = rhs.dsize().into();
233
234 let a_size: usize = (lhs.size() * basek_in).div_ceil(basek_ggsw);
235
236 let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, cols, rhs.size()); let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, cols, a_size.div_ceil(dsize));
238 a_dft.data_mut().fill(0);
239
240 if basek_in == basek_ggsw {
241 for di in 0..dsize {
242 a_dft.set_size((lhs.size() + di) / dsize);
244
245 res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
253
254 for j in 0..cols {
255 self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &lhs.data, j);
256 }
257
258 if di == 0 {
259 self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_2);
260 } else {
261 self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_2);
262 }
263 }
264 } else {
265 let (mut a_conv, scratch_3) = scratch_2.take_vec_znx(self.n(), cols, a_size);
266
267 for j in 0..cols {
268 self.vec_znx_normalize(
269 basek_ggsw,
270 &mut a_conv,
271 j,
272 basek_in,
273 &lhs.data,
274 j,
275 scratch_3,
276 );
277 }
278
279 for di in 0..dsize {
280 a_dft.set_size((a_size + di) / dsize);
282
283 res_dft.set_size(rhs.size() - ((dsize - di) as isize - 2).max(0) as usize);
291
292 for j in 0..cols {
293 self.vec_znx_dft_apply(dsize, dsize - 1 - di, &mut a_dft, j, &a_conv, j);
294 }
295
296 if di == 0 {
297 self.vmp_apply_dft_to_dft(&mut res_dft, &a_dft, &rhs.data, scratch_3);
298 } else {
299 self.vmp_apply_dft_to_dft_add(&mut res_dft, &a_dft, &rhs.data, di, scratch_3);
300 }
301 }
302 }
303
304 let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
305
306 (0..cols).for_each(|i| {
307 self.vec_znx_big_normalize(
308 basek_out,
309 res.data_mut(),
310 i,
311 basek_ggsw,
312 &res_big,
313 i,
314 scratch_1,
315 );
316 });
317 }
318}
319
320impl<BE: Backend> GLWEExternalProduct<BE> for Module<BE> where
321 Self: ModuleN
322 + VecZnxDftBytesOf
323 + VmpApplyDftToDftTmpBytes
324 + VecZnxNormalizeTmpBytes
325 + VecZnxDftApply<BE>
326 + VmpApplyDftToDft<BE>
327 + VmpApplyDftToDftAdd<BE>
328 + VecZnxIdftApplyConsume<BE>
329 + VecZnxBigNormalize<BE>
330 + VecZnxNormalize<BE>
331 + VecZnxDftBytesOf
332 + VmpApplyDftToDftTmpBytes
333 + VecZnxNormalizeTmpBytes
334{
335}