poulpy_core/operations/
glwe.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
4        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
5        VecZnxSubInplace, VecZnxSubNegateInplace,
6    },
7    layouts::{Backend, Module, Scratch, VecZnx, ZnxZero},
8};
9
10use crate::{
11    ScratchTakeCore,
12    layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision},
13};
14
15pub trait GLWEAdd
16where
17    Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace,
18{
19    fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
20    where
21        R: GLWEToMut,
22        A: GLWEToRef,
23        B: GLWEToRef,
24    {
25        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
26        let a: &mut GLWE<&[u8]> = &mut a.to_ref();
27        let b: &GLWE<&[u8]> = &b.to_ref();
28
29        assert_eq!(a.n(), self.n() as u32);
30        assert_eq!(b.n(), self.n() as u32);
31        assert_eq!(res.n(), self.n() as u32);
32        assert_eq!(a.base2k(), b.base2k());
33        assert!(res.rank() >= a.rank().max(b.rank()));
34
35        let min_col: usize = (a.rank().min(b.rank()) + 1).into();
36        let max_col: usize = (a.rank().max(b.rank() + 1)).into();
37        let self_col: usize = (res.rank() + 1).into();
38
39        (0..min_col).for_each(|i| {
40            self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
41        });
42
43        if a.rank() > b.rank() {
44            (min_col..max_col).for_each(|i| {
45                self.vec_znx_copy(res.data_mut(), i, a.data(), i);
46            });
47        } else {
48            (min_col..max_col).for_each(|i| {
49                self.vec_znx_copy(res.data_mut(), i, b.data(), i);
50            });
51        }
52
53        let size: usize = res.size();
54        (max_col..self_col).for_each(|i| {
55            (0..size).for_each(|j| {
56                res.data.zero_at(i, j);
57            });
58        });
59
60        res.set_base2k(a.base2k());
61        res.set_k(set_k_binary(res, a, b));
62    }
63
64    fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
65    where
66        R: GLWEToMut,
67        A: GLWEToRef,
68    {
69        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
70        let a: &GLWE<&[u8]> = &a.to_ref();
71
72        assert_eq!(res.n(), self.n() as u32);
73        assert_eq!(a.n(), self.n() as u32);
74        assert_eq!(res.base2k(), a.base2k());
75        assert!(res.rank() >= a.rank());
76
77        (0..(a.rank() + 1).into()).for_each(|i| {
78            self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
79        });
80
81        res.set_k(set_k_unary(res, a))
82    }
83}
84
85impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {}
86
87impl<BE: Backend> GLWESub for Module<BE> where
88    Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace
89{
90}
91
92pub trait GLWESub
93where
94    Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace,
95{
96    fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
97    where
98        R: GLWEToMut,
99        A: GLWEToRef,
100        B: GLWEToRef,
101    {
102        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
103        let a: &GLWE<&[u8]> = &a.to_ref();
104        let b: &GLWE<&[u8]> = &b.to_ref();
105
106        assert_eq!(a.n(), self.n() as u32);
107        assert_eq!(b.n(), self.n() as u32);
108        assert_eq!(a.base2k(), b.base2k());
109        assert!(res.rank() >= a.rank().max(b.rank()));
110
111        let min_col: usize = (a.rank().min(b.rank()) + 1).into();
112        let max_col: usize = (a.rank().max(b.rank() + 1)).into();
113        let self_col: usize = (res.rank() + 1).into();
114
115        (0..min_col).for_each(|i| {
116            self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
117        });
118
119        if a.rank() > b.rank() {
120            (min_col..max_col).for_each(|i| {
121                self.vec_znx_copy(res.data_mut(), i, a.data(), i);
122            });
123        } else {
124            (min_col..max_col).for_each(|i| {
125                self.vec_znx_copy(res.data_mut(), i, b.data(), i);
126                self.vec_znx_negate_inplace(res.data_mut(), i);
127            });
128        }
129
130        let size: usize = res.size();
131        (max_col..self_col).for_each(|i| {
132            (0..size).for_each(|j| {
133                res.data.zero_at(i, j);
134            });
135        });
136
137        res.set_base2k(a.base2k());
138        res.set_k(set_k_binary(res, a, b));
139    }
140
141    fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
142    where
143        R: GLWEToMut,
144        A: GLWEToRef,
145    {
146        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
147        let a: &GLWE<&[u8]> = &a.to_ref();
148
149        assert_eq!(res.n(), self.n() as u32);
150        assert_eq!(a.n(), self.n() as u32);
151        assert_eq!(res.base2k(), a.base2k());
152        assert!(res.rank() >= a.rank());
153
154        (0..(a.rank() + 1).into()).for_each(|i| {
155            self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
156        });
157
158        res.set_k(set_k_unary(res, a))
159    }
160
161    fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
162    where
163        R: GLWEToMut,
164        A: GLWEToRef,
165    {
166        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
167        let a: &GLWE<&[u8]> = &a.to_ref();
168
169        assert_eq!(res.n(), self.n() as u32);
170        assert_eq!(a.n(), self.n() as u32);
171        assert_eq!(res.base2k(), a.base2k());
172        assert!(res.rank() >= a.rank());
173
174        (0..(a.rank() + 1).into()).for_each(|i| {
175            self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
176        });
177
178        res.set_k(set_k_unary(res, a))
179    }
180}
181
182impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> {}
183
184pub trait GLWERotate<BE: Backend>
185where
186    Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>,
187{
188    fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
189    where
190        R: GLWEToMut,
191        A: GLWEToRef,
192    {
193        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
194        let a: &GLWE<&[u8]> = &a.to_ref();
195
196        assert_eq!(a.n(), self.n() as u32);
197        assert_eq!(res.rank(), a.rank());
198
199        (0..(a.rank() + 1).into()).for_each(|i| {
200            self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
201        });
202
203        res.set_base2k(a.base2k());
204        res.set_k(set_k_unary(res, a))
205    }
206
207    fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
208    where
209        R: GLWEToMut,
210        Scratch<BE>: ScratchTakeCore<BE>,
211    {
212        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
213
214        (0..(res.rank() + 1).into()).for_each(|i| {
215            self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
216        });
217    }
218}
219
220impl<BE: Backend> GLWEMulXpMinusOne<BE> for Module<BE> where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE> {}
221
222pub trait GLWEMulXpMinusOne<BE: Backend>
223where
224    Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>,
225{
226    fn glwe_mul_xp_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A)
227    where
228        R: GLWEToMut,
229        A: GLWEToRef,
230    {
231        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
232        let a: &GLWE<&[u8]> = &a.to_ref();
233
234        assert_eq!(res.n(), self.n() as u32);
235        assert_eq!(a.n(), self.n() as u32);
236        assert_eq!(res.rank(), a.rank());
237
238        for i in 0..res.rank().as_usize() + 1 {
239            self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
240        }
241
242        res.set_base2k(a.base2k());
243        res.set_k(set_k_unary(res, a))
244    }
245
246    fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
247    where
248        R: GLWEToMut,
249    {
250        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
251
252        assert_eq!(res.n(), self.n() as u32);
253
254        for i in 0..res.rank().as_usize() + 1 {
255            self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch);
256        }
257    }
258}
259
260impl<BE: Backend> GLWECopy for Module<BE> where Self: ModuleN + VecZnxCopy {}
261
262pub trait GLWECopy
263where
264    Self: ModuleN + VecZnxCopy,
265{
266    fn glwe_copy<R, A>(&self, res: &mut R, a: &A)
267    where
268        R: GLWEToMut,
269        A: GLWEToRef,
270    {
271        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
272        let a: &GLWE<&[u8]> = &a.to_ref();
273
274        assert_eq!(res.n(), self.n() as u32);
275        assert_eq!(a.n(), self.n() as u32);
276        assert_eq!(res.rank(), a.rank());
277
278        for i in 0..res.rank().as_usize() + 1 {
279            self.vec_znx_copy(res.data_mut(), i, a.data(), i);
280        }
281
282        res.set_k(a.k().min(res.max_k()));
283        res.set_base2k(a.base2k());
284    }
285}
286
287impl<BE: Backend> GLWEShift<BE> for Module<BE> where Self: ModuleN + VecZnxRshInplace<BE> {}
288
289pub trait GLWEShift<BE: Backend>
290where
291    Self: ModuleN + VecZnxRshInplace<BE>,
292{
293    fn glwe_rsh_tmp_byte(&self) -> usize {
294        VecZnx::rsh_tmp_bytes(self.n())
295    }
296
297    fn glwe_rsh<R>(&self, k: usize, res: &mut R, scratch: &mut Scratch<BE>)
298    where
299        R: GLWEToMut,
300        Scratch<BE>: ScratchTakeCore<BE>,
301    {
302        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
303        let base2k: usize = res.base2k().into();
304        for i in 0..res.rank().as_usize() + 1 {
305            self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch);
306        }
307    }
308}
309
310impl GLWE<Vec<u8>> {
311    pub fn rsh_tmp_bytes<M, BE: Backend>(module: &M) -> usize
312    where
313        M: GLWEShift<BE>,
314    {
315        module.glwe_rsh_tmp_byte()
316    }
317}
318
319impl<BE: Backend> GLWENormalize<BE> for Module<BE> where Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> {}
320
321pub trait GLWENormalize<BE: Backend>
322where
323    Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE>,
324{
325    fn glwe_normalize<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
326    where
327        R: GLWEToMut,
328        A: GLWEToRef,
329        Scratch<BE>: ScratchTakeCore<BE>,
330    {
331        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
332        let a: &GLWE<&[u8]> = &a.to_ref();
333
334        assert_eq!(res.n(), self.n() as u32);
335        assert_eq!(a.n(), self.n() as u32);
336        assert_eq!(res.rank(), a.rank());
337
338        for i in 0..res.rank().as_usize() + 1 {
339            self.vec_znx_normalize(
340                res.base2k().into(),
341                res.data_mut(),
342                i,
343                a.base2k().into(),
344                a.data(),
345                i,
346                scratch,
347            );
348        }
349
350        res.set_k(a.k().min(res.k()));
351    }
352
353    fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
354    where
355        R: GLWEToMut,
356        Scratch<BE>: ScratchTakeCore<BE>,
357    {
358        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
359        for i in 0..res.rank().as_usize() + 1 {
360            self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch);
361        }
362    }
363}
364
365// c = op(a, b)
366fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
367    // If either operands is a ciphertext
368    if a.rank() != 0 || b.rank() != 0 {
369        // If a is a plaintext (but b ciphertext)
370        let k = if a.rank() == 0 {
371            b.k()
372        // If b is a plaintext (but a ciphertext)
373        } else if b.rank() == 0 {
374            a.k()
375        // If a & b are both ciphertexts
376        } else {
377            a.k().min(b.k())
378        };
379        k.min(c.k())
380    // If a & b are both plaintexts
381    } else {
382        c.k()
383    }
384}
385
386// a = op(a, b)
387fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
388    if a.rank() != 0 || b.rank() != 0 {
389        a.k().min(b.k())
390    } else {
391        a.k()
392    }
393}