1use poulpy_hal::{
2 api::{
3 ModuleN, VecZnxAdd, VecZnxAddInplace, VecZnxCopy, VecZnxMulXpMinusOne, VecZnxMulXpMinusOneInplace, VecZnxNegateInplace,
4 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxRotate, VecZnxRotateInplace, VecZnxRshInplace, VecZnxSub,
5 VecZnxSubInplace, VecZnxSubNegateInplace,
6 },
7 layouts::{Backend, Module, Scratch, VecZnx, ZnxZero},
8};
9
10use crate::{
11 ScratchTakeCore,
12 layouts::{GLWE, GLWEInfos, GLWEToMut, GLWEToRef, LWEInfos, SetGLWEInfos, TorusPrecision},
13};
14
15pub trait GLWEAdd
16where
17 Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace,
18{
19 fn glwe_add<R, A, B>(&self, res: &mut R, a: &A, b: &B)
20 where
21 R: GLWEToMut,
22 A: GLWEToRef,
23 B: GLWEToRef,
24 {
25 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
26 let a: &mut GLWE<&[u8]> = &mut a.to_ref();
27 let b: &GLWE<&[u8]> = &b.to_ref();
28
29 assert_eq!(a.n(), self.n() as u32);
30 assert_eq!(b.n(), self.n() as u32);
31 assert_eq!(res.n(), self.n() as u32);
32 assert_eq!(a.base2k(), b.base2k());
33 assert!(res.rank() >= a.rank().max(b.rank()));
34
35 let min_col: usize = (a.rank().min(b.rank()) + 1).into();
36 let max_col: usize = (a.rank().max(b.rank() + 1)).into();
37 let self_col: usize = (res.rank() + 1).into();
38
39 (0..min_col).for_each(|i| {
40 self.vec_znx_add(res.data_mut(), i, a.data(), i, b.data(), i);
41 });
42
43 if a.rank() > b.rank() {
44 (min_col..max_col).for_each(|i| {
45 self.vec_znx_copy(res.data_mut(), i, a.data(), i);
46 });
47 } else {
48 (min_col..max_col).for_each(|i| {
49 self.vec_znx_copy(res.data_mut(), i, b.data(), i);
50 });
51 }
52
53 let size: usize = res.size();
54 (max_col..self_col).for_each(|i| {
55 (0..size).for_each(|j| {
56 res.data.zero_at(i, j);
57 });
58 });
59
60 res.set_base2k(a.base2k());
61 res.set_k(set_k_binary(res, a, b));
62 }
63
64 fn glwe_add_inplace<R, A>(&self, res: &mut R, a: &A)
65 where
66 R: GLWEToMut,
67 A: GLWEToRef,
68 {
69 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
70 let a: &GLWE<&[u8]> = &a.to_ref();
71
72 assert_eq!(res.n(), self.n() as u32);
73 assert_eq!(a.n(), self.n() as u32);
74 assert_eq!(res.base2k(), a.base2k());
75 assert!(res.rank() >= a.rank());
76
77 (0..(a.rank() + 1).into()).for_each(|i| {
78 self.vec_znx_add_inplace(res.data_mut(), i, a.data(), i);
79 });
80
81 res.set_k(set_k_unary(res, a))
82 }
83}
84
85impl<BE: Backend> GLWEAdd for Module<BE> where Self: ModuleN + VecZnxAdd + VecZnxCopy + VecZnxAddInplace {}
86
87impl<BE: Backend> GLWESub for Module<BE> where
88 Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace
89{
90}
91
92pub trait GLWESub
93where
94 Self: ModuleN + VecZnxSub + VecZnxCopy + VecZnxNegateInplace + VecZnxSubInplace + VecZnxSubNegateInplace,
95{
96 fn glwe_sub<R, A, B>(&self, res: &mut R, a: &A, b: &B)
97 where
98 R: GLWEToMut,
99 A: GLWEToRef,
100 B: GLWEToRef,
101 {
102 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
103 let a: &GLWE<&[u8]> = &a.to_ref();
104 let b: &GLWE<&[u8]> = &b.to_ref();
105
106 assert_eq!(a.n(), self.n() as u32);
107 assert_eq!(b.n(), self.n() as u32);
108 assert_eq!(a.base2k(), b.base2k());
109 assert!(res.rank() >= a.rank().max(b.rank()));
110
111 let min_col: usize = (a.rank().min(b.rank()) + 1).into();
112 let max_col: usize = (a.rank().max(b.rank() + 1)).into();
113 let self_col: usize = (res.rank() + 1).into();
114
115 (0..min_col).for_each(|i| {
116 self.vec_znx_sub(res.data_mut(), i, a.data(), i, b.data(), i);
117 });
118
119 if a.rank() > b.rank() {
120 (min_col..max_col).for_each(|i| {
121 self.vec_znx_copy(res.data_mut(), i, a.data(), i);
122 });
123 } else {
124 (min_col..max_col).for_each(|i| {
125 self.vec_znx_copy(res.data_mut(), i, b.data(), i);
126 self.vec_znx_negate_inplace(res.data_mut(), i);
127 });
128 }
129
130 let size: usize = res.size();
131 (max_col..self_col).for_each(|i| {
132 (0..size).for_each(|j| {
133 res.data.zero_at(i, j);
134 });
135 });
136
137 res.set_base2k(a.base2k());
138 res.set_k(set_k_binary(res, a, b));
139 }
140
141 fn glwe_sub_inplace<R, A>(&self, res: &mut R, a: &A)
142 where
143 R: GLWEToMut,
144 A: GLWEToRef,
145 {
146 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
147 let a: &GLWE<&[u8]> = &a.to_ref();
148
149 assert_eq!(res.n(), self.n() as u32);
150 assert_eq!(a.n(), self.n() as u32);
151 assert_eq!(res.base2k(), a.base2k());
152 assert!(res.rank() >= a.rank());
153
154 (0..(a.rank() + 1).into()).for_each(|i| {
155 self.vec_znx_sub_inplace(res.data_mut(), i, a.data(), i);
156 });
157
158 res.set_k(set_k_unary(res, a))
159 }
160
161 fn glwe_sub_negate_inplace<R, A>(&self, res: &mut R, a: &A)
162 where
163 R: GLWEToMut,
164 A: GLWEToRef,
165 {
166 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
167 let a: &GLWE<&[u8]> = &a.to_ref();
168
169 assert_eq!(res.n(), self.n() as u32);
170 assert_eq!(a.n(), self.n() as u32);
171 assert_eq!(res.base2k(), a.base2k());
172 assert!(res.rank() >= a.rank());
173
174 (0..(a.rank() + 1).into()).for_each(|i| {
175 self.vec_znx_sub_negate_inplace(res.data_mut(), i, a.data(), i);
176 });
177
178 res.set_k(set_k_unary(res, a))
179 }
180}
181
182impl<BE: Backend> GLWERotate<BE> for Module<BE> where Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE> {}
183
184pub trait GLWERotate<BE: Backend>
185where
186 Self: ModuleN + VecZnxRotate + VecZnxRotateInplace<BE>,
187{
188 fn glwe_rotate<R, A>(&self, k: i64, res: &mut R, a: &A)
189 where
190 R: GLWEToMut,
191 A: GLWEToRef,
192 {
193 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
194 let a: &GLWE<&[u8]> = &a.to_ref();
195
196 assert_eq!(a.n(), self.n() as u32);
197 assert_eq!(res.rank(), a.rank());
198
199 (0..(a.rank() + 1).into()).for_each(|i| {
200 self.vec_znx_rotate(k, res.data_mut(), i, a.data(), i);
201 });
202
203 res.set_base2k(a.base2k());
204 res.set_k(set_k_unary(res, a))
205 }
206
207 fn glwe_rotate_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
208 where
209 R: GLWEToMut,
210 Scratch<BE>: ScratchTakeCore<BE>,
211 {
212 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
213
214 (0..(res.rank() + 1).into()).for_each(|i| {
215 self.vec_znx_rotate_inplace(k, res.data_mut(), i, scratch);
216 });
217 }
218}
219
220impl<BE: Backend> GLWEMulXpMinusOne<BE> for Module<BE> where Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE> {}
221
222pub trait GLWEMulXpMinusOne<BE: Backend>
223where
224 Self: ModuleN + VecZnxMulXpMinusOne + VecZnxMulXpMinusOneInplace<BE>,
225{
226 fn glwe_mul_xp_minus_one<R, A>(&self, k: i64, res: &mut R, a: &A)
227 where
228 R: GLWEToMut,
229 A: GLWEToRef,
230 {
231 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
232 let a: &GLWE<&[u8]> = &a.to_ref();
233
234 assert_eq!(res.n(), self.n() as u32);
235 assert_eq!(a.n(), self.n() as u32);
236 assert_eq!(res.rank(), a.rank());
237
238 for i in 0..res.rank().as_usize() + 1 {
239 self.vec_znx_mul_xp_minus_one(k, res.data_mut(), i, a.data(), i);
240 }
241
242 res.set_base2k(a.base2k());
243 res.set_k(set_k_unary(res, a))
244 }
245
246 fn glwe_mul_xp_minus_one_inplace<R>(&self, k: i64, res: &mut R, scratch: &mut Scratch<BE>)
247 where
248 R: GLWEToMut,
249 {
250 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
251
252 assert_eq!(res.n(), self.n() as u32);
253
254 for i in 0..res.rank().as_usize() + 1 {
255 self.vec_znx_mul_xp_minus_one_inplace(k, res.data_mut(), i, scratch);
256 }
257 }
258}
259
260impl<BE: Backend> GLWECopy for Module<BE> where Self: ModuleN + VecZnxCopy {}
261
262pub trait GLWECopy
263where
264 Self: ModuleN + VecZnxCopy,
265{
266 fn glwe_copy<R, A>(&self, res: &mut R, a: &A)
267 where
268 R: GLWEToMut,
269 A: GLWEToRef,
270 {
271 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
272 let a: &GLWE<&[u8]> = &a.to_ref();
273
274 assert_eq!(res.n(), self.n() as u32);
275 assert_eq!(a.n(), self.n() as u32);
276 assert_eq!(res.rank(), a.rank());
277
278 for i in 0..res.rank().as_usize() + 1 {
279 self.vec_znx_copy(res.data_mut(), i, a.data(), i);
280 }
281
282 res.set_k(a.k().min(res.max_k()));
283 res.set_base2k(a.base2k());
284 }
285}
286
287impl<BE: Backend> GLWEShift<BE> for Module<BE> where Self: ModuleN + VecZnxRshInplace<BE> {}
288
289pub trait GLWEShift<BE: Backend>
290where
291 Self: ModuleN + VecZnxRshInplace<BE>,
292{
293 fn glwe_rsh_tmp_byte(&self) -> usize {
294 VecZnx::rsh_tmp_bytes(self.n())
295 }
296
297 fn glwe_rsh<R>(&self, k: usize, res: &mut R, scratch: &mut Scratch<BE>)
298 where
299 R: GLWEToMut,
300 Scratch<BE>: ScratchTakeCore<BE>,
301 {
302 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
303 let base2k: usize = res.base2k().into();
304 for i in 0..res.rank().as_usize() + 1 {
305 self.vec_znx_rsh_inplace(base2k, k, res.data_mut(), i, scratch);
306 }
307 }
308}
309
310impl GLWE<Vec<u8>> {
311 pub fn rsh_tmp_bytes<M, BE: Backend>(module: &M) -> usize
312 where
313 M: GLWEShift<BE>,
314 {
315 module.glwe_rsh_tmp_byte()
316 }
317}
318
319impl<BE: Backend> GLWENormalize<BE> for Module<BE> where Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE> {}
320
321pub trait GLWENormalize<BE: Backend>
322where
323 Self: ModuleN + VecZnxNormalize<BE> + VecZnxNormalizeInplace<BE>,
324{
325 fn glwe_normalize<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
326 where
327 R: GLWEToMut,
328 A: GLWEToRef,
329 Scratch<BE>: ScratchTakeCore<BE>,
330 {
331 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
332 let a: &GLWE<&[u8]> = &a.to_ref();
333
334 assert_eq!(res.n(), self.n() as u32);
335 assert_eq!(a.n(), self.n() as u32);
336 assert_eq!(res.rank(), a.rank());
337
338 for i in 0..res.rank().as_usize() + 1 {
339 self.vec_znx_normalize(
340 res.base2k().into(),
341 res.data_mut(),
342 i,
343 a.base2k().into(),
344 a.data(),
345 i,
346 scratch,
347 );
348 }
349
350 res.set_k(a.k().min(res.k()));
351 }
352
353 fn glwe_normalize_inplace<R>(&self, res: &mut R, scratch: &mut Scratch<BE>)
354 where
355 R: GLWEToMut,
356 Scratch<BE>: ScratchTakeCore<BE>,
357 {
358 let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
359 for i in 0..res.rank().as_usize() + 1 {
360 self.vec_znx_normalize_inplace(res.base2k().into(), res.data_mut(), i, scratch);
361 }
362 }
363}
364
365fn set_k_binary(c: &impl GLWEInfos, a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
367 if a.rank() != 0 || b.rank() != 0 {
369 let k = if a.rank() == 0 {
371 b.k()
372 } else if b.rank() == 0 {
374 a.k()
375 } else {
377 a.k().min(b.k())
378 };
379 k.min(c.k())
380 } else {
382 c.k()
383 }
384}
385
386fn set_k_unary(a: &impl GLWEInfos, b: &impl GLWEInfos) -> TorusPrecision {
388 if a.rank() != 0 || b.rank() != 0 {
389 a.k().min(b.k())
390 } else {
391 a.k()
392 }
393}