poulpy_core/operations/
glwe.rs

1use poulpy_hal::{
2    api::{
3        VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
4        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
5        VecZnxSubABInplace, VecZnxSubBAInplace,
6    },
7    layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero},
8};
9
10use crate::layouts::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, Infos, SetMetaData};
11
12impl<D> GLWEOperations for GLWEPlaintext<D>
13where
14    D: DataMut,
15    GLWEPlaintext<D>: GLWECiphertextToMut + Infos + SetMetaData,
16{
17}
18
19impl<D: DataMut> GLWEOperations for GLWECiphertext<D> where GLWECiphertext<D>: GLWECiphertextToMut + Infos + SetMetaData {}
20
21pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
22    fn add<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
23    where
24        A: GLWECiphertextToRef,
25        B: GLWECiphertextToRef,
26        Module<BACKEND>: VecZnxAdd + VecZnxCopy,
27    {
28        #[cfg(debug_assertions)]
29        {
30            assert_eq!(a.n(), self.n());
31            assert_eq!(b.n(), self.n());
32            assert_eq!(a.basek(), b.basek());
33            assert!(self.rank() >= a.rank().max(b.rank()));
34        }
35
36        let min_col: usize = a.rank().min(b.rank()) + 1;
37        let max_col: usize = a.rank().max(b.rank() + 1);
38        let self_col: usize = self.rank() + 1;
39
40        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
41        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
42        let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
43
44        (0..min_col).for_each(|i| {
45            module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
46        });
47
48        if a.rank() > b.rank() {
49            (min_col..max_col).for_each(|i| {
50                module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
51            });
52        } else {
53            (min_col..max_col).for_each(|i| {
54                module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
55            });
56        }
57
58        let size: usize = self_mut.size();
59        (max_col..self_col).for_each(|i| {
60            (0..size).for_each(|j| {
61                self_mut.data.zero_at(i, j);
62            });
63        });
64
65        self.set_basek(a.basek());
66        self.set_k(set_k_binary(self, a, b));
67    }
68
69    fn add_inplace<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
70    where
71        A: GLWECiphertextToRef + Infos,
72        Module<BACKEND>: VecZnxAddInplace,
73    {
74        #[cfg(debug_assertions)]
75        {
76            assert_eq!(a.n(), self.n());
77            assert_eq!(self.basek(), a.basek());
78            assert!(self.rank() >= a.rank())
79        }
80
81        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
82        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
83
84        (0..a.rank() + 1).for_each(|i| {
85            module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
86        });
87
88        self.set_k(set_k_unary(self, a))
89    }
90
91    fn sub<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
92    where
93        A: GLWECiphertextToRef,
94        B: GLWECiphertextToRef,
95        Module<BACKEND>: VecZnxSub + VecZnxCopy + VecZnxNegateInplace,
96    {
97        #[cfg(debug_assertions)]
98        {
99            assert_eq!(a.n(), self.n());
100            assert_eq!(b.n(), self.n());
101            assert_eq!(a.basek(), b.basek());
102            assert!(self.rank() >= a.rank().max(b.rank()));
103        }
104
105        let min_col: usize = a.rank().min(b.rank()) + 1;
106        let max_col: usize = a.rank().max(b.rank() + 1);
107        let self_col: usize = self.rank() + 1;
108
109        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
110        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
111        let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
112
113        (0..min_col).for_each(|i| {
114            module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
115        });
116
117        if a.rank() > b.rank() {
118            (min_col..max_col).for_each(|i| {
119                module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
120            });
121        } else {
122            (min_col..max_col).for_each(|i| {
123                module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
124                module.vec_znx_negate_inplace(&mut self_mut.data, i);
125            });
126        }
127
128        let size: usize = self_mut.size();
129        (max_col..self_col).for_each(|i| {
130            (0..size).for_each(|j| {
131                self_mut.data.zero_at(i, j);
132            });
133        });
134
135        self.set_basek(a.basek());
136        self.set_k(set_k_binary(self, a, b));
137    }
138
139    fn sub_inplace_ab<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
140    where
141        A: GLWECiphertextToRef + Infos,
142        Module<BACKEND>: VecZnxSubABInplace,
143    {
144        #[cfg(debug_assertions)]
145        {
146            assert_eq!(a.n(), self.n());
147            assert_eq!(self.basek(), a.basek());
148            assert!(self.rank() >= a.rank())
149        }
150
151        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
152        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
153
154        (0..a.rank() + 1).for_each(|i| {
155            module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i);
156        });
157
158        self.set_k(set_k_unary(self, a))
159    }
160
161    fn sub_inplace_ba<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
162    where
163        A: GLWECiphertextToRef + Infos,
164        Module<BACKEND>: VecZnxSubBAInplace,
165    {
166        #[cfg(debug_assertions)]
167        {
168            assert_eq!(a.n(), self.n());
169            assert_eq!(self.basek(), a.basek());
170            assert!(self.rank() >= a.rank())
171        }
172
173        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
174        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
175
176        (0..a.rank() + 1).for_each(|i| {
177            module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i);
178        });
179
180        self.set_k(set_k_unary(self, a))
181    }
182
183    fn rotate<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
184    where
185        A: GLWECiphertextToRef + Infos,
186        Module<B>: VecZnxRotate,
187    {
188        #[cfg(debug_assertions)]
189        {
190            assert_eq!(a.n(), self.n());
191            assert_eq!(self.rank(), a.rank())
192        }
193
194        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
195        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
196
197        (0..a.rank() + 1).for_each(|i| {
198            module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
199        });
200
201        self.set_basek(a.basek());
202        self.set_k(set_k_unary(self, a))
203    }
204
205    fn rotate_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
206    where
207        Module<B>: VecZnxRotateInplace<B>,
208    {
209        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
210
211        (0..self_mut.rank() + 1).for_each(|i| {
212            module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch);
213        });
214    }
215
216    fn mul_xp_minus_one<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
217    where
218        A: GLWECiphertextToRef + Infos,
219        Module<B>: VecZnxMulXpMinusOne,
220    {
221        #[cfg(debug_assertions)]
222        {
223            assert_eq!(a.n(), self.n());
224            assert_eq!(self.rank(), a.rank())
225        }
226
227        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
228        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
229
230        (0..a.rank() + 1).for_each(|i| {
231            module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i);
232        });
233
234        self.set_basek(a.basek());
235        self.set_k(set_k_unary(self, a))
236    }
237
238    fn mul_xp_minus_one_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
239    where
240        Module<B>: VecZnxMulXpMinusOneInplace<B>,
241    {
242        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
243
244        (0..self_mut.rank() + 1).for_each(|i| {
245            module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch);
246        });
247    }
248
249    fn copy<A, B: Backend>(&mut self, module: &Module<B>, a: &A)
250    where
251        A: GLWECiphertextToRef + Infos,
252        Module<B>: VecZnxCopy,
253    {
254        #[cfg(debug_assertions)]
255        {
256            assert_eq!(self.n(), a.n());
257            assert_eq!(self.rank(), a.rank());
258        }
259
260        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
261        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
262
263        (0..self_mut.rank() + 1).for_each(|i| {
264            module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
265        });
266
267        self.set_k(a.k().min(self.size() * self.basek()));
268        self.set_basek(a.basek());
269    }
270
271    fn rsh<B: Backend>(&mut self, module: &Module<B>, k: usize, scratch: &mut Scratch<B>)
272    where
273        Module<B>: VecZnxRshInplace<B>,
274    {
275        let basek: usize = self.basek();
276        (0..self.cols()).for_each(|i| {
277            module.vec_znx_rsh_inplace(basek, k, &mut self.to_mut().data, i, scratch);
278        })
279    }
280
281    fn normalize<A, B: Backend>(&mut self, module: &Module<B>, a: &A, scratch: &mut Scratch<B>)
282    where
283        A: GLWECiphertextToRef,
284        Module<B>: VecZnxNormalize<B>,
285    {
286        #[cfg(debug_assertions)]
287        {
288            assert_eq!(self.n(), a.n());
289            assert_eq!(self.rank(), a.rank());
290        }
291
292        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
293        let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
294
295        (0..self_mut.rank() + 1).for_each(|i| {
296            module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch);
297        });
298        self.set_basek(a.basek());
299        self.set_k(a.k().min(self.k()));
300    }
301
302    fn normalize_inplace<B: Backend>(&mut self, module: &Module<B>, scratch: &mut Scratch<B>)
303    where
304        Module<B>: VecZnxNormalizeInplace<B>,
305    {
306        let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
307        (0..self_mut.rank() + 1).for_each(|i| {
308            module.vec_znx_normalize_inplace(self_mut.basek(), &mut self_mut.data, i, scratch);
309        });
310    }
311}
312
313impl GLWECiphertext<Vec<u8>> {
314    pub fn rsh_scratch_space(n: usize) -> usize {
315        VecZnx::rsh_scratch_space(n)
316    }
317}
318
319// c = op(a, b)
320fn set_k_binary(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize {
321    // If either operands is a ciphertext
322    if a.rank() != 0 || b.rank() != 0 {
323        // If a is a plaintext (but b ciphertext)
324        let k = if a.rank() == 0 {
325            b.k()
326        // If b is a plaintext (but a ciphertext)
327        } else if b.rank() == 0 {
328            a.k()
329        // If a & b are both ciphertexts
330        } else {
331            a.k().min(b.k())
332        };
333        k.min(c.k())
334    // If a & b are both plaintexts
335    } else {
336        c.k()
337    }
338}
339
340// a = op(a, b)
341fn set_k_unary(a: &impl Infos, b: &impl Infos) -> usize {
342    if a.rank() != 0 || b.rank() != 0 {
343        a.k().min(b.k())
344    } else {
345        a.k()
346    }
347}