1use poulpy_hal::{
2 api::{
3 VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
4 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
5 VecZnxSubABInplace, VecZnxSubBAInplace,
6 },
7 layouts::{Backend, DataMut, Module, Scratch, VecZnx, ZnxZero},
8};
9
10use crate::layouts::{GLWECiphertext, GLWECiphertextToMut, GLWECiphertextToRef, GLWEPlaintext, Infos, SetMetaData};
11
12impl<D> GLWEOperations for GLWEPlaintext<D>
13where
14 D: DataMut,
15 GLWEPlaintext<D>: GLWECiphertextToMut + Infos + SetMetaData,
16{
17}
18
19impl<D: DataMut> GLWEOperations for GLWECiphertext<D> where GLWECiphertext<D>: GLWECiphertextToMut + Infos + SetMetaData {}
20
21pub trait GLWEOperations: GLWECiphertextToMut + SetMetaData + Sized {
22 fn add<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
23 where
24 A: GLWECiphertextToRef,
25 B: GLWECiphertextToRef,
26 Module<BACKEND>: VecZnxAdd + VecZnxCopy,
27 {
28 #[cfg(debug_assertions)]
29 {
30 assert_eq!(a.n(), self.n());
31 assert_eq!(b.n(), self.n());
32 assert_eq!(a.basek(), b.basek());
33 assert!(self.rank() >= a.rank().max(b.rank()));
34 }
35
36 let min_col: usize = a.rank().min(b.rank()) + 1;
37 let max_col: usize = a.rank().max(b.rank() + 1);
38 let self_col: usize = self.rank() + 1;
39
40 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
41 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
42 let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
43
44 (0..min_col).for_each(|i| {
45 module.vec_znx_add(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
46 });
47
48 if a.rank() > b.rank() {
49 (min_col..max_col).for_each(|i| {
50 module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
51 });
52 } else {
53 (min_col..max_col).for_each(|i| {
54 module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
55 });
56 }
57
58 let size: usize = self_mut.size();
59 (max_col..self_col).for_each(|i| {
60 (0..size).for_each(|j| {
61 self_mut.data.zero_at(i, j);
62 });
63 });
64
65 self.set_basek(a.basek());
66 self.set_k(set_k_binary(self, a, b));
67 }
68
69 fn add_inplace<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
70 where
71 A: GLWECiphertextToRef + Infos,
72 Module<BACKEND>: VecZnxAddInplace,
73 {
74 #[cfg(debug_assertions)]
75 {
76 assert_eq!(a.n(), self.n());
77 assert_eq!(self.basek(), a.basek());
78 assert!(self.rank() >= a.rank())
79 }
80
81 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
82 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
83
84 (0..a.rank() + 1).for_each(|i| {
85 module.vec_znx_add_inplace(&mut self_mut.data, i, &a_ref.data, i);
86 });
87
88 self.set_k(set_k_unary(self, a))
89 }
90
91 fn sub<A, B, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A, b: &B)
92 where
93 A: GLWECiphertextToRef,
94 B: GLWECiphertextToRef,
95 Module<BACKEND>: VecZnxSub + VecZnxCopy + VecZnxNegateInplace,
96 {
97 #[cfg(debug_assertions)]
98 {
99 assert_eq!(a.n(), self.n());
100 assert_eq!(b.n(), self.n());
101 assert_eq!(a.basek(), b.basek());
102 assert!(self.rank() >= a.rank().max(b.rank()));
103 }
104
105 let min_col: usize = a.rank().min(b.rank()) + 1;
106 let max_col: usize = a.rank().max(b.rank() + 1);
107 let self_col: usize = self.rank() + 1;
108
109 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
110 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
111 let b_ref: &GLWECiphertext<&[u8]> = &b.to_ref();
112
113 (0..min_col).for_each(|i| {
114 module.vec_znx_sub(&mut self_mut.data, i, &a_ref.data, i, &b_ref.data, i);
115 });
116
117 if a.rank() > b.rank() {
118 (min_col..max_col).for_each(|i| {
119 module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
120 });
121 } else {
122 (min_col..max_col).for_each(|i| {
123 module.vec_znx_copy(&mut self_mut.data, i, &b_ref.data, i);
124 module.vec_znx_negate_inplace(&mut self_mut.data, i);
125 });
126 }
127
128 let size: usize = self_mut.size();
129 (max_col..self_col).for_each(|i| {
130 (0..size).for_each(|j| {
131 self_mut.data.zero_at(i, j);
132 });
133 });
134
135 self.set_basek(a.basek());
136 self.set_k(set_k_binary(self, a, b));
137 }
138
139 fn sub_inplace_ab<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
140 where
141 A: GLWECiphertextToRef + Infos,
142 Module<BACKEND>: VecZnxSubABInplace,
143 {
144 #[cfg(debug_assertions)]
145 {
146 assert_eq!(a.n(), self.n());
147 assert_eq!(self.basek(), a.basek());
148 assert!(self.rank() >= a.rank())
149 }
150
151 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
152 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
153
154 (0..a.rank() + 1).for_each(|i| {
155 module.vec_znx_sub_ab_inplace(&mut self_mut.data, i, &a_ref.data, i);
156 });
157
158 self.set_k(set_k_unary(self, a))
159 }
160
161 fn sub_inplace_ba<A, BACKEND: Backend>(&mut self, module: &Module<BACKEND>, a: &A)
162 where
163 A: GLWECiphertextToRef + Infos,
164 Module<BACKEND>: VecZnxSubBAInplace,
165 {
166 #[cfg(debug_assertions)]
167 {
168 assert_eq!(a.n(), self.n());
169 assert_eq!(self.basek(), a.basek());
170 assert!(self.rank() >= a.rank())
171 }
172
173 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
174 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
175
176 (0..a.rank() + 1).for_each(|i| {
177 module.vec_znx_sub_ba_inplace(&mut self_mut.data, i, &a_ref.data, i);
178 });
179
180 self.set_k(set_k_unary(self, a))
181 }
182
183 fn rotate<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
184 where
185 A: GLWECiphertextToRef + Infos,
186 Module<B>: VecZnxRotate,
187 {
188 #[cfg(debug_assertions)]
189 {
190 assert_eq!(a.n(), self.n());
191 assert_eq!(self.rank(), a.rank())
192 }
193
194 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
195 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
196
197 (0..a.rank() + 1).for_each(|i| {
198 module.vec_znx_rotate(k, &mut self_mut.data, i, &a_ref.data, i);
199 });
200
201 self.set_basek(a.basek());
202 self.set_k(set_k_unary(self, a))
203 }
204
205 fn rotate_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
206 where
207 Module<B>: VecZnxRotateInplace<B>,
208 {
209 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
210
211 (0..self_mut.rank() + 1).for_each(|i| {
212 module.vec_znx_rotate_inplace(k, &mut self_mut.data, i, scratch);
213 });
214 }
215
216 fn mul_xp_minus_one<A, B: Backend>(&mut self, module: &Module<B>, k: i64, a: &A)
217 where
218 A: GLWECiphertextToRef + Infos,
219 Module<B>: VecZnxMulXpMinusOne,
220 {
221 #[cfg(debug_assertions)]
222 {
223 assert_eq!(a.n(), self.n());
224 assert_eq!(self.rank(), a.rank())
225 }
226
227 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
228 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
229
230 (0..a.rank() + 1).for_each(|i| {
231 module.vec_znx_mul_xp_minus_one(k, &mut self_mut.data, i, &a_ref.data, i);
232 });
233
234 self.set_basek(a.basek());
235 self.set_k(set_k_unary(self, a))
236 }
237
238 fn mul_xp_minus_one_inplace<B: Backend>(&mut self, module: &Module<B>, k: i64, scratch: &mut Scratch<B>)
239 where
240 Module<B>: VecZnxMulXpMinusOneInplace<B>,
241 {
242 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
243
244 (0..self_mut.rank() + 1).for_each(|i| {
245 module.vec_znx_mul_xp_minus_one_inplace(k, &mut self_mut.data, i, scratch);
246 });
247 }
248
249 fn copy<A, B: Backend>(&mut self, module: &Module<B>, a: &A)
250 where
251 A: GLWECiphertextToRef + Infos,
252 Module<B>: VecZnxCopy,
253 {
254 #[cfg(debug_assertions)]
255 {
256 assert_eq!(self.n(), a.n());
257 assert_eq!(self.rank(), a.rank());
258 }
259
260 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
261 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
262
263 (0..self_mut.rank() + 1).for_each(|i| {
264 module.vec_znx_copy(&mut self_mut.data, i, &a_ref.data, i);
265 });
266
267 self.set_k(a.k().min(self.size() * self.basek()));
268 self.set_basek(a.basek());
269 }
270
271 fn rsh<B: Backend>(&mut self, module: &Module<B>, k: usize, scratch: &mut Scratch<B>)
272 where
273 Module<B>: VecZnxRshInplace<B>,
274 {
275 let basek: usize = self.basek();
276 (0..self.cols()).for_each(|i| {
277 module.vec_znx_rsh_inplace(basek, k, &mut self.to_mut().data, i, scratch);
278 })
279 }
280
281 fn normalize<A, B: Backend>(&mut self, module: &Module<B>, a: &A, scratch: &mut Scratch<B>)
282 where
283 A: GLWECiphertextToRef,
284 Module<B>: VecZnxNormalize<B>,
285 {
286 #[cfg(debug_assertions)]
287 {
288 assert_eq!(self.n(), a.n());
289 assert_eq!(self.rank(), a.rank());
290 }
291
292 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
293 let a_ref: &GLWECiphertext<&[u8]> = &a.to_ref();
294
295 (0..self_mut.rank() + 1).for_each(|i| {
296 module.vec_znx_normalize(a.basek(), &mut self_mut.data, i, &a_ref.data, i, scratch);
297 });
298 self.set_basek(a.basek());
299 self.set_k(a.k().min(self.k()));
300 }
301
302 fn normalize_inplace<B: Backend>(&mut self, module: &Module<B>, scratch: &mut Scratch<B>)
303 where
304 Module<B>: VecZnxNormalizeInplace<B>,
305 {
306 let self_mut: &mut GLWECiphertext<&mut [u8]> = &mut self.to_mut();
307 (0..self_mut.rank() + 1).for_each(|i| {
308 module.vec_znx_normalize_inplace(self_mut.basek(), &mut self_mut.data, i, scratch);
309 });
310 }
311}
312
313impl GLWECiphertext<Vec<u8>> {
314 pub fn rsh_scratch_space(n: usize) -> usize {
315 VecZnx::rsh_scratch_space(n)
316 }
317}
318
319fn set_k_binary(c: &impl Infos, a: &impl Infos, b: &impl Infos) -> usize {
321 if a.rank() != 0 || b.rank() != 0 {
323 let k = if a.rank() == 0 {
325 b.k()
326 } else if b.rank() == 0 {
328 a.k()
329 } else {
331 a.k().min(b.k())
332 };
333 k.min(c.k())
334 } else {
336 c.k()
337 }
338}
339
340fn set_k_unary(a: &impl Infos, b: &impl Infos) -> usize {
342 if a.rank() != 0 || b.rank() != 0 {
343 a.k().min(b.k())
344 } else {
345 a.k()
346 }
347}