poulpy_core/layouts/compressed/
ggsw.rs

1use poulpy_hal::{
2    layouts::{
3        Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos,
4    },
5    source::Source,
6};
7
8use crate::layouts::{
9    Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, Rank, TorusPrecision,
10    compressed::{GLWECompressed, GLWEDecompress},
11};
12use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
13use std::fmt;
14
15#[derive(PartialEq, Eq, Clone)]
16pub struct GGSWCompressed<D: Data> {
17    pub(crate) data: MatZnx<D>,
18    pub(crate) k: TorusPrecision,
19    pub(crate) base2k: Base2K,
20    pub(crate) dsize: Dsize,
21    pub(crate) rank: Rank,
22    pub(crate) seed: Vec<[u8; 32]>,
23}
24
25pub trait GGSWCompressedSeedMut {
26    fn seed_mut(&mut self) -> &mut Vec<[u8; 32]>;
27}
28
29impl<D: DataMut> GGSWCompressedSeedMut for GGSWCompressed<D> {
30    fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
31        &mut self.seed
32    }
33}
34
35pub trait GGSWCompressedSeed {
36    fn seed(&self) -> &Vec<[u8; 32]>;
37}
38
39impl<D: DataRef> GGSWCompressedSeed for GGSWCompressed<D> {
40    fn seed(&self) -> &Vec<[u8; 32]> {
41        &self.seed
42    }
43}
44
45impl<D: Data> LWEInfos for GGSWCompressed<D> {
46    fn n(&self) -> Degree {
47        Degree(self.data.n() as u32)
48    }
49
50    fn base2k(&self) -> Base2K {
51        self.base2k
52    }
53
54    fn k(&self) -> TorusPrecision {
55        self.k
56    }
57    fn size(&self) -> usize {
58        self.data.size()
59    }
60}
61impl<D: Data> GLWEInfos for GGSWCompressed<D> {
62    fn rank(&self) -> Rank {
63        self.rank
64    }
65}
66
67impl<D: Data> GGSWInfos for GGSWCompressed<D> {
68    fn dsize(&self) -> Dsize {
69        self.dsize
70    }
71
72    fn dnum(&self) -> Dnum {
73        Dnum(self.data.rows() as u32)
74    }
75}
76
77impl<D: DataRef> fmt::Debug for GGSWCompressed<D> {
78    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79        write!(f, "{}", self.data)
80    }
81}
82
83impl<D: DataRef> fmt::Display for GGSWCompressed<D> {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        write!(
86            f,
87            "(GGSWCompressed: base2k={} k={} dsize={}) {}",
88            self.base2k, self.k, self.dsize, self.data
89        )
90    }
91}
92
93impl<D: DataMut> FillUniform for GGSWCompressed<D> {
94    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
95        self.data.fill_uniform(log_bound, source);
96    }
97}
98
99impl GGSWCompressed<Vec<u8>> {
100    pub fn alloc_from_infos<A>(infos: &A) -> Self
101    where
102        A: GGSWInfos,
103    {
104        Self::alloc(
105            infos.n(),
106            infos.base2k(),
107            infos.k(),
108            infos.rank(),
109            infos.dnum(),
110            infos.dsize(),
111        )
112    }
113
114    pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
115        let size: usize = k.0.div_ceil(base2k.0) as usize;
116        assert!(
117            size as u32 > dsize.0,
118            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
119            dsize.0
120        );
121
122        assert!(
123            dnum.0 * dsize.0 <= size as u32,
124            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
125            dnum.0,
126            dsize.0,
127        );
128
129        GGSWCompressed {
130            data: MatZnx::alloc(
131                n.into(),
132                dnum.into(),
133                (rank + 1).into(),
134                1,
135                k.0.div_ceil(base2k.0) as usize,
136            ),
137            k,
138            base2k,
139            dsize,
140            rank,
141            seed: vec![[0u8; 32]; dnum.as_usize() * (rank.as_usize() + 1)],
142        }
143    }
144
145    pub fn bytes_of_from_infos<A>(infos: &A) -> usize
146    where
147        A: GGSWInfos,
148    {
149        Self::bytes_of(
150            infos.n(),
151            infos.base2k(),
152            infos.k(),
153            infos.rank(),
154            infos.dnum(),
155            infos.dsize(),
156        )
157    }
158
159    pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
160        let size: usize = k.0.div_ceil(base2k.0) as usize;
161        assert!(
162            size as u32 > dsize.0,
163            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
164            dsize.0
165        );
166
167        assert!(
168            dnum.0 * dsize.0 <= size as u32,
169            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
170            dnum.0,
171            dsize.0,
172        );
173
174        MatZnx::bytes_of(
175            n.into(),
176            dnum.into(),
177            (rank + 1).into(),
178            1,
179            k.0.div_ceil(base2k.0) as usize,
180        )
181    }
182}
183
184impl<D: DataRef> GGSWCompressed<D> {
185    pub fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> {
186        let rank: usize = self.rank().into();
187        GLWECompressed {
188            data: self.data.at(row, col),
189            k: self.k,
190            base2k: self.base2k,
191            rank: self.rank,
192            seed: self.seed[row * (rank + 1) + col],
193        }
194    }
195}
196
197impl<D: DataMut> GGSWCompressed<D> {
198    pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> {
199        let rank: usize = self.rank().into();
200        GLWECompressed {
201            data: self.data.at_mut(row, col),
202            k: self.k,
203            base2k: self.base2k,
204            rank: self.rank,
205            seed: self.seed[row * (rank + 1) + col],
206        }
207    }
208}
209
210impl<D: DataMut> ReaderFrom for GGSWCompressed<D> {
211    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
212        self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
213        self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
214        self.dsize = Dsize(reader.read_u32::<LittleEndian>()?);
215        self.rank = Rank(reader.read_u32::<LittleEndian>()?);
216        let seed_len: usize = reader.read_u32::<LittleEndian>()? as usize;
217        self.seed = vec![[0u8; 32]; seed_len];
218        for s in &mut self.seed {
219            reader.read_exact(s)?;
220        }
221        self.data.read_from(reader)
222    }
223}
224
225impl<D: DataRef> WriterTo for GGSWCompressed<D> {
226    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
227        writer.write_u32::<LittleEndian>(self.k.into())?;
228        writer.write_u32::<LittleEndian>(self.base2k.into())?;
229        writer.write_u32::<LittleEndian>(self.dsize.into())?;
230        writer.write_u32::<LittleEndian>(self.rank.into())?;
231        writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
232        for s in &self.seed {
233            writer.write_all(s)?;
234        }
235        self.data.write_to(writer)
236    }
237}
238
239pub trait GGSWDecompress
240where
241    Self: GLWEDecompress,
242{
243    fn decompress_ggsw<R, O>(&self, res: &mut R, other: &O)
244    where
245        R: GGSWToMut,
246        O: GGSWCompressedToRef,
247    {
248        let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
249        let other: &GGSWCompressed<&[u8]> = &other.to_ref();
250
251        assert_eq!(res.rank(), other.rank());
252        let dnum: usize = res.dnum().into();
253        let rank: usize = res.rank().into();
254
255        for row_i in 0..dnum {
256            for col_j in 0..rank + 1 {
257                self.decompress_glwe(&mut res.at_mut(row_i, col_j), &other.at(row_i, col_j));
258            }
259        }
260    }
261}
262
263impl<B: Backend> GGSWDecompress for Module<B> where Self: GLWEDecompress {}
264
265impl<D: DataMut> GGSW<D> {
266    pub fn decompress<O, M>(&mut self, module: &M, other: &O)
267    where
268        O: GGSWCompressedToRef,
269        M: GGSWDecompress,
270    {
271        module.decompress_ggsw(self, other);
272    }
273}
274
275pub trait GGSWCompressedToMut {
276    fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]>;
277}
278
279impl<D: DataMut> GGSWCompressedToMut for GGSWCompressed<D> {
280    fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]> {
281        GGSWCompressed {
282            k: self.k(),
283            base2k: self.base2k(),
284            dsize: self.dsize(),
285            rank: self.rank(),
286            seed: self.seed.clone(),
287            data: self.data.to_mut(),
288        }
289    }
290}
291
292pub trait GGSWCompressedToRef {
293    fn to_ref(&self) -> GGSWCompressed<&[u8]>;
294}
295
296impl<D: DataRef> GGSWCompressedToRef for GGSWCompressed<D> {
297    fn to_ref(&self) -> GGSWCompressed<&[u8]> {
298        GGSWCompressed {
299            k: self.k(),
300            base2k: self.base2k(),
301            dsize: self.dsize(),
302            rank: self.rank(),
303            seed: self.seed.clone(),
304            data: self.data.to_ref(),
305        }
306    }
307}