Skip to main content

poulpy_core/layouts/compressed/
ggsw.rs

1use poulpy_hal::{
2    layouts::{
3        Backend, Data, FillUniform, HostDataMut, HostDataRef, MatZnx, MatZnxToBackendMut, MatZnxToBackendRef, Module, ReaderFrom,
4        WriterTo, mat_znx_at_backend_mut_from_mut, mat_znx_at_backend_ref_from_ref, mat_znx_backend_mut_from_mut,
5        mat_znx_backend_ref_from_mut,
6    },
7    source::Source,
8};
9
10use crate::layouts::{
11    Base2K, Degree, Dnum, Dsize, GGSWInfos, GGSWToBackendMut, GLWEInfos, LWEInfos, Rank, TorusPrecision,
12    compressed::{
13        GLWECompressed, GLWECompressedBackendMut, GLWECompressedBackendRef, GLWECompressedViewMut, GLWECompressedViewRef,
14        GLWEDecompress,
15    },
16};
17use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
18use std::{
19    fmt,
20    ops::{Deref, DerefMut},
21};
22
23/// Seed-compressed GGSW (gadget GSW) ciphertext layout.
24///
25/// Stores only the body components of a [`GGSW`] ciphertext; the mask
26/// polynomials are regenerated deterministically from 32-byte PRNG
27/// seeds during decompression.
28#[derive(PartialEq, Eq, Clone)]
29pub struct GGSWCompressed<D: Data> {
30    pub(crate) data: MatZnx<D>,
31    pub(crate) k: TorusPrecision,
32    pub(crate) base2k: Base2K,
33    pub(crate) dsize: Dsize,
34    pub(crate) rank: Rank,
35    pub(crate) seed: Vec<[u8; 32]>,
36}
37
38pub struct GGSWCompressedBackendRef<'a, BE: Backend + 'a> {
39    inner: GGSWCompressed<BE::BufRef<'a>>,
40}
41
42impl<'a, BE: Backend + 'a> GGSWCompressedBackendRef<'a, BE> {
43    pub fn from_inner(inner: GGSWCompressed<BE::BufRef<'a>>) -> Self {
44        Self { inner }
45    }
46
47    pub fn into_inner(self) -> GGSWCompressed<BE::BufRef<'a>> {
48        self.inner
49    }
50
51    pub fn at_view(&self, row: usize, col: usize) -> GLWECompressedViewRef<'_, BE> {
52        GLWECompressedViewRef::from_inner(ggsw_compressed_at_backend_ref_from_ref::<BE>(&self.inner, row, col))
53    }
54}
55
56impl<'a, BE: Backend + 'a> Deref for GGSWCompressedBackendRef<'a, BE> {
57    type Target = GGSWCompressed<BE::BufRef<'a>>;
58
59    fn deref(&self) -> &Self::Target {
60        &self.inner
61    }
62}
63
64pub struct GGSWCompressedBackendMut<'a, BE: Backend + 'a> {
65    inner: GGSWCompressed<BE::BufMut<'a>>,
66}
67
68impl<'a, BE: Backend + 'a> GGSWCompressedBackendMut<'a, BE> {
69    pub fn from_inner(inner: GGSWCompressed<BE::BufMut<'a>>) -> Self {
70        Self { inner }
71    }
72
73    pub fn into_inner(self) -> GGSWCompressed<BE::BufMut<'a>> {
74        self.inner
75    }
76
77    pub fn at_view_mut(&mut self, row: usize, col: usize) -> GLWECompressedViewMut<'_, BE> {
78        GLWECompressedViewMut::from_inner(ggsw_compressed_at_backend_mut_from_mut::<BE>(&mut self.inner, row, col))
79    }
80}
81
82impl<'a, BE: Backend + 'a> Deref for GGSWCompressedBackendMut<'a, BE> {
83    type Target = GGSWCompressed<BE::BufMut<'a>>;
84
85    fn deref(&self) -> &Self::Target {
86        &self.inner
87    }
88}
89
90impl<'a, BE: Backend + 'a> DerefMut for GGSWCompressedBackendMut<'a, BE> {
91    fn deref_mut(&mut self) -> &mut Self::Target {
92        &mut self.inner
93    }
94}
95
96impl<'a, BE: Backend + 'a> LWEInfos for GGSWCompressedBackendRef<'a, BE> {
97    fn base2k(&self) -> Base2K {
98        self.inner.base2k()
99    }
100
101    fn n(&self) -> Degree {
102        self.inner.n()
103    }
104
105    fn size(&self) -> usize {
106        self.inner.size()
107    }
108}
109
110impl<'a, BE: Backend + 'a> GLWEInfos for GGSWCompressedBackendRef<'a, BE> {
111    fn rank(&self) -> Rank {
112        self.inner.rank()
113    }
114}
115
116impl<'a, BE: Backend + 'a> GGSWInfos for GGSWCompressedBackendRef<'a, BE> {
117    fn dnum(&self) -> Dnum {
118        self.inner.dnum()
119    }
120
121    fn dsize(&self) -> Dsize {
122        self.inner.dsize()
123    }
124}
125
126impl<'a, BE: Backend + 'a> LWEInfos for GGSWCompressedBackendMut<'a, BE> {
127    fn base2k(&self) -> Base2K {
128        self.inner.base2k()
129    }
130
131    fn n(&self) -> Degree {
132        self.inner.n()
133    }
134
135    fn size(&self) -> usize {
136        self.inner.size()
137    }
138}
139
140impl<'a, BE: Backend + 'a> GLWEInfos for GGSWCompressedBackendMut<'a, BE> {
141    fn rank(&self) -> Rank {
142        self.inner.rank()
143    }
144}
145
146impl<'a, BE: Backend + 'a> GGSWInfos for GGSWCompressedBackendMut<'a, BE> {
147    fn dnum(&self) -> Dnum {
148        self.inner.dnum()
149    }
150
151    fn dsize(&self) -> Dsize {
152        self.inner.dsize()
153    }
154}
155
156impl<'a, BE: Backend + 'a> GGSWCompressedSeedMut for GGSWCompressedBackendMut<'a, BE> {
157    fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
158        &mut self.inner.seed
159    }
160}
161
162/// Provides mutable access to the PRNG seeds of a compressed GGSW.
163pub trait GGSWCompressedSeedMut {
164    /// Returns a mutable reference to the vector of 32-byte PRNG seeds.
165    fn seed_mut(&mut self) -> &mut Vec<[u8; 32]>;
166}
167
168impl<D: Data> GGSWCompressedSeedMut for GGSWCompressed<D> {
169    fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
170        &mut self.seed
171    }
172}
173
174/// Provides read access to the PRNG seeds of a compressed GGSW.
175pub trait GGSWCompressedSeed {
176    /// Returns a reference to the vector of 32-byte PRNG seeds.
177    fn seed(&self) -> &Vec<[u8; 32]>;
178}
179
180impl<D: HostDataRef> GGSWCompressedSeed for GGSWCompressed<D> {
181    fn seed(&self) -> &Vec<[u8; 32]> {
182        &self.seed
183    }
184}
185
186impl<D: Data> LWEInfos for GGSWCompressed<D> {
187    fn n(&self) -> Degree {
188        Degree(self.data.n() as u32)
189    }
190
191    fn base2k(&self) -> Base2K {
192        self.base2k
193    }
194
195    fn size(&self) -> usize {
196        self.data.size()
197    }
198}
199impl<D: Data> GLWEInfos for GGSWCompressed<D> {
200    fn rank(&self) -> Rank {
201        self.rank
202    }
203}
204
205impl<D: Data> GGSWInfos for GGSWCompressed<D> {
206    fn dsize(&self) -> Dsize {
207        self.dsize
208    }
209
210    fn dnum(&self) -> Dnum {
211        Dnum(self.data.rows() as u32)
212    }
213}
214
215impl<D: Data> LWEInfos for &GGSWCompressed<D> {
216    fn n(&self) -> Degree {
217        (**self).n()
218    }
219
220    fn base2k(&self) -> Base2K {
221        (**self).base2k()
222    }
223
224    fn size(&self) -> usize {
225        (**self).size()
226    }
227}
228
229impl<D: Data> GLWEInfos for &GGSWCompressed<D> {
230    fn rank(&self) -> Rank {
231        (**self).rank()
232    }
233}
234
235impl<D: Data> GGSWInfos for &GGSWCompressed<D> {
236    fn dsize(&self) -> Dsize {
237        (**self).dsize()
238    }
239
240    fn dnum(&self) -> Dnum {
241        (**self).dnum()
242    }
243}
244
245impl<D: HostDataRef> fmt::Debug for GGSWCompressed<D> {
246    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
247        write!(f, "{}", self.data)
248    }
249}
250
251impl<D: HostDataRef> fmt::Display for GGSWCompressed<D> {
252    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
253        write!(
254            f,
255            "(GGSWCompressed: base2k={} k={} dsize={}) {}",
256            self.base2k, self.k, self.dsize, self.data
257        )
258    }
259}
260
261impl<D: HostDataMut> FillUniform for GGSWCompressed<D> {
262    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
263        self.data.fill_uniform(log_bound, source);
264    }
265}
266
267impl GGSWCompressed<Vec<u8>> {
268    /// Allocates a new compressed GGSW by copying parameters from an existing info provider.
269    pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
270    where
271        A: GGSWInfos,
272    {
273        Self::alloc(
274            infos.n(),
275            infos.base2k(),
276            infos.max_k(),
277            infos.rank(),
278            infos.dnum(),
279            infos.dsize(),
280        )
281    }
282
283    /// Allocates a new compressed GGSW with the given parameters.
284    pub(crate) fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
285        let size: usize = k.0.div_ceil(base2k.0) as usize;
286        assert!(
287            size as u32 > dsize.0,
288            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
289            dsize.0
290        );
291
292        assert!(
293            dnum.0 * dsize.0 <= size as u32,
294            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
295            dnum.0,
296            dsize.0,
297        );
298
299        GGSWCompressed {
300            data: MatZnx::from_data(
301                poulpy_hal::layouts::HostBytesBackend::alloc_bytes(MatZnx::<Vec<u8>>::bytes_of(
302                    n.into(),
303                    dnum.into(),
304                    (rank + 1).into(),
305                    1,
306                    size,
307                )),
308                n.into(),
309                dnum.into(),
310                (rank + 1).into(),
311                1,
312                size,
313            ),
314            k,
315            base2k,
316            dsize,
317            rank,
318            seed: vec![[0u8; 32]; dnum.as_usize() * (rank.as_usize() + 1)],
319        }
320    }
321
322    /// Returns the serialized byte size by copying parameters from an existing info provider.
323    pub fn bytes_of_from_infos<A>(infos: &A) -> usize
324    where
325        A: GGSWInfos,
326    {
327        Self::bytes_of(
328            infos.n(),
329            infos.base2k(),
330            infos.max_k(),
331            infos.rank(),
332            infos.dnum(),
333            infos.dsize(),
334        )
335    }
336
337    /// Returns the serialized byte size for a compressed GGSW with the given parameters.
338    pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
339        let size: usize = k.0.div_ceil(base2k.0) as usize;
340        assert!(
341            size as u32 > dsize.0,
342            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
343            dsize.0
344        );
345
346        assert!(
347            dnum.0 * dsize.0 <= size as u32,
348            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
349            dnum.0,
350            dsize.0,
351        );
352
353        MatZnx::bytes_of(n.into(), dnum.into(), (rank + 1).into(), 1, k.0.div_ceil(base2k.0) as usize)
354    }
355}
356
357impl<D: HostDataRef> GGSWCompressed<D> {
358    /// Returns an immutably-borrowed compressed GLWE at the given row and column.
359    pub fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> {
360        let rank: usize = self.rank().into();
361        GLWECompressed {
362            data: self.data.at(row, col),
363            base2k: self.base2k,
364            rank: self.rank,
365            seed: self.seed[row * (rank + 1) + col],
366        }
367    }
368}
369
370impl<D: HostDataMut> GGSWCompressed<D> {
371    /// Returns a mutably-borrowed compressed GLWE at the given row and column.
372    pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> {
373        let rank: usize = self.rank().into();
374        GLWECompressed {
375            data: self.data.at_mut(row, col),
376            base2k: self.base2k,
377            rank: self.rank,
378            seed: self.seed[row * (rank + 1) + col],
379        }
380    }
381}
382
383impl<D: HostDataMut> ReaderFrom for GGSWCompressed<D> {
384    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
385        self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
386        self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
387        self.dsize = Dsize(reader.read_u32::<LittleEndian>()?);
388        self.rank = Rank(reader.read_u32::<LittleEndian>()?);
389        let seed_len: usize = reader.read_u32::<LittleEndian>()? as usize;
390        self.seed = vec![[0u8; 32]; seed_len];
391        for s in &mut self.seed {
392            reader.read_exact(s)?;
393        }
394        self.data.read_from(reader)
395    }
396}
397
398impl<D: HostDataRef> WriterTo for GGSWCompressed<D> {
399    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
400        writer.write_u32::<LittleEndian>(self.k.into())?;
401        writer.write_u32::<LittleEndian>(self.base2k.into())?;
402        writer.write_u32::<LittleEndian>(self.dsize.into())?;
403        writer.write_u32::<LittleEndian>(self.rank.into())?;
404        writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
405        for s in &self.seed {
406            writer.write_all(s)?;
407        }
408        self.data.write_to(writer)
409    }
410}
411
412/// Trait for decompressing a [`GGSWCompressed`] into a standard [`GGSW`].
413///
414/// Iterates over every (row, column) entry, decompressing each
415/// compressed GLWE individually via [`GLWEDecompress`].
416pub trait GGSWDecompress
417where
418    Self: GLWEDecompress,
419{
420    /// Decompresses `other` into `res`.
421    fn decompress_ggsw<R, O>(&self, res: &mut R, other: &O)
422    where
423        R: GGSWToBackendMut<Self::Backend> + GGSWInfos,
424        O: GGSWCompressedToBackendRef<Self::Backend> + GGSWInfos,
425    {
426        let mut res = res.to_backend_mut();
427        let other = other.to_backend_ref();
428
429        assert_eq!(res.rank(), other.rank());
430        let dnum: usize = res.dnum().into();
431        let rank: usize = res.rank().into();
432
433        for row_i in 0..dnum {
434            for col_j in 0..rank + 1 {
435                let mut dst = res.at_view_mut(row_i, col_j);
436                let src = other.at_view(row_i, col_j);
437                self.decompress_glwe(&mut dst, &src);
438            }
439        }
440    }
441}
442
443impl<B: Backend> GGSWDecompress for Module<B> where Self: GLWEDecompress {}
444
445// module-only API: decompression is provided by `GGSWDecompress` on `Module`.
446
447pub trait GGSWCompressedToBackendRef<BE: Backend> {
448    fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE>;
449}
450
451impl<BE: Backend> GGSWCompressedToBackendRef<BE> for GGSWCompressed<BE::OwnedBuf> {
452    fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE> {
453        GGSWCompressedBackendRef::from_inner(GGSWCompressed {
454            k: self.max_k(),
455            base2k: self.base2k(),
456            dsize: self.dsize(),
457            rank: self.rank(),
458            seed: self.seed.clone(),
459            data: <MatZnx<BE::OwnedBuf> as MatZnxToBackendRef<BE>>::to_backend_ref(&self.data),
460        })
461    }
462}
463
464impl<'b, BE: Backend + 'b> GGSWCompressedToBackendRef<BE> for &GGSWCompressed<BE::BufRef<'b>> {
465    fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE> {
466        GGSWCompressedBackendRef::from_inner(GGSWCompressed {
467            k: self.max_k(),
468            base2k: self.base2k(),
469            dsize: self.dsize(),
470            rank: self.rank(),
471            seed: self.seed.clone(),
472            data: poulpy_hal::layouts::mat_znx_backend_ref_from_ref::<BE>(&self.data),
473        })
474    }
475}
476
477impl<'b, BE: Backend + 'b> GGSWCompressedToBackendRef<BE> for &mut GGSWCompressed<BE::BufMut<'b>> {
478    fn to_backend_ref(&self) -> GGSWCompressedBackendRef<'_, BE> {
479        GGSWCompressedBackendRef::from_inner(GGSWCompressed {
480            k: self.max_k(),
481            base2k: self.base2k(),
482            dsize: self.dsize(),
483            rank: self.rank(),
484            seed: self.seed.clone(),
485            data: mat_znx_backend_ref_from_mut::<BE>(&self.data),
486        })
487    }
488}
489
490pub trait GGSWCompressedToBackendMut<BE: Backend>: GGSWCompressedToBackendRef<BE> {
491    fn to_backend_mut(&mut self) -> GGSWCompressedBackendMut<'_, BE>;
492}
493
494impl<BE: Backend> GGSWCompressedToBackendMut<BE> for GGSWCompressed<BE::OwnedBuf> {
495    fn to_backend_mut(&mut self) -> GGSWCompressedBackendMut<'_, BE> {
496        GGSWCompressedBackendMut::from_inner(GGSWCompressed {
497            k: self.max_k(),
498            base2k: self.base2k(),
499            dsize: self.dsize(),
500            rank: self.rank(),
501            seed: self.seed.clone(),
502            data: <MatZnx<BE::OwnedBuf> as MatZnxToBackendMut<BE>>::to_backend_mut(&mut self.data),
503        })
504    }
505}
506
507impl<'b, BE: Backend + 'b> GGSWCompressedToBackendMut<BE> for &mut GGSWCompressed<BE::BufMut<'b>> {
508    fn to_backend_mut(&mut self) -> GGSWCompressedBackendMut<'_, BE> {
509        GGSWCompressedBackendMut::from_inner(GGSWCompressed {
510            k: self.max_k(),
511            base2k: self.base2k(),
512            dsize: self.dsize(),
513            rank: self.rank(),
514            seed: self.seed.clone(),
515            data: mat_znx_backend_mut_from_mut::<BE>(&mut self.data),
516        })
517    }
518}
519
520fn ggsw_compressed_at_backend_mut_from_mut<'a, 'b, BE: Backend>(
521    ggsw: &'a mut GGSWCompressed<BE::BufMut<'b>>,
522    row: usize,
523    col: usize,
524) -> GLWECompressedBackendMut<'a, BE> {
525    let rank: usize = ggsw.rank().into();
526    GLWECompressed {
527        data: mat_znx_at_backend_mut_from_mut::<BE>(&mut ggsw.data, row, col),
528        base2k: ggsw.base2k,
529        rank: ggsw.rank,
530        seed: ggsw.seed[row * (rank + 1) + col],
531    }
532}
533
534fn ggsw_compressed_at_backend_ref_from_ref<'a, 'b, BE: Backend>(
535    ggsw: &'a GGSWCompressed<BE::BufRef<'b>>,
536    row: usize,
537    col: usize,
538) -> GLWECompressedBackendRef<'a, BE> {
539    let rank: usize = ggsw.rank().into();
540    GLWECompressed {
541        data: mat_znx_at_backend_ref_from_ref::<BE>(&ggsw.data, row, col),
542        base2k: ggsw.base2k,
543        rank: ggsw.rank,
544        seed: ggsw.seed[row * (rank + 1) + col],
545    }
546}