poulpy_core/operations/
glwe.rs

1use poulpy_hal::{
2    api::{
3        VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
4        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
5        VecZnxSubInplace, VecZnxSubNegateInplace,
6    },
7    layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero},
8};
9
10use crate::layouts::{
11    GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, GLWEPlaintext, LWEInfos, TorusPrecision,
12};
13
14impl<D> GLWEOperations for GLWEPlaintext<D>
15where
16    D: DataMut,
17    GLWEPlaintext<D>: GLWECiphertextToMut + GLWEInfos,
18{
19}
20
21impl<D: DataMut> GLWEOperations for GLWECiphertext<D> where GLWECiphertext<D>: GLWECiphertextToMut + GLWEInfos {}
22
23pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Sized {
24    fn add<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
25    where
26        A: GLWECiphertextToRef + GLWEInfos,
27        B: GLWECiphertextToRef + GLWEInfos,
28        Module<BACKEND>: VecZnxAdd + VecZnxCopy,
29    {
30        #[cfg(debug_assertions)]
31        {
32            assert_eq!(a.n(), self.n());
33            assert_eq!(b.n(), self.n());
34            assert_eq!(a.base2k(), b.base2k());
35            assert!(self.rank() >= a.rank().max(b.rank()));
36        }
37
38        let min_col: usize = (a.rank().min(b.rank()) + 1).into();
39        let max_col: usize = (a.rank().max(b.rank() + 1)).into();
40        let self_col: usize = (self.rank() + 1).into();
41
42        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
43        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
44        let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
45
46        (0..min_col).for_each(|i| {
47            module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
48        });
49
50        if a.rank() > b.rank() {
51            (min_col..max_col).for_each(|i| {
52                module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
53            });
54        } else {
55            (min_col..max_col).for_each(|i| {
56                module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
57            });
58        }
59
60        let size: usize = self_mut.size();
61        (max_col..self_col).for_each(|i| {
62            (0..size).for_each(|j| {
63                self_mut.data.zero_at(i, j);
64            });
65        });
66
67        self.set_basek(a.base2k());
68        self.set_k(set_k_binary(self, a, b));
69    }
70
71    fn add_inplace<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
72    where
73        A: GLWECiphertextToRef + GLWEInfos,
74        Module<BACKEND>: VecZnxAddInplace,
75    {
76        #[cfg(debug_assertions)]
77        {
78            assert_eq!(a.n(), self.n());
79            assert_eq!(self.base2k(), a.base2k());
80            assert!(self.rank() >= a.rank())
81        }
82
83        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
84        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
85
86        (0..(a.rank() + 1).into()).for_each(|i| {
87            module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
88        });
89
90        self.set_k(set_k_unary(self, a))
91    }
92
93    fn sub<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
94    where
95        A: GLWECiphertextToRef + GLWEInfos,
96        B: GLWECiphertextToRef + GLWEInfos,
97        Module<BACKEND>: VecZnxSub + VecZnxCopy + VecZnxNegateInplace,
98    {
99        #[cfg(debug_assertions)]
100        {
101            assert_eq!(a.n(), self.n());
102            assert_eq!(b.n(), self.n());
103            assert_eq!(a.base2k(), b.base2k());
104            assert!(self.rank() >= a.rank().max(b.rank()));
105        }
106
107        let min_col: usize = (a.rank().min(b.rank()) + 1).into();
108        let max_col: usize = (a.rank().max(b.rank() + 1)).into();
109        let self_col: usize = (self.rank() + 1).into();
110
111        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
112        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
113        let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
114
115        (0..min_col).for_each(|i| {
116            module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
117        });
118
119        if a.rank() > b.rank() {
120            (min_col..max_col).for_each(|i| {
121                module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
122            });
123        } else {
124            (min_col..max_col).for_each(|i| {
125                module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
126                module.vec_znx_negate_inplace(&mut self_mut.data, i);
127            });
128        }
129
130        let size: usize = self_mut.size();
131        (max_col..self_col).for_each(|i| {
132            (0..size).for_each(|j| {
133                self_mut.data.zero_at(i, j);
134            });
135        });
136
137        self.set_basek(a.base2k());
138        self.set_k(set_k_binary(self, a, b));
139    }
140
141    fn sub_inplace_ab<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
142    where
143        A: GLWECiphertextToRef + GLWEInfos,
144        Module<BACKEND>: VecZnxSubInplace,
145    {
146        #[cfg(debug_assertions)]
147        {
148            assert_eq!(a.n(), self.n());
149            assert_eq!(self.base2k(), a.base2k());
150            assert!(self.rank() >= a.rank())
151        }
152
153        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
154        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
155
156        (0..(a.rank() + 1).into()).for_each(|i| {
157            module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i);
158        });
159
160        self.set_k(set_k_unary(self, a))
161    }
162
163    fn sub_inplace_ba<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
164    where
165        A: GLWECiphertextToRef + GLWEInfos,
166        Module<BACKEND>: VecZnxSubNegateInplace,
167    {
168        #[cfg(debug_assertions)]
169        {
170            assert_eq!(a.n(), self.n());
171            assert_eq!(self.base2k(), a.base2k());
172            assert!(self.rank() >= a.rank())
173        }
174
175        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
176        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
177
178        (0..(a.rank() + 1).into()).for_each(|i| {
179            module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i);
180        });
181
182        self.set_k(set_k_unary(self, a))
183    }
184
185    fn rotate<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
186    where
187        A: GLWECiphertextToRef + GLWEInfos,
188        Module<B>: VecZnxRotate,
189    {
190        #[cfg(debug_assertions)]
191        {
192            assert_eq!(a.n(), self.n());
193            assert_eq!(self.rank(), a.rank())
194        }
195
196        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
197        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
198
199        (0..(a.rank() + 1).into()).for_each(|i| {
200            module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
201        });
202
203        self.set_basek(a.base2k());
204        self.set_k(set_k_unary(self, a))
205    }
206
207    fn rotate_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
208    where
209        Module<B>: VecZnxRotateInplace<B>,
210    {
211        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
212
213        (0..(self_mut.rank() + 1).into()).for_each(|i| {
214            module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch);
215        });
216    }
217
218    fn mul_xp_minus_one<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
219    where
220        A: GLWECiphertextToRef + GLWEInfos,
221        Module<B>: VecZnxMulXpMinusOne,
222    {
223        #[cfg(debug_assertions)]
224        {
225            assert_eq!(a.n(), self.n());
226            assert_eq!(self.rank(), a.rank())
227        }
228
229        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
230        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
231
232        (0..(a.rank() + 1).into()).for_each(|i| {
233            module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i);
234        });
235
236        self.set_basek(a.base2k());
237        self.set_k(set_k_unary(self, a))
238    }
239
240    fn mul_xp_minus_one_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
241    where
242        Module<B>: VecZnxMulXpMinusOneInplace<B>,
243    {
244        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
245
246        (0..(self_mut.rank() + 1).into()).for_each(|i| {
247            module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch);
248        });
249    }
250
251    fn copy<A, B: Backend>(&mut self, module: &Module<B>, a: &A)
252    where
253        A: GLWECiphertextToRef + GLWEInfos,
254        Module<B>: VecZnxCopy,
255    {
256        #[cfg(debug_assertions)]
257        {
258            assert_eq!(self.n(), a.n());
259            assert_eq!(self.rank(), a.rank());
260        }
261
262        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
263        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
264
265        (0..(self_mut.rank() + 1).into()).for_each(|i| {
266            module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
267        });
268
269        self.set_k(a.k().min(self.max_k()));
270        self.set_basek(a.base2k());
271    }
272
273    fn rsh<B: Backend>(&mut self, module: &Module<B>, k: usize, scratch: &mut Scratch<B>)
274    where
275        Module<B>: VecZnxRshInplace<B>,
276    {
277        let base2k: usize = self.base2k().into();
278        (0..(self.rank() + 1).into()).for_each(|i| {
279            module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch);
280        })
281    }
282
283    fn normalize<A, B: Backend>(&mut self, module: &Module<B>, a: &A, scratch: &mut Scratch<B>)
284    where
285        A: GLWECiphertextToRef + GLWEInfos,
286        Module<B>: VecZnxNormalize<B>,
287    {
288        #[cfg(debug_assertions)]
289        {
290            assert_eq!(self.n(), a.n());
291            assert_eq!(self.rank(), a.rank());
292        }
293
294        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
295        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
296
297        (0..(self_mut.rank() + 1).into()).for_each(|i| {
298            module.vec_znx_normalize(
299                a.base2k().into(),
300                &mut self_mut.data,
301                i,
302                a.base2k().into(),
303                &a_ref.data,
304                i,
305                scratch,
306            );
307        });
308        self.set_basek(a.base2k());
309        self.set_k(a.k().min(self.k()));
310    }
311
312    fn normalize_inplace<B: Backend>(&mut self, module: &Module<B>, scratch: &mut Scratch<B>)
313    where
314        Module<B>: VecZnxNormalizeInplace<B>,
315    {
316        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
317        (0..(self_mut.rank() + 1).into()).for_each(|i| {
318            module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch);
319        });
320    }
321}
322
323impl GLWECiphertext<Vec<u8>> {
324    pub fn rsh_scratch_space(n: usize) -> usize {
325        VecZnx::rsh_scratch_space(n)
326    }
327}
328
329// c = op(a, b)
330fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
331    // If either operands is a ciphertext
332    if a.rank() != 0 || b.rank() != 0 {
333        // If a is a plaintext (but b ciphertext)
334        let k = if a.rank() == 0 {
335            b.k()
336        // If b is a plaintext (but a ciphertext)
337        } else if b.rank() == 0 {
338            a.k()
339        // If a & b are both ciphertexts
340        } else {
341            a.k().min(b.k())
342        };
343        k.min(c.k())
344    // If a & b are both plaintexts
345    } else {
346        c.k()
347    }
348}
349
350// a = op(a, b)
351fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
352    if a.rank() != 0 || b.rank() != 0 {
353        a.k().min(b.k())
354    } else {
355        a.k()
356    }
357}