1use poulpy_hal::{
2 api::{
3 BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy,
4 VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
5 VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
6 VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero,
7 },
8 layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos},
9 reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
10};
11
12use crate::{
13 ScratchTakeCore,
14 layouts::{
15 GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos,
16 TorusPrecision,
17 },
18};
19
20pub trait GLWETensoring<BE: Backend>
21where
22 Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>,
23 Scratch<BE>: ScratchTakeCore<BE>,
24{
25 fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
34 where
35 R: GLWETensorToMut,
36 A: GLWEToRef,
37 B: GLWEPreparedToRef<BE>,
38 {
39 let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut();
40 let a: &GLWE<&[u8]> = &a.to_ref();
41 let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
42
43 assert_eq!(a.base2k(), b.base2k());
44 assert_eq!(a.rank(), res.rank());
45
46 let res_cols: usize = res.data.cols();
47
48 let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize);
50
51 self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1);
53
54 let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
56
57 for res_col in 0..res_cols {
59 self.vec_znx_big_normalize(
60 res.base2k().into(),
61 &mut res.data,
62 res_col,
63 a.base2k().into(),
64 &res_big,
65 res_col,
66 scratch_1,
67 );
68 }
69 }
70
71 }
79
80pub trait GLWEAdd
81where
82 Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
83{
84 fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
85 where
86 R: GLWEToMut,
87 A: GLWEToRef,
88 B: GLWEToRef,
89 {
90 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
91 let a: &mut GLWE<&[u8]> = &mut a.to_ref();
92 let b: &GLWE<&[u8]> = &b.to_ref();
93
94 assert_eq!(a.n(), self.n() as u32);
95 assert_eq!(b.n(), self.n() as u32);
96 assert_eq!(res.n(), self.n() as u32);
97 assert_eq!(a.base2k(), b.base2k());
98 assert_eq!(res.base2k(), b.base2k());
99
100 if a.rank() == 0 {
101 assert_eq!(res.rank(), b.rank());
102 } else if b.rank() == 0 {
103 assert_eq!(res.rank(), a.rank());
104 } else {
105 assert_eq!(res.rank(), a.rank());
106 assert_eq!(res.rank(), b.rank());
107 }
108
109 let min_col: usize = (a.rank().min(b.rank()) + 1).into();
110 let max_col: usize = (a.rank().max(b.rank() + 1)).into();
111 let self_col: usize = (res.rank() + 1).into();
112
113 for i in 0..min_col {
114 self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
115 }
116
117 if a.rank() > b.rank() {
118 for i in min_col..max_col {
119 self.vec_znx_copy(res.data_mut(), i, a.data(), i);
120 }
121 } else {
122 for i in min_col..max_col {
123 self.vec_znx_copy(res.data_mut(), i, b.data(), i);
124 }
125 }
126
127 for i in max_col..self_col {
128 self.vec_znx_zero(res.data_mut(), i);
129 }
130 }
131
132 fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
133 where
134 R: GLWEToMut,
135 A: GLWEToRef,
136 {
137 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
138 let a: &GLWE<&[u8]> = &a.to_ref();
139
140 assert_eq!(res.n(), self.n() as u32);
141 assert_eq!(a.n(), self.n() as u32);
142 assert_eq!(res.base2k(), a.base2k());
143 assert!(res.rank() >= a.rank());
144
145 for i in 0..(a.rank() + 1).into() {
146 self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
147 }
148 }
149}
150
151impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
152
153impl<BE: Backend> GLWESub for Module<BE> where
154 Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
155{
156}
157
158pub trait GLWESub
159where
160 Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
161{
162 fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
163 where
164 R: GLWEToMut,
165 A: GLWEToRef,
166 B: GLWEToRef,
167 {
168 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
169 let a: &GLWE<&[u8]> = &a.to_ref();
170 let b: &GLWE<&[u8]> = &b.to_ref();
171
172 assert_eq!(a.n(), self.n() as u32);
173 assert_eq!(b.n(), self.n() as u32);
174 assert_eq!(res.n(), self.n() as u32);
175 assert_eq!(a.base2k(), res.base2k());
176 assert_eq!(b.base2k(), res.base2k());
177
178 if a.rank() == 0 {
179 assert_eq!(res.rank(), b.rank());
180 } else if b.rank() == 0 {
181 assert_eq!(res.rank(), a.rank());
182 } else {
183 assert_eq!(res.rank(), a.rank());
184 assert_eq!(res.rank(), b.rank());
185 }
186
187 let min_col: usize = (a.rank().min(b.rank()) + 1).into();
188 let max_col: usize = (a.rank().max(b.rank() + 1)).into();
189 let self_col: usize = (res.rank() + 1).into();
190
191 for i in 0..min_col {
192 self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
193 }
194
195 if a.rank() > b.rank() {
196 for i in min_col..max_col {
197 self.vec_znx_copy(res.data_mut(), i, a.data(), i);
198 }
199 } else {
200 for i in min_col..max_col {
201 self.vec_znx_negate(res.data_mut(), i, b.data(), i);
202 }
203 }
204
205 for i in max_col..self_col {
206 self.vec_znx_zero(res.data_mut(), i);
207 }
208 }
209
210 fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
211 where
212 R: GLWEToMut,
213 A: GLWEToRef,
214 {
215 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
216 let a: &GLWE<&[u8]> = &a.to_ref();
217
218 assert_eq!(res.n(), self.n() as u32);
219 assert_eq!(a.n(), self.n() as u32);
220 assert_eq!(res.base2k(), a.base2k());
221 assert!(res.rank() == a.rank() || a.rank() == 0);
222
223 for i in 0..(a.rank() + 1).into() {
224 self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
225 }
226 }
227
228 fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
229 where
230 R: GLWEToMut,
231 A: GLWEToRef,
232 {
233 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
234 let a: &GLWE<&[u8]> = &a.to_ref();
235
236 assert_eq!(res.n(), self.n() as u32);
237 assert_eq!(a.n(), self.n() as u32);
238 assert_eq!(res.base2k(), a.base2k());
239 assert!(res.rank() == a.rank() || a.rank() == 0);
240
241 for i in 0..(a.rank() + 1).into() {
242 self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
243 }
244 }
245}
246
247impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
248
249pub trait GLWERotate<BE: Backend>
250where
251 Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
252{
253 fn glwe_rotate_tmp_bytes(&self) -> usize {
254 vec_znx_rotate_inplace_tmp_bytes(self.n())
255 }
256
257 fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
258 where
259 R: GLWEToMut,
260 A: GLWEToRef,
261 {
262 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
263 let a: &GLWE<&[u8]> = &a.to_ref();
264
265 assert_eq!(a.n(), self.n() as u32);
266 assert_eq!(res.n(), self.n() as u32);
267 assert!(res.rank() == a.rank() || a.rank() == 0);
268
269 let res_cols = (res.rank() + 1).into();
270 let a_cols = (a.rank() + 1).into();
271
272 for i in 0..a_cols {
273 self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
274 }
275 for i in a_cols..res_cols {
276 self.vec_znx_zero(res.data_mut(), i);
277 }
278 }
279
280 fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
281 where
282 R: GLWEToMut,
283 Scratch<BE>: ScratchTakeCore<BE>,
284 {
285 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
286
287 for i in 0..(res.rank() + 1).into() {
288 self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
289 }
290 }
291}
292
293impl<BE: Backend> GLWEMulXpMinusOne<BE> for Module<BE> where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE> {}
294
295pub trait GLWEMulXpMinusOne<BE: Backend>
296where
297 Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>,
298{
299 fn glwe_mul_xp_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A)
300 where
301 R: GLWEToMut,
302 A: GLWEToRef,
303 {
304 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
305 let a: &GLWE<&[u8]> = &a.to_ref();
306
307 assert_eq!(res.n(), self.n() as u32);
308 assert_eq!(a.n(), self.n() as u32);
309 assert_eq!(res.rank(), a.rank());
310
311 for i in 0..res.rank().as_usize() + 1 {
312 self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
313 }
314 }
315
316 fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
317 where
318 R: GLWEToMut,
319 {
320 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
321
322 assert_eq!(res.n(), self.n() as u32);
323
324 for i in 0..res.rank().as_usize() + 1 {
325 self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch);
326 }
327 }
328}
329
330impl<BE: Backend> GLWECopy for Module<BE> where Self: ModuleN + VecZnxCopy + VecZnxZero {}
331
332pub trait GLWECopy
333where
334 Self: ModuleN + VecZnxCopy + VecZnxZero,
335{
336 fn glwe_copy<R, A>(&self, res: &mut R, a: &A)
337 where
338 R: GLWEToMut,
339 A: GLWEToRef,
340 {
341 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
342 let a: &GLWE<&[u8]> = &a.to_ref();
343
344 assert_eq!(res.n(), self.n() as u32);
345 assert_eq!(a.n(), self.n() as u32);
346 assert!(res.rank() == a.rank() || a.rank() == 0);
347
348 let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
349
350 for i in 0..min_rank {
351 self.vec_znx_copy(res.data_mut(), i, a.data(), i);
352 }
353
354 for i in min_rank..(res.rank() + 1).into() {
355 self.vec_znx_zero(res.data_mut(), i);
356 }
357 }
358}
359
360impl<BE: Backend> GLWEShift<BE> for Module<BE> where Self: ModuleN + VecZnxRshInplace<BE> {}
361
362pub trait GLWEShift<BE: Backend>
363where
364 Self: ModuleN + VecZnxRshInplace<BE>,
365{
366 fn glwe_rsh_tmp_byte(&self) -> usize {
367 VecZnx::rsh_tmp_bytes(self.n())
368 }
369
370 fn glwe_rsh<R>(&self, k: usize, res: &mut R, scratch: &mut Scratch<BE>)
371 where
372 R: GLWEToMut,
373 Scratch<BE>: ScratchTakeCore<BE>,
374 {
375 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
376 let base2k: usize = res.base2k().into();
377 for i in 0..res.rank().as_usize() + 1 {
378 self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch);
379 }
380 }
381}
382
383impl GLWE<Vec<u8>> {
384 pub fn rsh_tmp_bytes<M, BE: Backend>(module: &M) -> usize
385 where
386 M: GLWEShift<BE>,
387 {
388 module.glwe_rsh_tmp_byte()
389 }
390}
391
392impl<BE: Backend> GLWENormalize<BE> for Module<BE> where
393 Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxNormalizeTmpBytes
394{
395}
396
397pub trait GLWENormalize<BE: Backend>
398where
399 Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxNormalizeTmpBytes,
400{
401 fn glwe_normalize_tmp_bytes(&self) -> usize {
402 self.vec_znx_normalize_tmp_bytes()
403 }
404
405 fn glwe_maybe_cross_normalize_to_ref<'a, A>(
409 &self,
410 glwe: &'a A,
411 target_base2k: usize,
412 tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, scratch: &'a mut Scratch<BE>,
414 ) -> (GLWE<&'a [u8]>, &'a mut Scratch<BE>)
415 where
416 A: GLWEToRef + GLWEInfos,
417 Scratch<BE>: ScratchTakeCore<BE>,
418 {
419 if glwe.base2k().as_usize() == target_base2k {
421 tmp_slot.take();
423 return (glwe.to_ref(), scratch);
424 }
425
426 let mut layout = glwe.glwe_layout();
428 layout.base2k = target_base2k.into();
429
430 let (tmp, scratch2) = scratch.take_glwe(&layout);
431 *tmp_slot = Some(tmp);
432
433 let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot
435 .as_mut()
436 .expect("tmp_slot just set to Some, but found None");
437
438 self.glwe_normalize(tmp_ref, glwe, scratch2);
439
440 (tmp_ref.to_ref(), scratch2)
442 }
443
444 fn glwe_maybe_cross_normalize_to_mut<'a, A>(
448 &self,
449 glwe: &'a mut A,
450 target_base2k: usize,
451 tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, scratch: &'a mut Scratch<BE>,
453 ) -> (GLWE<&'a mut [u8]>, &'a mut Scratch<BE>)
454 where
455 A: GLWEToMut + GLWEInfos,
456 Scratch<BE>: ScratchTakeCore<BE>,
457 {
458 if glwe.base2k().as_usize() == target_base2k {
460 tmp_slot.take();
462 return (glwe.to_mut(), scratch);
463 }
464
465 let mut layout = glwe.glwe_layout();
467 layout.base2k = target_base2k.into();
468
469 let (tmp, scratch2) = scratch.take_glwe(&layout);
470 *tmp_slot = Some(tmp);
471
472 let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot
474 .as_mut()
475 .expect("tmp_slot just set to Some, but found None");
476
477 self.glwe_normalize(tmp_ref, glwe, scratch2);
478
479 (tmp_ref.to_mut(), scratch2)
481 }
482
483 fn glwe_normalize<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
484 where
485 R: GLWEToMut,
486 A: GLWEToRef,
487 Scratch<BE>: ScratchTakeCore<BE>,
488 {
489 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
490 let a: &GLWE<&[u8]> = &a.to_ref();
491
492 assert_eq!(res.n(), self.n() as u32);
493 assert_eq!(a.n(), self.n() as u32);
494 assert_eq!(res.rank(), a.rank());
495
496 for i in 0..res.rank().as_usize() + 1 {
497 self.vec_znx_normalize(
498 res.base2k().into(),
499 res.data_mut(),
500 i,
501 a.base2k().into(),
502 a.data(),
503 i,
504 scratch,
505 );
506 }
507 }
508
509 fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
510 where
511 R: GLWEToMut,
512 Scratch<BE>: ScratchTakeCore<BE>,
513 {
514 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
515 for i in 0..res.rank().as_usize() + 1 {
516 self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch);
517 }
518 }
519}
520
521#[allow(dead_code)]
522fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
524 if a.rank() != 0 || b.rank() != 0 {
526 let k = if a.rank() == 0 {
528 b.k()
529 } else if b.rank() == 0 {
531 a.k()
532 } else {
534 a.k().min(b.k())
535 };
536 k.min(c.k())
537 } else {
539 c.k()
540 }
541}
542
543#[allow(dead_code)]
544fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
546 if a.rank() != 0 || b.rank() != 0 {
547 a.k().min(b.k())
548 } else {
549 a.k()
550 }
551}