poulpy_core/operations/
glwe.rs

1use poulpy_hal::{
2    api::{
3        BivariateTensoring, ModuleN, ScratchTakeBasic, VecZnxAdd, VecZnxAddInplace, VecZnxBigNormalize, VecZnxCopy,
4        VecZnxIdftApplyConsume, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegate, VecZnxNormalize,
5        VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
6        VecZnxSubInplace, VecZnxSubNegateInplace, VecZnxZero,
7    },
8    layouts::{Backend, Module, Scratch, VecZnx, VecZnxBig, ZnxInfos},
9    reference::vec_znx::vec_znx_rotate_inplace_tmp_bytes,
10};
11
12use crate::{
13    ScratchTakeCore,
14    layouts::{
15        GLWE, GLWEInfos, GLWEPrepared, GLWEPreparedToRef, GLWETensor, GLWETensorToMut, GLWEToMut, GLWEToRef, LWEInfos,
16        TorusPrecision,
17    },
18};
19
20pub trait GLWETensoring<BE: Backend>
21where
22    Self: BivariateTensoring<BE> + VecZnxIdftApplyConsume<BE> + VecZnxBigNormalize<BE>,
23    Scratch<BE>: ScratchTakeCore<BE>,
24{
25    /// res = (a (x) b) * 2^{k * a_base2k}
26    ///
27    /// # Requires
28    /// * a.base2k() == b.base2k()
29    /// * res.cols() >= a.cols() + b.cols() - 1
30    ///
31    /// # Behavior
32    /// * res precision is truncated to res.max_k().min(a.max_k() + b.max_k() + k * a_base2k)
33    fn glwe_tensor<R, A, B>(&self, k: i64, res: &mut R, a: &A, b: &B, scratch: &mut Scratch<BE>)
34    where
35        R: GLWETensorToMut,
36        A: GLWEToRef,
37        B: GLWEPreparedToRef<BE>,
38    {
39        let res: &mut GLWETensor<&mut [u8]> = &mut res.to_mut();
40        let a: &GLWE<&[u8]> = &a.to_ref();
41        let b: &GLWEPrepared<&[u8], BE> = &b.to_ref();
42
43        assert_eq!(a.base2k(), b.base2k());
44        assert_eq!(a.rank(), res.rank());
45
46        let res_cols: usize = res.data.cols();
47
48        // Get tmp buffer of min precision between a_prec * b_prec and res_prec
49        let (mut res_dft, scratch_1) = scratch.take_vec_znx_dft(self, res_cols, res.max_k().div_ceil(a.base2k()) as usize);
50
51        // DFT(res) = DFT(a) (x) DFT(b)
52        self.bivariate_tensoring(k, &mut res_dft, &a.data, &b.data, scratch_1);
53
54        // res = IDFT(res)
55        let res_big: VecZnxBig<&mut [u8], BE> = self.vec_znx_idft_apply_consume(res_dft);
56
57        // Normalize and switches basis if required
58        for res_col in 0..res_cols {
59            self.vec_znx_big_normalize(
60                res.base2k().into(),
61                &mut res.data,
62                res_col,
63                a.base2k().into(),
64                &res_big,
65                res_col,
66                scratch_1,
67            );
68        }
69    }
70
71    // fn glwe_relinearize<R, A, T>(&self, res: &mut R, a: &A, tsk: &T, scratch: &mut Scratch<BE>)
72    // where
73    // R: GLWEToRef,
74    // A: GLWETensorToRef,
75    // T: GLWETensorKeyPreparedToRef<BE>,
76    // {
77    // }
78}
79
80pub trait GLWEAdd
81where
82    Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero,
83{
84    fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
85    where
86        R: GLWEToMut,
87        A: GLWEToRef,
88        B: GLWEToRef,
89    {
90        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
91        let a: &mut GLWE<&[u8]> = &mut a.to_ref();
92        let b: &GLWE<&[u8]> = &b.to_ref();
93
94        assert_eq!(a.n(), self.n() as u32);
95        assert_eq!(b.n(), self.n() as u32);
96        assert_eq!(res.n(), self.n() as u32);
97        assert_eq!(a.base2k(), b.base2k());
98        assert_eq!(res.base2k(), b.base2k());
99
100        if a.rank() == 0 {
101            assert_eq!(res.rank(), b.rank());
102        } else if b.rank() == 0 {
103            assert_eq!(res.rank(), a.rank());
104        } else {
105            assert_eq!(res.rank(), a.rank());
106            assert_eq!(res.rank(), b.rank());
107        }
108
109        let min_col: usize = (a.rank().min(b.rank()) + 1).into();
110        let max_col: usize = (a.rank().max(b.rank() + 1)).into();
111        let self_col: usize = (res.rank() + 1).into();
112
113        for i in 0..min_col {
114            self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
115        }
116
117        if a.rank() > b.rank() {
118            for i in min_col..max_col {
119                self.vec_znx_copy(res.data_mut(), i, a.data(), i);
120            }
121        } else {
122            for i in min_col..max_col {
123                self.vec_znx_copy(res.data_mut(), i, b.data(), i);
124            }
125        }
126
127        for i in max_col..self_col {
128            self.vec_znx_zero(res.data_mut(), i);
129        }
130    }
131
132    fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
133    where
134        R: GLWEToMut,
135        A: GLWEToRef,
136    {
137        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
138        let a: &GLWE<&[u8]> = &a.to_ref();
139
140        assert_eq!(res.n(), self.n() as u32);
141        assert_eq!(a.n(), self.n() as u32);
142        assert_eq!(res.base2k(), a.base2k());
143        assert!(res.rank() >= a.rank());
144
145        for i in 0..(a.rank() + 1).into() {
146            self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
147        }
148    }
149}
150
151impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace + VecZnxZero {}
152
153impl<BE: Backend> GLWESub for Module<BE> where
154    Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace
155{
156}
157
158pub trait GLWESub
159where
160    Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegate + VecZnxZero + VecZnxSubInplace + VecZnxSubNegateInplace,
161{
162    fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
163    where
164        R: GLWEToMut,
165        A: GLWEToRef,
166        B: GLWEToRef,
167    {
168        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
169        let a: &GLWE<&[u8]> = &a.to_ref();
170        let b: &GLWE<&[u8]> = &b.to_ref();
171
172        assert_eq!(a.n(), self.n() as u32);
173        assert_eq!(b.n(), self.n() as u32);
174        assert_eq!(res.n(), self.n() as u32);
175        assert_eq!(a.base2k(), res.base2k());
176        assert_eq!(b.base2k(), res.base2k());
177
178        if a.rank() == 0 {
179            assert_eq!(res.rank(), b.rank());
180        } else if b.rank() == 0 {
181            assert_eq!(res.rank(), a.rank());
182        } else {
183            assert_eq!(res.rank(), a.rank());
184            assert_eq!(res.rank(), b.rank());
185        }
186
187        let min_col: usize = (a.rank().min(b.rank()) + 1).into();
188        let max_col: usize = (a.rank().max(b.rank() + 1)).into();
189        let self_col: usize = (res.rank() + 1).into();
190
191        for i in 0..min_col {
192            self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
193        }
194
195        if a.rank() > b.rank() {
196            for i in min_col..max_col {
197                self.vec_znx_copy(res.data_mut(), i, a.data(), i);
198            }
199        } else {
200            for i in min_col..max_col {
201                self.vec_znx_negate(res.data_mut(), i, b.data(), i);
202            }
203        }
204
205        for i in max_col..self_col {
206            self.vec_znx_zero(res.data_mut(), i);
207        }
208    }
209
210    fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
211    where
212        R: GLWEToMut,
213        A: GLWEToRef,
214    {
215        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
216        let a: &GLWE<&[u8]> = &a.to_ref();
217
218        assert_eq!(res.n(), self.n() as u32);
219        assert_eq!(a.n(), self.n() as u32);
220        assert_eq!(res.base2k(), a.base2k());
221        assert!(res.rank() == a.rank() || a.rank() == 0);
222
223        for i in 0..(a.rank() + 1).into() {
224            self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
225        }
226    }
227
228    fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
229    where
230        R: GLWEToMut,
231        A: GLWEToRef,
232    {
233        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
234        let a: &GLWE<&[u8]> = &a.to_ref();
235
236        assert_eq!(res.n(), self.n() as u32);
237        assert_eq!(a.n(), self.n() as u32);
238        assert_eq!(res.base2k(), a.base2k());
239        assert!(res.rank() == a.rank() || a.rank() == 0);
240
241        for i in 0..(a.rank() + 1).into() {
242            self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
243        }
244    }
245}
246
247impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero {}
248
249pub trait GLWERotate<BE: Backend>
250where
251    Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> + VecZnxZero,
252{
253    fn glwe_rotate_tmp_bytes(&self) -> usize {
254        vec_znx_rotate_inplace_tmp_bytes(self.n())
255    }
256
257    fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
258    where
259        R: GLWEToMut,
260        A: GLWEToRef,
261    {
262        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
263        let a: &GLWE<&[u8]> = &a.to_ref();
264
265        assert_eq!(a.n(), self.n() as u32);
266        assert_eq!(res.n(), self.n() as u32);
267        assert!(res.rank() == a.rank() || a.rank() == 0);
268
269        let res_cols = (res.rank() + 1).into();
270        let a_cols = (a.rank() + 1).into();
271
272        for i in 0..a_cols {
273            self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
274        }
275        for i in a_cols..res_cols {
276            self.vec_znx_zero(res.data_mut(), i);
277        }
278    }
279
280    fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
281    where
282        R: GLWEToMut,
283        Scratch<BE>: ScratchTakeCore<BE>,
284    {
285        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
286
287        for i in 0..(res.rank() + 1).into() {
288            self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
289        }
290    }
291}
292
293impl<BE: Backend> GLWEMulXpMinusOne<BE> for Module<BE> where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE> {}
294
295pub trait GLWEMulXpMinusOne<BE: Backend>
296where
297    Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>,
298{
299    fn glwe_mul_xp_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A)
300    where
301        R: GLWEToMut,
302        A: GLWEToRef,
303    {
304        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
305        let a: &GLWE<&[u8]> = &a.to_ref();
306
307        assert_eq!(res.n(), self.n() as u32);
308        assert_eq!(a.n(), self.n() as u32);
309        assert_eq!(res.rank(), a.rank());
310
311        for i in 0..res.rank().as_usize() + 1 {
312            self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
313        }
314    }
315
316    fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
317    where
318        R: GLWEToMut,
319    {
320        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
321
322        assert_eq!(res.n(), self.n() as u32);
323
324        for i in 0..res.rank().as_usize() + 1 {
325            self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch);
326        }
327    }
328}
329
330impl<BE: Backend> GLWECopy for Module<BE> where Self: ModuleN + VecZnxCopy + VecZnxZero {}
331
332pub trait GLWECopy
333where
334    Self: ModuleN + VecZnxCopy + VecZnxZero,
335{
336    fn glwe_copy<R, A>(&self, res: &mut R, a: &A)
337    where
338        R: GLWEToMut,
339        A: GLWEToRef,
340    {
341        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
342        let a: &GLWE<&[u8]> = &a.to_ref();
343
344        assert_eq!(res.n(), self.n() as u32);
345        assert_eq!(a.n(), self.n() as u32);
346        assert!(res.rank() == a.rank() || a.rank() == 0);
347
348        let min_rank: usize = res.rank().min(a.rank()).as_usize() + 1;
349
350        for i in 0..min_rank {
351            self.vec_znx_copy(res.data_mut(), i, a.data(), i);
352        }
353
354        for i in min_rank..(res.rank() + 1).into() {
355            self.vec_znx_zero(res.data_mut(), i);
356        }
357    }
358}
359
360impl<BE: Backend> GLWEShift<BE> for Module<BE> where Self: ModuleN + VecZnxRshInplace<BE> {}
361
362pub trait GLWEShift<BE: Backend>
363where
364    Self: ModuleN + VecZnxRshInplace<BE>,
365{
366    fn glwe_rsh_tmp_byte(&self) -> usize {
367        VecZnx::rsh_tmp_bytes(self.n())
368    }
369
370    fn glwe_rsh<R>(&self, k: usize, res: &mut R, scratch: &mut Scratch<BE>)
371    where
372        R: GLWEToMut,
373        Scratch<BE>: ScratchTakeCore<BE>,
374    {
375        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
376        let base2k: usize = res.base2k().into();
377        for i in 0..res.rank().as_usize() + 1 {
378            self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch);
379        }
380    }
381}
382
383impl GLWE<Vec<u8>> {
384    pub fn rsh_tmp_bytes<M, BE: Backend>(module: &M) -> usize
385    where
386        M: GLWEShift<BE>,
387    {
388        module.glwe_rsh_tmp_byte()
389    }
390}
391
392impl<BE: Backend> GLWENormalize<BE> for Module<BE> where
393    Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxNormalizeTmpBytes
394{
395}
396
397pub trait GLWENormalize<BE: Backend>
398where
399    Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> + VecZnxNormalizeTmpBytes,
400{
401    fn glwe_normalize_tmp_bytes(&self) -> usize {
402        self.vec_znx_normalize_tmp_bytes()
403    }
404
405    /// Usage:
406    /// let mut tmp_b: Option<GLWE<&mut [u8]>> = None;
407    /// let (b_conv, scratch_1) = glwe_maybe_convert_in_place(self, b, res.base2k().as_u32(), &mut tmp_b, scratch);
408    fn glwe_maybe_cross_normalize_to_ref<'a, A>(
409        &self,
410        glwe: &'a A,
411        target_base2k: usize,
412        tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, // caller-owned scratch-backed temp
413        scratch: &'a mut Scratch<BE>,
414    ) -> (GLWE<&'a [u8]>, &'a mut Scratch<BE>)
415    where
416        A: GLWEToRef + GLWEInfos,
417        Scratch<BE>: ScratchTakeCore<BE>,
418    {
419        // No conversion: just use the original GLWE
420        if glwe.base2k().as_usize() == target_base2k {
421            // Drop any previous temp; it's stale for this base
422            tmp_slot.take();
423            return (glwe.to_ref(), scratch);
424        }
425
426        // Conversion: allocate a temporary GLWE in scratch
427        let mut layout = glwe.glwe_layout();
428        layout.base2k = target_base2k.into();
429
430        let (tmp, scratch2) = scratch.take_glwe(&layout);
431        *tmp_slot = Some(tmp);
432
433        // Get a mutable handle to the temp and normalize into it
434        let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot
435            .as_mut()
436            .expect("tmp_slot just set to Some, but found None");
437
438        self.glwe_normalize(tmp_ref, glwe, scratch2);
439
440        // Return a trait-object view of the temp
441        (tmp_ref.to_ref(), scratch2)
442    }
443
444    /// Usage:
445    /// let mut tmp_b: Option<GLWE<&mut [u8]>> = None;
446    /// let (b_conv, scratch_1) = glwe_maybe_convert_in_place(self, b, res.base2k().as_u32(), &mut tmp_b, scratch);
447    fn glwe_maybe_cross_normalize_to_mut<'a, A>(
448        &self,
449        glwe: &'a mut A,
450        target_base2k: usize,
451        tmp_slot: &'a mut Option<GLWE<&'a mut [u8]>>, // caller-owned scratch-backed temp
452        scratch: &'a mut Scratch<BE>,
453    ) -> (GLWE<&'a mut [u8]>, &'a mut Scratch<BE>)
454    where
455        A: GLWEToMut + GLWEInfos,
456        Scratch<BE>: ScratchTakeCore<BE>,
457    {
458        // No conversion: just use the original GLWE
459        if glwe.base2k().as_usize() == target_base2k {
460            // Drop any previous temp; it's stale for this base
461            tmp_slot.take();
462            return (glwe.to_mut(), scratch);
463        }
464
465        // Conversion: allocate a temporary GLWE in scratch
466        let mut layout = glwe.glwe_layout();
467        layout.base2k = target_base2k.into();
468
469        let (tmp, scratch2) = scratch.take_glwe(&layout);
470        *tmp_slot = Some(tmp);
471
472        // Get a mutable handle to the temp and normalize into it
473        let tmp_ref: &mut GLWE<&mut [u8]> = tmp_slot
474            .as_mut()
475            .expect("tmp_slot just set to Some, but found None");
476
477        self.glwe_normalize(tmp_ref, glwe, scratch2);
478
479        // Return a trait-object view of the temp
480        (tmp_ref.to_mut(), scratch2)
481    }
482
483    fn glwe_normalize<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
484    where
485        R: GLWEToMut,
486        A: GLWEToRef,
487        Scratch<BE>: ScratchTakeCore<BE>,
488    {
489        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
490        let a: &GLWE<&[u8]> = &a.to_ref();
491
492        assert_eq!(res.n(), self.n() as u32);
493        assert_eq!(a.n(), self.n() as u32);
494        assert_eq!(res.rank(), a.rank());
495
496        for i in 0..res.rank().as_usize() + 1 {
497            self.vec_znx_normalize(
498                res.base2k().into(),
499                res.data_mut(),
500                i,
501                a.base2k().into(),
502                a.data(),
503                i,
504                scratch,
505            );
506        }
507    }
508
509    fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
510    where
511        R: GLWEToMut,
512        Scratch<BE>: ScratchTakeCore<BE>,
513    {
514        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
515        for i in 0..res.rank().as_usize() + 1 {
516            self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch);
517        }
518    }
519}
520
521#[allow(dead_code)]
522// c = op(a, b)
523fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
524    // If either operands is a ciphertext
525    if a.rank() != 0 || b.rank() != 0 {
526        // If a is a plaintext (but b ciphertext)
527        let k = if a.rank() == 0 {
528            b.k()
529        // If b is a plaintext (but a ciphertext)
530        } else if b.rank() == 0 {
531            a.k()
532        // If a & b are both ciphertexts
533        } else {
534            a.k().min(b.k())
535        };
536        k.min(c.k())
537    // If a & b are both plaintexts
538    } else {
539        c.k()
540    }
541}
542
543#[allow(dead_code)]
544// a = op(a, b)
545fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
546    if a.rank() != 0 || b.rank() != 0 {
547        a.k().min(b.k())
548    } else {
549        a.k()
550    }
551}