1use poulpy_hal::{
2 api::{
3 VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
4 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
5 VecZnxSubInplace, VecZnxSubNegateInplace,
6 },
7 layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero},
8};
9
10use crate::layouts::{
11 GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEInfos, GLWELayoutSet, GLWEPlaintext, LWEInfos, TorusPrecision,
12};
13
14impl<D> GLWEOperations for GLWEPlaintext<D>
15where
16 D: DataMut,
17 GLWEPlaintext<D>: GLWECiphertextToMut + GLWEInfos,
18{
19}
20
21impl<D: DataMut> GLWEOperations for GLWECiphertext<D> where GLWECiphertext<D>: GLWECiphertextToMut + GLWEInfos {}
22
23pub trait GLWEOperations: GLWECiphertextToMut + GLWEInfos + GLWELayoutSet + Sized {
24 fn add<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
25 where
26 A: GLWECiphertextToRef + GLWEInfos,
27 B: GLWECiphertextToRef + GLWEInfos,
28 Module<BACKEND>: VecZnxAdd + VecZnxCopy,
29 {
30 #[cfg(debug_assertions)]
31 {
32 assert_eq!(a.n(), self.n());
33 assert_eq!(b.n(), self.n());
34 assert_eq!(a.base2k(), b.base2k());
35 assert!(self.rank() >= a.rank().max(b.rank()));
36 }
37
38 let min_col: usize = (a.rank().min(b.rank()) + 1).into();
39 let max_col: usize = (a.rank().max(b.rank() + 1)).into();
40 let self_col: usize = (self.rank() + 1).into();
41
42 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
43 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
44 let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
45
46 (0..min_col).for_each(|i| {
47 module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
48 });
49
50 if a.rank() > b.rank() {
51 (min_col..max_col).for_each(|i| {
52 module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
53 });
54 } else {
55 (min_col..max_col).for_each(|i| {
56 module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
57 });
58 }
59
60 let size: usize = self_mut.size();
61 (max_col..self_col).for_each(|i| {
62 (0..size).for_each(|j| {
63 self_mut.data.zero_at(i, j);
64 });
65 });
66
67 self.set_basek(a.base2k());
68 self.set_k(set_k_binary(self, a, b));
69 }
70
71 fn add_inplace<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
72 where
73 A: GLWECiphertextToRef + GLWEInfos,
74 Module<BACKEND>: VecZnxAddInplace,
75 {
76 #[cfg(debug_assertions)]
77 {
78 assert_eq!(a.n(), self.n());
79 assert_eq!(self.base2k(), a.base2k());
80 assert!(self.rank() >= a.rank())
81 }
82
83 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
84 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
85
86 (0..(a.rank() + 1).into()).for_each(|i| {
87 module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
88 });
89
90 self.set_k(set_k_unary(self, a))
91 }
92
93 fn sub<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
94 where
95 A: GLWECiphertextToRef + GLWEInfos,
96 B: GLWECiphertextToRef + GLWEInfos,
97 Module<BACKEND>: VecZnxSub + VecZnxCopy + VecZnxNegateInplace,
98 {
99 #[cfg(debug_assertions)]
100 {
101 assert_eq!(a.n(), self.n());
102 assert_eq!(b.n(), self.n());
103 assert_eq!(a.base2k(), b.base2k());
104 assert!(self.rank() >= a.rank().max(b.rank()));
105 }
106
107 let min_col: usize = (a.rank().min(b.rank()) + 1).into();
108 let max_col: usize = (a.rank().max(b.rank() + 1)).into();
109 let self_col: usize = (self.rank() + 1).into();
110
111 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
112 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
113 let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
114
115 (0..min_col).for_each(|i| {
116 module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
117 });
118
119 if a.rank() > b.rank() {
120 (min_col..max_col).for_each(|i| {
121 module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
122 });
123 } else {
124 (min_col..max_col).for_each(|i| {
125 module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
126 module.vec_znx_negate_inplace(&mut self_mut.data, i);
127 });
128 }
129
130 let size: usize = self_mut.size();
131 (max_col..self_col).for_each(|i| {
132 (0..size).for_each(|j| {
133 self_mut.data.zero_at(i, j);
134 });
135 });
136
137 self.set_basek(a.base2k());
138 self.set_k(set_k_binary(self, a, b));
139 }
140
141 fn sub_inplace_ab<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
142 where
143 A: GLWECiphertextToRef + GLWEInfos,
144 Module<BACKEND>: VecZnxSubInplace,
145 {
146 #[cfg(debug_assertions)]
147 {
148 assert_eq!(a.n(), self.n());
149 assert_eq!(self.base2k(), a.base2k());
150 assert!(self.rank() >= a.rank())
151 }
152
153 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
154 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
155
156 (0..(a.rank() + 1).into()).for_each(|i| {
157 module.vec_znx_sub_inplace(&mut self_mut.data, i, &a_ref.data, i);
158 });
159
160 self.set_k(set_k_unary(self, a))
161 }
162
163 fn sub_inplace_ba<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
164 where
165 A: GLWECiphertextToRef + GLWEInfos,
166 Module<BACKEND>: VecZnxSubNegateInplace,
167 {
168 #[cfg(debug_assertions)]
169 {
170 assert_eq!(a.n(), self.n());
171 assert_eq!(self.base2k(), a.base2k());
172 assert!(self.rank() >= a.rank())
173 }
174
175 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
176 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
177
178 (0..(a.rank() + 1).into()).for_each(|i| {
179 module.vec_znx_sub_negate_inplace(&mut self_mut.data, i, &a_ref.data, i);
180 });
181
182 self.set_k(set_k_unary(self, a))
183 }
184
185 fn rotate<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
186 where
187 A: GLWECiphertextToRef + GLWEInfos,
188 Module<B>: VecZnxRotate,
189 {
190 #[cfg(debug_assertions)]
191 {
192 assert_eq!(a.n(), self.n());
193 assert_eq!(self.rank(), a.rank())
194 }
195
196 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
197 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
198
199 (0..(a.rank() + 1).into()).for_each(|i| {
200 module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
201 });
202
203 self.set_basek(a.base2k());
204 self.set_k(set_k_unary(self, a))
205 }
206
207 fn rotate_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
208 where
209 Module<B>: VecZnxRotateInplace<B>,
210 {
211 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
212
213 (0..(self_mut.rank() + 1).into()).for_each(|i| {
214 module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch);
215 });
216 }
217
218 fn mul_xp_minus_one<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
219 where
220 A: GLWECiphertextToRef + GLWEInfos,
221 Module<B>: VecZnxMulXpMinusOne,
222 {
223 #[cfg(debug_assertions)]
224 {
225 assert_eq!(a.n(), self.n());
226 assert_eq!(self.rank(), a.rank())
227 }
228
229 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
230 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
231
232 (0..(a.rank() + 1).into()).for_each(|i| {
233 module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i);
234 });
235
236 self.set_basek(a.base2k());
237 self.set_k(set_k_unary(self, a))
238 }
239
240 fn mul_xp_minus_one_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
241 where
242 Module<B>: VecZnxMulXpMinusOneInplace<B>,
243 {
244 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
245
246 (0..(self_mut.rank() + 1).into()).for_each(|i| {
247 module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch);
248 });
249 }
250
251 fn copy<A, B: Backend>(&mut self, module: &Module<B>, a: &A)
252 where
253 A: GLWECiphertextToRef + GLWEInfos,
254 Module<B>: VecZnxCopy,
255 {
256 #[cfg(debug_assertions)]
257 {
258 assert_eq!(self.n(), a.n());
259 assert_eq!(self.rank(), a.rank());
260 }
261
262 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
263 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
264
265 (0..(self_mut.rank() + 1).into()).for_each(|i| {
266 module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
267 });
268
269 self.set_k(a.k().min(self.max_k()));
270 self.set_basek(a.base2k());
271 }
272
273 fn rsh<B: Backend>(&mut self, module: &Module<B>, k: usize, scratch: &mut Scratch<B>)
274 where
275 Module<B>: VecZnxRshInplace<B>,
276 {
277 let base2k: usize = self.base2k().into();
278 (0..(self.rank() + 1).into()).for_each(|i| {
279 module.vec_znx_rsh_inplace(base2k, k, &mut self.to_mut().data, i, scratch);
280 })
281 }
282
283 fn normalize<A, B: Backend>(&mut self, module: &Module<B>, a: &A, scratch: &mut Scratch<B>)
284 where
285 A: GLWECiphertextToRef + GLWEInfos,
286 Module<B>: VecZnxNormalize<B>,
287 {
288 #[cfg(debug_assertions)]
289 {
290 assert_eq!(self.n(), a.n());
291 assert_eq!(self.rank(), a.rank());
292 }
293
294 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
295 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
296
297 (0..(self_mut.rank() + 1).into()).for_each(|i| {
298 module.vec_znx_normalize(
299 a.base2k().into(),
300 &mut self_mut.data,
301 i,
302 a.base2k().into(),
303 &a_ref.data,
304 i,
305 scratch,
306 );
307 });
308 self.set_basek(a.base2k());
309 self.set_k(a.k().min(self.k()));
310 }
311
312 fn normalize_inplace<B: Backend>(&mut self, module: &Module<B>, scratch: &mut Scratch<B>)
313 where
314 Module<B>: VecZnxNormalizeInplace<B>,
315 {
316 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
317 (0..(self_mut.rank() + 1).into()).for_each(|i| {
318 module.vec_znx_normalize_inplace(self_mut.base2k().into(), &mut self_mut.data, i, scratch);
319 });
320 }
321}
322
323impl GLWECiphertext<Vec<u8>> {
324 pub fn rsh_scratch_space(n: usize) -> usize {
325 VecZnx::rsh_scratch_space(n)
326 }
327}
328
329fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
331 if a.rank() != 0 || b.rank() != 0 {
333 let k = if a.rank() == 0 {
335 b.k()
336 } else if b.rank() == 0 {
338 a.k()
339 } else {
341 a.k().min(b.k())
342 };
343 k.min(c.k())
344 } else {
346 c.k()
347 }
348}
349
350fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
352 if a.rank() != 0 || b.rank() != 0 {
353 a.k().min(b.k())
354 } else {
355 a.k()
356 }
357}