1use poulpy_hal::{
2 layouts::{
3 Backend, Data, FillUniform, HostDataMut, HostDataRef, MatZnx, MatZnxToBackendMut, MatZnxToBackendRef, Module, ReaderFrom,
4 WriterTo, mat_znx_at_backend_mut_from_mut, mat_znx_at_backend_ref_from_ref, mat_znx_backend_mut_from_mut,
5 mat_znx_backend_ref_from_mut,
6 },
7 source::Source,
8};
9
10use crate::layouts::{
11 Base2K, Degree, Dnum, Dsize, GGLWEInfos, GGLWEToBackendMut, GLWEInfos, LWEInfos, Rank, TorusPrecision,
12 compressed::{GLWECompressed, GLWECompressedBackendMut, GLWECompressedViewMut, GLWECompressedViewRef, GLWEDecompress},
13};
14use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
15use std::{
16 fmt,
17 ops::{Deref, DerefMut},
18};
19
20#[derive(PartialEq, Eq, Clone)]
26pub struct GGLWECompressed<D: Data> {
27 pub(crate) data: MatZnx<D>,
28 pub(crate) base2k: Base2K,
29 pub(crate) k: TorusPrecision,
30 pub(crate) rank_out: Rank,
31 pub(crate) dsize: Dsize,
32 pub(crate) seed: Vec<[u8; 32]>,
33}
34
35pub struct GGLWECompressedBackendRef<'a, BE: Backend + 'a> {
36 inner: GGLWECompressed<BE::BufRef<'a>>,
37}
38
39impl<'a, BE: Backend + 'a> GGLWECompressedBackendRef<'a, BE> {
40 pub fn from_inner(inner: GGLWECompressed<BE::BufRef<'a>>) -> Self {
41 Self { inner }
42 }
43
44 pub fn into_inner(self) -> GGLWECompressed<BE::BufRef<'a>> {
45 self.inner
46 }
47
48 pub fn at_view(&self, row: usize, col: usize) -> GLWECompressedViewRef<'_, BE> {
49 GLWECompressedViewRef::from_inner(gglwe_compressed_at_backend_ref_from_ref::<BE>(&self.inner, row, col))
50 }
51}
52
53impl<'a, BE: Backend + 'a> Deref for GGLWECompressedBackendRef<'a, BE> {
54 type Target = GGLWECompressed<BE::BufRef<'a>>;
55
56 fn deref(&self) -> &Self::Target {
57 &self.inner
58 }
59}
60
61pub struct GGLWECompressedBackendMut<'a, BE: Backend + 'a> {
62 inner: GGLWECompressed<BE::BufMut<'a>>,
63}
64
65impl<'a, BE: Backend + 'a> GGLWECompressedBackendMut<'a, BE> {
66 pub fn from_inner(inner: GGLWECompressed<BE::BufMut<'a>>) -> Self {
67 Self { inner }
68 }
69
70 pub fn into_inner(self) -> GGLWECompressed<BE::BufMut<'a>> {
71 self.inner
72 }
73
74 pub fn at_view_mut(&mut self, row: usize, col: usize) -> GLWECompressedViewMut<'_, BE> {
75 GLWECompressedViewMut::from_inner(gglwe_compressed_at_backend_mut_from_mut::<BE>(&mut self.inner, row, col))
76 }
77}
78
79impl<'a, BE: Backend + 'a> Deref for GGLWECompressedBackendMut<'a, BE> {
80 type Target = GGLWECompressed<BE::BufMut<'a>>;
81
82 fn deref(&self) -> &Self::Target {
83 &self.inner
84 }
85}
86
87impl<'a, BE: Backend + 'a> DerefMut for GGLWECompressedBackendMut<'a, BE> {
88 fn deref_mut(&mut self) -> &mut Self::Target {
89 &mut self.inner
90 }
91}
92
93impl_gglwe_infos_for_inner!(GGLWECompressedBackendRef<'a, BE>, ['a, BE: Backend + 'a]; inner);
94impl_gglwe_infos_for_inner!(GGLWECompressedBackendMut<'a, BE>, ['a, BE: Backend + 'a]; inner);
95
96impl<'a, BE: Backend + 'a> GGLWECompressedSeedMut for GGLWECompressedBackendMut<'a, BE> {
97 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
98 &mut self.inner.seed
99 }
100}
101
102impl<'a, BE: Backend + 'a> GGLWECompressedToBackendRef<BE> for GGLWECompressedBackendRef<'a, BE> {
103 fn to_backend_ref(&self) -> GGLWECompressedBackendRef<'_, BE> {
104 GGLWECompressedBackendRef::from_inner(GGLWECompressed {
105 k: self.inner.k,
106 base2k: self.inner.base2k,
107 dsize: self.inner.dsize,
108 seed: self.inner.seed.clone(),
109 rank_out: self.inner.rank_out,
110 data: poulpy_hal::layouts::mat_znx_backend_ref_from_ref::<BE>(&self.inner.data),
111 })
112 }
113}
114
115impl<'a, BE: Backend + 'a> GGLWECompressedToBackendRef<BE> for GGLWECompressedBackendMut<'a, BE> {
116 fn to_backend_ref(&self) -> GGLWECompressedBackendRef<'_, BE> {
117 GGLWECompressedBackendRef::from_inner(GGLWECompressed {
118 k: self.inner.k,
119 base2k: self.inner.base2k,
120 dsize: self.inner.dsize,
121 seed: self.inner.seed.clone(),
122 rank_out: self.inner.rank_out,
123 data: mat_znx_backend_ref_from_mut::<BE>(&self.inner.data),
124 })
125 }
126}
127
128impl<'a, BE: Backend + 'a> GGLWECompressedToBackendMut<BE> for GGLWECompressedBackendMut<'a, BE> {
129 fn to_backend_mut(&mut self) -> GGLWECompressedBackendMut<'_, BE> {
130 GGLWECompressedBackendMut::from_inner(GGLWECompressed {
131 k: self.inner.k,
132 base2k: self.inner.base2k,
133 dsize: self.inner.dsize,
134 seed: self.inner.seed.clone(),
135 rank_out: self.inner.rank_out,
136 data: mat_znx_backend_mut_from_mut::<BE>(&mut self.inner.data),
137 })
138 }
139}
140
141pub trait GGLWECompressedSeedMut {
143 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]>;
145}
146
147impl<D: Data> GGLWECompressedSeedMut for GGLWECompressed<D> {
148 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
149 &mut self.seed
150 }
151}
152
153impl<D: Data> GGLWECompressedSeedMut for &mut GGLWECompressed<D> {
154 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
155 &mut self.seed
156 }
157}
158
159pub trait GGLWECompressedSeed {
161 fn seed(&self) -> &Vec<[u8; 32]>;
163}
164
165impl<D: HostDataRef> GGLWECompressedSeed for GGLWECompressed<D> {
166 fn seed(&self) -> &Vec<[u8; 32]> {
167 &self.seed
168 }
169}
170impl<D: Data> LWEInfos for GGLWECompressed<D> {
171 fn n(&self) -> Degree {
172 Degree(self.data.n() as u32)
173 }
174
175 fn base2k(&self) -> Base2K {
176 self.base2k
177 }
178
179 fn size(&self) -> usize {
180 self.data.size()
181 }
182}
183impl<D: Data> GLWEInfos for GGLWECompressed<D> {
184 fn rank(&self) -> Rank {
185 self.rank_out()
186 }
187}
188
189impl<D: Data> GGLWEInfos for GGLWECompressed<D> {
190 fn rank_in(&self) -> Rank {
191 Rank(self.data.cols_in() as u32)
192 }
193
194 fn rank_out(&self) -> Rank {
195 self.rank_out
196 }
197
198 fn dsize(&self) -> Dsize {
199 self.dsize
200 }
201
202 fn dnum(&self) -> Dnum {
203 Dnum(self.data.rows() as u32)
204 }
205}
206
207impl<D: Data> LWEInfos for &GGLWECompressed<D> {
208 fn n(&self) -> Degree {
209 (*self).n()
210 }
211
212 fn base2k(&self) -> Base2K {
213 (*self).base2k()
214 }
215
216 fn size(&self) -> usize {
217 (*self).size()
218 }
219}
220
221impl<D: Data> GLWEInfos for &GGLWECompressed<D> {
222 fn rank(&self) -> Rank {
223 (*self).rank()
224 }
225}
226
227impl<D: Data> GGLWEInfos for &GGLWECompressed<D> {
228 fn rank_in(&self) -> Rank {
229 (*self).rank_in()
230 }
231
232 fn rank_out(&self) -> Rank {
233 (*self).rank_out()
234 }
235
236 fn dsize(&self) -> Dsize {
237 (*self).dsize()
238 }
239
240 fn dnum(&self) -> Dnum {
241 (*self).dnum()
242 }
243}
244
245impl<D: HostDataRef> fmt::Debug for GGLWECompressed<D> {
246 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
247 write!(f, "{self}")
248 }
249}
250
251impl<D: HostDataMut> FillUniform for GGLWECompressed<D> {
252 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
253 self.data.fill_uniform(log_bound, source);
254 }
255}
256
257impl<D: HostDataRef> fmt::Display for GGLWECompressed<D> {
258 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
259 write!(
260 f,
261 "(GGLWECompressed: base2k={} k={} dsize={}) {}",
262 self.base2k.0, self.k.0, self.dsize.0, self.data
263 )
264 }
265}
266
267impl GGLWECompressed<Vec<u8>> {
268 pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
270 where
271 A: GGLWEInfos,
272 {
273 Self::alloc(
274 infos.n(),
275 infos.base2k(),
276 infos.max_k(),
277 infos.rank_in(),
278 infos.rank_out(),
279 infos.dnum(),
280 infos.dsize(),
281 )
282 }
283
284 pub(crate) fn alloc(
286 n: Degree,
287 base2k: Base2K,
288 k: TorusPrecision,
289 rank_in: Rank,
290 rank_out: Rank,
291 dnum: Dnum,
292 dsize: Dsize,
293 ) -> Self {
294 let size: usize = k.0.div_ceil(base2k.0) as usize;
295 debug_assert!(
296 size as u32 > dsize.0,
297 "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
298 dsize.0
299 );
300
301 assert!(
302 dnum.0 * dsize.0 <= size as u32,
303 "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
304 dnum.0,
305 dsize.0,
306 );
307
308 GGLWECompressed {
309 data: MatZnx::from_data(
310 poulpy_hal::layouts::HostBytesBackend::alloc_bytes(MatZnx::<Vec<u8>>::bytes_of(
311 n.into(),
312 dnum.into(),
313 rank_in.into(),
314 1,
315 size,
316 )),
317 n.into(),
318 dnum.into(),
319 rank_in.into(),
320 1,
321 size,
322 ),
323 k,
324 base2k,
325 dsize,
326 rank_out,
327 seed: vec![[0u8; 32]; (dnum.0 * rank_in.0) as usize],
328 }
329 }
330
331 pub fn bytes_of_from_infos<A>(infos: &A) -> usize
333 where
334 A: GGLWEInfos,
335 {
336 Self::bytes_of(
337 infos.n(),
338 infos.base2k(),
339 infos.max_k(),
340 infos.rank_in(),
341 infos.dnum(),
342 infos.dsize(),
343 )
344 }
345
346 pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank_in: Rank, dnum: Dnum, dsize: Dsize) -> usize {
348 let size: usize = k.0.div_ceil(base2k.0) as usize;
349 debug_assert!(
350 size as u32 > dsize.0,
351 "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
352 dsize.0
353 );
354
355 assert!(
356 dnum.0 * dsize.0 <= size as u32,
357 "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
358 dnum.0,
359 dsize.0,
360 );
361
362 MatZnx::bytes_of(n.into(), dnum.into(), rank_in.into(), 1, k.0.div_ceil(base2k.0) as usize)
363 }
364}
365
366impl<D: HostDataMut> ReaderFrom for GGLWECompressed<D> {
367 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
368 self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
369 self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
370 self.dsize = Dsize(reader.read_u32::<LittleEndian>()?);
371 self.rank_out = Rank(reader.read_u32::<LittleEndian>()?);
372 let seed_len: u32 = reader.read_u32::<LittleEndian>()?;
373 self.seed = vec![[0u8; 32]; seed_len as usize];
374 for s in &mut self.seed {
375 reader.read_exact(s)?;
376 }
377 self.data.read_from(reader)
378 }
379}
380
381impl<D: HostDataRef> WriterTo for GGLWECompressed<D> {
382 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
383 writer.write_u32::<LittleEndian>(self.k.into())?;
384 writer.write_u32::<LittleEndian>(self.base2k.into())?;
385 writer.write_u32::<LittleEndian>(self.dsize.into())?;
386 writer.write_u32::<LittleEndian>(self.rank_out.into())?;
387 writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
388 for s in &self.seed {
389 writer.write_all(s)?;
390 }
391 self.data.write_to(writer)
392 }
393}
394
395pub trait GGLWEDecompress
400where
401 Self: GLWEDecompress,
402{
403 fn decompress_gglwe<R, O>(&self, res: &mut R, other: &O)
405 where
406 R: GGLWEToBackendMut<Self::Backend> + GGLWEInfos,
407 O: GGLWECompressedToBackendRef<Self::Backend> + GGLWEInfos,
408 {
409 let mut res = res.to_backend_mut();
410 let other = other.to_backend_ref();
411
412 assert_eq!(res.dsize(), other.dsize());
413 assert!(res.dnum() <= other.dnum());
414
415 let rank_in: usize = res.rank_in().into();
416 let dnum: usize = res.dnum().into();
417 for col_i in 0..rank_in {
418 for row_i in 0..dnum {
419 let mut dst = res.at_view_mut(row_i, col_i);
420 let src = other.at_view(row_i, col_i);
421 self.decompress_glwe(&mut dst, &src);
422 }
423 }
424 }
425}
426
427impl<B: Backend> GGLWEDecompress for Module<B> where Self: GLWEDecompress {}
428
429pub trait GGLWECompressedToBackendRef<BE: Backend> {
432 fn to_backend_ref(&self) -> GGLWECompressedBackendRef<'_, BE>;
433}
434
435impl<BE: Backend> GGLWECompressedToBackendRef<BE> for GGLWECompressed<BE::OwnedBuf> {
436 fn to_backend_ref(&self) -> GGLWECompressedBackendRef<'_, BE> {
437 GGLWECompressedBackendRef::from_inner(GGLWECompressed {
438 k: self.max_k(),
439 base2k: self.base2k(),
440 dsize: self.dsize(),
441 seed: self.seed.clone(),
442 rank_out: self.rank_out,
443 data: <MatZnx<BE::OwnedBuf> as MatZnxToBackendRef<BE>>::to_backend_ref(&self.data),
444 })
445 }
446}
447
448impl<'b, BE: Backend + 'b> GGLWECompressedToBackendRef<BE> for &GGLWECompressed<BE::BufRef<'b>> {
449 fn to_backend_ref(&self) -> GGLWECompressedBackendRef<'_, BE> {
450 GGLWECompressedBackendRef::from_inner(GGLWECompressed {
451 k: self.max_k(),
452 base2k: self.base2k(),
453 dsize: self.dsize(),
454 seed: self.seed.clone(),
455 rank_out: self.rank_out,
456 data: poulpy_hal::layouts::mat_znx_backend_ref_from_ref::<BE>(&self.data),
457 })
458 }
459}
460
461impl<'b, BE: Backend + 'b> GGLWECompressedToBackendRef<BE> for &mut GGLWECompressed<BE::BufMut<'b>> {
462 fn to_backend_ref(&self) -> GGLWECompressedBackendRef<'_, BE> {
463 GGLWECompressedBackendRef::from_inner(GGLWECompressed {
464 k: self.max_k(),
465 base2k: self.base2k(),
466 dsize: self.dsize(),
467 seed: self.seed.clone(),
468 rank_out: self.rank_out,
469 data: mat_znx_backend_ref_from_mut::<BE>(&self.data),
470 })
471 }
472}
473
474pub trait GGLWECompressedToBackendMut<BE: Backend>: GGLWECompressedToBackendRef<BE> {
475 fn to_backend_mut(&mut self) -> GGLWECompressedBackendMut<'_, BE>;
476}
477
478impl<BE: Backend> GGLWECompressedToBackendMut<BE> for GGLWECompressed<BE::OwnedBuf> {
479 fn to_backend_mut(&mut self) -> GGLWECompressedBackendMut<'_, BE> {
480 GGLWECompressedBackendMut::from_inner(GGLWECompressed {
481 k: self.max_k(),
482 base2k: self.base2k(),
483 dsize: self.dsize(),
484 seed: self.seed.clone(),
485 rank_out: self.rank_out,
486 data: <MatZnx<BE::OwnedBuf> as MatZnxToBackendMut<BE>>::to_backend_mut(&mut self.data),
487 })
488 }
489}
490
491impl<'b, BE: Backend + 'b> GGLWECompressedToBackendMut<BE> for &mut GGLWECompressed<BE::BufMut<'b>> {
492 fn to_backend_mut(&mut self) -> GGLWECompressedBackendMut<'_, BE> {
493 GGLWECompressedBackendMut::from_inner(GGLWECompressed {
494 k: self.max_k(),
495 base2k: self.base2k(),
496 dsize: self.dsize(),
497 seed: self.seed.clone(),
498 rank_out: self.rank_out,
499 data: mat_znx_backend_mut_from_mut::<BE>(&mut self.data),
500 })
501 }
502}
503
504fn gglwe_compressed_at_backend_mut_from_mut<'a, 'b, BE: Backend>(
505 gglwe: &'a mut GGLWECompressed<BE::BufMut<'b>>,
506 row: usize,
507 col: usize,
508) -> GLWECompressedBackendMut<'a, BE> {
509 let rank_in: usize = gglwe.rank_in().into();
510 GLWECompressed {
511 base2k: gglwe.base2k,
512 rank: gglwe.rank_out,
513 data: mat_znx_at_backend_mut_from_mut::<BE>(&mut gglwe.data, row, col),
514 seed: gglwe.seed[rank_in * row + col],
515 }
516}
517
518fn gglwe_compressed_at_backend_ref_from_ref<'a, 'b, BE: Backend>(
519 gglwe: &'a GGLWECompressed<BE::BufRef<'b>>,
520 row: usize,
521 col: usize,
522) -> crate::layouts::compressed::GLWECompressedBackendRef<'a, BE> {
523 let rank_in: usize = gglwe.rank_in().into();
524 GLWECompressed {
525 base2k: gglwe.base2k,
526 rank: gglwe.rank_out,
527 data: mat_znx_at_backend_ref_from_ref::<BE>(&gglwe.data, row, col),
528 seed: gglwe.seed[rank_in * row + col],
529 }
530}