1use poulpy_hal::{
2 layouts::{
3 Backend, Data, FillUniform, HostDataMut, HostDataRef, MatZnx, MatZnxToBackendMut, MatZnxToBackendRef, Module, ReaderFrom,
4 WriterTo, mat_znx_at_backend_mut_from_mut, mat_znx_at_backend_ref_from_ref, mat_znx_backend_mut_from_mut,
5 mat_znx_backend_ref_from_mut,
6 },
7 source::Source,
8};
9
10use crate::layouts::{
11 Base2K, Degree, Dnum, Dsize, GGSWInfos, GGSWToBackendMut, GLWEInfos, LWEInfos, Rank, TorusPrecision,
12 compressed::{
13 GLWECompressed, GLWECompressedBackendMut, GLWECompressedBackendRef, GLWECompressedViewMut, GLWECompressedViewRef,
14 GLWEDecompress,
15 },
16};
17use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
18use std::{
19 fmt,
20 ops::{Deref, DerefMut},
21};
22
23#[derive(PartialEq, Eq, Clone)]
29pub struct GGSWCompressed<D: Data> {
30 pub(crate) data: MatZnx<D>,
31 pub(crate) k: TorusPrecision,
32 pub(crate) base2k: Base2K,
33 pub(crate) dsize: Dsize,
34 pub(crate) rank: Rank,
35 pub(crate) seed: Vec<[u8; 32]>,
36}
37
38pub struct GGSWCompressedBackendRef<'a, BE: Backend + 'a> {
39 inner: GGSWCompressed<BE::BufRef<'a>>,
40}
41
42impl<'a, BE: Backend + 'a> GGSWCompressedBackendRef<'a, BE> {
43 pub fn from_inner(inner: GGSWCompressed<BE::BufRef<'a>>) -> Self {
44 Self { inner }
45 }
46
47 pub fn into_inner(self) -> GGSWCompressed<BE::BufRef<'a>> {
48 self.inner
49 }
50
51 pub fn at_view(&self, row: usize, col: usize) -> GLWECompressedViewRef<'_, BE> {
52 GLWECompressedViewRef::from_inner(ggsw_compressed_at_backend_ref_from_ref::<BE>(&self.inner, row, col))
53 }
54}
55
56impl<'a, BE: Backend + 'a> Deref for GGSWCompressedBackendRef<'a, BE> {
57 type Target = GGSWCompressed<BE::BufRef<'a>>;
58
59 fn deref(&self) -> &Self::Target {
60 &self.inner
61 }
62}
63
64pub struct GGSWCompressedBackendMut<'a, BE: Backend + 'a> {
65 inner: GGSWCompressed<BE::BufMut<'a>>,
66}
67
68impl<'a, BE: Backend + 'a> GGSWCompressedBackendMut<'a, BE> {
69 pub fn from_inner(inner: GGSWCompressed<BE::BufMut<'a>>) -> Self {
70 Self { inner }
71 }
72
73 pub fn into_inner(self) -> GGSWCompressed<BE::BufMut<'a>> {
74 self.inner
75 }
76
77 pub fn at_view_mut(&mut self, row: usize, col: usize) -> GLWECompressedViewMut<'_, BE> {
78 GLWECompressedViewMut::from_inner(ggsw_compressed_at_backend_mut_from_mut::<BE>(&mut self.inner, row, col))
79 }
80}
81
82impl<'a, BE: Backend + 'a> Deref for GGSWCompressedBackendMut<'a, BE> {
83 type Target = GGSWCompressed<BE::BufMut<'a>>;
84
85 fn deref(&self) -> &Self::Target {
86 &self.inner
87 }
88}
89
90impl<'a, BE: Backend + 'a> DerefMut for GGSWCompressedBackendMut<'a, BE> {
91 fn deref_mut(&mut self) -> &mut Self::Target {
92 &mut self.inner
93 }
94}
95
96impl<'a, BE: Backend + 'a> LWEInfos for GGSWCompressedBackendRef<'a, BE> {
97 fn base2k(&self) -> Base2K {
98 self.inner.base2k()
99 }
100
101 fn n(&self) -> Degree {
102 self.inner.n()
103 }
104
105 fn size(&self) -> usize {
106 self.inner.size()
107 }
108}
109
110impl<'a, BE: Backend + 'a> GLWEInfos for GGSWCompressedBackendRef<'a, BE> {
111 fn rank(&self) -> Rank {
112 self.inner.rank()
113 }
114}
115
116impl<'a, BE: Backend + 'a> GGSWInfos for GGSWCompressedBackendRef<'a, BE> {
117 fn dnum(&self) -> Dnum {
118 self.inner.dnum()
119 }
120
121 fn dsize(&self) -> Dsize {
122 self.inner.dsize()
123 }
124}
125
126impl<'a, BE: Backend + 'a> LWEInfos for GGSWCompressedBackendMut<'a, BE> {
127 fn base2k(&self) -> Base2K {
128 self.inner.base2k()
129 }
130
131 fn n(&self) -> Degree {
132 self.inner.n()
133 }
134
135 fn size(&self) -> usize {
136 self.inner.size()
137 }
138}
139
140impl<'a, BE: Backend + 'a> GLWEInfos for GGSWCompressedBackendMut<'a, BE> {
141 fn rank(&self) -> Rank {
142 self.inner.rank()
143 }
144}
145
146impl<'a, BE: Backend + 'a> GGSWInfos for GGSWCompressedBackendMut<'a, BE> {
147 fn dnum(&self) -> Dnum {
148 self.inner.dnum()
149 }
150
151 fn dsize(&self) -> Dsize {
152 self.inner.dsize()
153 }
154}
155
156impl<'a, BE: Backend + 'a> GGSWCompressedSeedMut for GGSWCompressedBackendMut<'a, BE> {
157 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
158 &mut self.inner.seed
159 }
160}
161
162pub trait GGSWCompressedSeedMut {
164 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]>;
166}
167
168impl<D: Data> GGSWCompressedSeedMut for GGSWCompressed<D> {
169 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
170 &mut self.seed
171 }
172}
173
174pub trait GGSWCompressedSeed {
176 fn seed(&self) -> &Vec<[u8; 32]>;
178}
179
180impl<D: HostDataRef> GGSWCompressedSeed for GGSWCompressed<D> {
181 fn seed(&self) -> &Vec<[u8; 32]> {
182 &self.seed
183 }
184}
185
186impl<D: Data> LWEInfos for GGSWCompressed<D> {
187 fn n(&self) -> Degree {
188 Degree(self.data.n() as u32)
189 }
190
191 fn base2k(&self) -> Base2K {
192 self.base2k
193 }
194
195 fn size(&self) -> usize {
196 self.data.size()
197 }
198}
199impl<D: Data> GLWEInfos for GGSWCompressed<D> {
200 fn rank(&self) -> Rank {
201 self.rank
202 }
203}
204
205impl<D: Data> GGSWInfos for GGSWCompressed<D> {
206 fn dsize(&self) -> Dsize {
207 self.dsize
208 }
209
210 fn dnum(&self) -> Dnum {
211 Dnum(self.data.rows() as u32)
212 }
213}
214
215impl<D: Data> LWEInfos for &GGSWCompressed<D> {
216 fn n(&self) -> Degree {
217 (**self).n()
218 }
219
220 fn base2k(&self) -> Base2K {
221 (**self).base2k()
222 }
223
224 fn size(&self) -> usize {
225 (**self).size()
226 }
227}
228
229impl<D: Data> GLWEInfos for &GGSWCompressed<D> {
230 fn rank(&self) -> Rank {
231 (**self).rank()
232 }
233}
234
235impl<D: Data> GGSWInfos for &GGSWCompressed<D> {
236 fn dsize(&self) -> Dsize {
237 (**self).dsize()
238 }
239
240 fn dnum(&self) -> Dnum {
241 (**self).dnum()
242 }
243}
244
245impl<D: HostDataRef> fmt::Debug for GGSWCompressed<D> {
246 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
247 write!(f, "{}", self.data)
248 }
249}
250
251impl<D: HostDataRef> fmt::Display for GGSWCompressed<D> {
252 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253 write!(
254 f,
255 "(GGSWCompressed: base2k={} k={} dsize={}) {}",
256 self.base2k, self.k, self.dsize, self.data
257 )
258 }
259}
260
261impl<D: HostDataMut> FillUniform for GGSWCompressed<D> {
262 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
263 self.data.fill_uniform(log_bound, source);
264 }
265}
266
267impl GGSWCompressed<Vec<u8>> {
268 pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
270 where
271 A: GGSWInfos,
272 {
273 Self::alloc(
274 infos.n(),
275 infos.base2k(),
276 infos.max_k(),
277 infos.rank(),
278 infos.dnum(),
279 infos.dsize(),
280 )
281 }
282
283 pub(crate) fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
285 let size: usize = k.0.div_ceil(base2k.0) as usize;
286 assert!(
287 size as u32 > dsize.0,
288 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
289 dsize.0
290 );
291
292 assert!(
293 dnum.0 * dsize.0 <= size as u32,
294 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
295 dnum.0,
296 dsize.0,
297 );
298
299 GGSWCompressed {
300 data: MatZnx::from_data(
301 poulpy_hal::layouts::HostBytesBackend::alloc_bytes(MatZnx::<Vec<u8>>::bytes_of(
302 n.into(),
303 dnum.into(),
304 (rank + 1).into(),
305 1,
306 size,
307 )),
308 n.into(),
309 dnum.into(),
310 (rank + 1).into(),
311 1,
312 size,
313 ),
314 k,
315 base2k,
316 dsize,
317 rank,
318 seed: vec![[0u8; 32]; dnum.as_usize() * (rank.as_usize() + 1)],
319 }
320 }
321
322 pub fn bytes_of_from_infos<A>(infos: &A) -> usize
324 where
325 A: GGSWInfos,
326 {
327 Self::bytes_of(
328 infos.n(),
329 infos.base2k(),
330 infos.max_k(),
331 infos.rank(),
332 infos.dnum(),
333 infos.dsize(),
334 )
335 }
336
337 pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
339 let size: usize = k.0.div_ceil(base2k.0) as usize;
340 assert!(
341 size as u32 > dsize.0,
342 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
343 dsize.0
344 );
345
346 assert!(
347 dnum.0 * dsize.0 <= size as u32,
348 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
349 dnum.0,
350 dsize.0,
351 );
352
353 MatZnx::bytes_of(n.into(), dnum.into(), (rank + 1).into(), 1, k.0.div_ceil(base2k.0) as usize)
354 }
355}
356
357impl<D: HostDataRef> GGSWCompressed<D> {
358 pub fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> {
360 let rank: usize = self.rank().into();
361 GLWECompressed {
362 data: self.data.at(row, col),
363 base2k: self.base2k,
364 rank: self.rank,
365 seed: self.seed[row * (rank + 1) + col],
366 }
367 }
368}
369
370impl<D: HostDataMut> GGSWCompressed<D> {
371 pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> {
373 let rank: usize = self.rank().into();
374 GLWECompressed {
375 data: self.data.at_mut(row, col),
376 base2k: self.base2k,
377 rank: self.rank,
378 seed: self.seed[row * (rank + 1) + col],
379 }
380 }
381}
382
383impl<D: HostDataMut> ReaderFrom for GGSWCompressed<D> {
384 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
385 self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
386 self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
387 self.dsize = Dsize(reader.read_u32::<LittleEndian>()?);
388 self.rank = Rank(reader.read_u32::<LittleEndian>()?);
389 let seed_len: usize = reader.read_u32::<LittleEndian>()? as usize;
390 self.seed = vec![[0u8; 32]; seed_len];
391 for s in &mut self.seed {
392 reader.read_exact(s)?;
393 }
394 self.data.read_from(reader)
395 }
396}
397
398impl<D: HostDataRef> WriterTo for GGSWCompressed<D> {
399 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
400 writer.write_u32::<LittleEndian>(self.k.into())?;
401 writer.write_u32::<LittleEndian>(self.base2k.into())?;
402 writer.write_u32::<LittleEndian>(self.dsize.into())?;
403 writer.write_u32::<LittleEndian>(self.rank.into())?;
404 writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
405 for s in &self.seed {
406 writer.write_all(s)?;
407 }
408 self.data.write_to(writer)
409 }
410}
411
412pub trait GGSWDecompress
417where
418 Self: GLWEDecompress,
419{
420 fn decompress_ggsw<R, O>(&self, res: &mut R, other: &O)
422 where
423 R: GGSWToBackendMut<Self::Backend> + GGSWInfos,
424 O: GGSWCompressedToBackendRef<Self::Backend> + GGSWInfos,
425 {
426 let mut res = res.to_backend_mut();
427 let other = other.to_backend_ref();
428
429 assert_eq!(res.rank(), other.rank());
430 let dnum: usize = res.dnum().into();
431 let rank: usize = res.rank().into();
432
433 for row_i in 0..dnum {
434 for col_j in 0..rank + 1 {
435 let mut dst = res.at_view_mut(row_i, col_j);
436 let src = other.at_view(row_i, col_j);
437 self.decompress_glwe(&mut dst, &src);
438 }
439 }
440 }
441}
442
443impl<B: Backend> GGSWDecompress for Module<B> where Self: GLWEDecompress {}
444
445pub trait GGSWCompressedToBackendRef<BE: Backend> {
448 fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE>;
449}
450
451impl<BE: Backend> GGSWCompressedToBackendRef<BE> for GGSWCompressed<BE::OwnedBuf> {
452 fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE> {
453 GGSWCompressedBackendRef::from_inner(GGSWCompressed {
454 k: self.max_k(),
455 base2k: self.base2k(),
456 dsize: self.dsize(),
457 rank: self.rank(),
458 seed: self.seed.clone(),
459 data: <MatZnx<BE::OwnedBuf> as MatZnxToBackendRef<BE>>::to_backend_ref(&self.data),
460 })
461 }
462}
463
464impl<'b, BE: Backend + 'b> GGSWCompressedToBackendRef<BE> for &GGSWCompressed<BE::BufRef<'b>> {
465 fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE> {
466 GGSWCompressedBackendRef::from_inner(GGSWCompressed {
467 k: self.max_k(),
468 base2k: self.base2k(),
469 dsize: self.dsize(),
470 rank: self.rank(),
471 seed: self.seed.clone(),
472 data: poulpy_hal::layouts::mat_znx_backend_ref_from_ref::<BE>(&self.data),
473 })
474 }
475}
476
477impl<'b, BE: Backend + 'b> GGSWCompressedToBackendRef<BE> for &mut GGSWCompressed<BE::BufMut<'b>> {
478 fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE> {
479 GGSWCompressedBackendRef::from_inner(GGSWCompressed {
480 k: self.max_k(),
481 base2k: self.base2k(),
482 dsize: self.dsize(),
483 rank: self.rank(),
484 seed: self.seed.clone(),
485 data: mat_znx_backend_ref_from_mut::<BE>(&self.data),
486 })
487 }
488}
489
490pub trait GGSWCompressedToBackendMut<BE: Backend>: GGSWCompressedToBackendRef<BE> {
491 fn to_backend_mut(&mut self) -> GGSWCompressedBackendMut<'_, BE>;
492}
493
494impl<BE: Backend> GGSWCompressedToBackendMut<BE> for GGSWCompressed<BE::OwnedBuf> {
495 fn to_backend_mut(&mut self) -> GGSWCompressedBackendMut<'_, BE> {
496 GGSWCompressedBackendMut::from_inner(GGSWCompressed {
497 k: self.max_k(),
498 base2k: self.base2k(),
499 dsize: self.dsize(),
500 rank: self.rank(),
501 seed: self.seed.clone(),
502 data: <MatZnx<BE::OwnedBuf> as MatZnxToBackendMut<BE>>::to_backend_mut(&mut self.data),
503 })
504 }
505}
506
507impl<'b, BE: Backend + 'b> GGSWCompressedToBackendMut<BE> for &mut GGSWCompressed<BE::BufMut<'b>> {
508 fn to_backend_mut(&mut self) -> GGSWCompressedBackendMut<'_, BE> {
509 GGSWCompressedBackendMut::from_inner(GGSWCompressed {
510 k: self.max_k(),
511 base2k: self.base2k(),
512 dsize: self.dsize(),
513 rank: self.rank(),
514 seed: self.seed.clone(),
515 data: mat_znx_backend_mut_from_mut::<BE>(&mut self.data),
516 })
517 }
518}
519
520fn ggsw_compressed_at_backend_mut_from_mut<'a, 'b, BE: Backend>(
521 ggsw: &'a mut GGSWCompressed<BE::BufMut<'b>>,
522 row: usize,
523 col: usize,
524) -> GLWECompressedBackendMut<'a, BE> {
525 let rank: usize = ggsw.rank().into();
526 GLWECompressed {
527 data: mat_znx_at_backend_mut_from_mut::<BE>(&mut ggsw.data, row, col),
528 base2k: ggsw.base2k,
529 rank: ggsw.rank,
530 seed: ggsw.seed[row * (rank + 1) + col],
531 }
532}
533
534fn ggsw_compressed_at_backend_ref_from_ref<'a, 'b, BE: Backend>(
535 ggsw: &'a GGSWCompressed<BE::BufRef<'b>>,
536 row: usize,
537 col: usize,
538) -> GLWECompressedBackendRef<'a, BE> {
539 let rank: usize = ggsw.rank().into();
540 GLWECompressed {
541 data: mat_znx_at_backend_ref_from_ref::<BE>(&ggsw.data, row, col),
542 base2k: ggsw.base2k,
543 rank: ggsw.rank,
544 seed: ggsw.seed[row * (rank + 1) + col],
545 }
546}