poulpy_core/layouts/compressed/
ggsw.rs1use poulpy_hal::{
2 layouts::{
3 Backend, Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, Module, ReaderFrom, WriterTo, ZnxInfos,
4 },
5 source::Source,
6};
7
8use crate::layouts::{
9 Base2K, Degree, Dnum, Dsize, GGSW, GGSWInfos, GGSWToMut, GLWEInfos, LWEInfos, Rank, TorusPrecision,
10 compressed::{GLWECompressed, GLWEDecompress},
11};
12use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
13use std::fmt;
14
15#[derive(PartialEq, Eq, Clone)]
16pub struct GGSWCompressed<D: Data> {
17 pub(crate) data: MatZnx<D>,
18 pub(crate) k: TorusPrecision,
19 pub(crate) base2k: Base2K,
20 pub(crate) dsize: Dsize,
21 pub(crate) rank: Rank,
22 pub(crate) seed: Vec<[u8; 32]>,
23}
24
25pub trait GGSWCompressedSeedMut {
26 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]>;
27}
28
29impl<D: DataMut> GGSWCompressedSeedMut for GGSWCompressed<D> {
30 fn seed_mut(&mut self) -> &mut Vec<[u8; 32]> {
31 &mut self.seed
32 }
33}
34
35pub trait GGSWCompressedSeed {
36 fn seed(&self) -> &Vec<[u8; 32]>;
37}
38
39impl<D: DataRef> GGSWCompressedSeed for GGSWCompressed<D> {
40 fn seed(&self) -> &Vec<[u8; 32]> {
41 &self.seed
42 }
43}
44
45impl<D: Data> LWEInfos for GGSWCompressed<D> {
46 fn n(&self) -> Degree {
47 Degree(self.data.n() as u32)
48 }
49
50 fn base2k(&self) -> Base2K {
51 self.base2k
52 }
53
54 fn k(&self) -> TorusPrecision {
55 self.k
56 }
57 fn size(&self) -> usize {
58 self.data.size()
59 }
60}
61impl<D: Data> GLWEInfos for GGSWCompressed<D> {
62 fn rank(&self) -> Rank {
63 self.rank
64 }
65}
66
67impl<D: Data> GGSWInfos for GGSWCompressed<D> {
68 fn dsize(&self) -> Dsize {
69 self.dsize
70 }
71
72 fn dnum(&self) -> Dnum {
73 Dnum(self.data.rows() as u32)
74 }
75}
76
77impl<D: DataRef> fmt::Debug for GGSWCompressed<D> {
78 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79 write!(f, "{}", self.data)
80 }
81}
82
83impl<D: DataRef> fmt::Display for GGSWCompressed<D> {
84 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85 write!(
86 f,
87 "(GGSWCompressed: base2k={} k={} dsize={}) {}",
88 self.base2k, self.k, self.dsize, self.data
89 )
90 }
91}
92
93impl<D: DataMut> FillUniform for GGSWCompressed<D> {
94 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
95 self.data.fill_uniform(log_bound, source);
96 }
97}
98
99impl GGSWCompressed<Vec<u8>> {
100 pub fn alloc_from_infos<A>(infos: &A) -> Self
101 where
102 A: GGSWInfos,
103 {
104 Self::alloc(
105 infos.n(),
106 infos.base2k(),
107 infos.k(),
108 infos.rank(),
109 infos.dnum(),
110 infos.dsize(),
111 )
112 }
113
114 pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
115 let size: usize = k.0.div_ceil(base2k.0) as usize;
116 assert!(
117 size as u32 > dsize.0,
118 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
119 dsize.0
120 );
121
122 assert!(
123 dnum.0 * dsize.0 <= size as u32,
124 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
125 dnum.0,
126 dsize.0,
127 );
128
129 GGSWCompressed {
130 data: MatZnx::alloc(
131 n.into(),
132 dnum.into(),
133 (rank + 1).into(),
134 1,
135 k.0.div_ceil(base2k.0) as usize,
136 ),
137 k,
138 base2k,
139 dsize,
140 rank,
141 seed: vec![[0u8; 32]; dnum.as_usize() * (rank.as_usize() + 1)],
142 }
143 }
144
145 pub fn bytes_of_from_infos<A>(infos: &A) -> usize
146 where
147 A: GGSWInfos,
148 {
149 Self::bytes_of(
150 infos.n(),
151 infos.base2k(),
152 infos.k(),
153 infos.rank(),
154 infos.dnum(),
155 infos.dsize(),
156 )
157 }
158
159 pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
160 let size: usize = k.0.div_ceil(base2k.0) as usize;
161 assert!(
162 size as u32 > dsize.0,
163 "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
164 dsize.0
165 );
166
167 assert!(
168 dnum.0 * dsize.0 <= size as u32,
169 "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
170 dnum.0,
171 dsize.0,
172 );
173
174 MatZnx::bytes_of(
175 n.into(),
176 dnum.into(),
177 (rank + 1).into(),
178 1,
179 k.0.div_ceil(base2k.0) as usize,
180 )
181 }
182}
183
184impl<D: DataRef> GGSWCompressed<D> {
185 pub fn at(&self, row: usize, col: usize) -> GLWECompressed<&[u8]> {
186 let rank: usize = self.rank().into();
187 GLWECompressed {
188 data: self.data.at(row, col),
189 k: self.k,
190 base2k: self.base2k,
191 rank: self.rank,
192 seed: self.seed[row * (rank + 1) + col],
193 }
194 }
195}
196
197impl<D: DataMut> GGSWCompressed<D> {
198 pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECompressed<&mut [u8]> {
199 let rank: usize = self.rank().into();
200 GLWECompressed {
201 data: self.data.at_mut(row, col),
202 k: self.k,
203 base2k: self.base2k,
204 rank: self.rank,
205 seed: self.seed[row * (rank + 1) + col],
206 }
207 }
208}
209
210impl<D: DataMut> ReaderFrom for GGSWCompressed<D> {
211 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
212 self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
213 self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
214 self.dsize = Dsize(reader.read_u32::<LittleEndian>()?);
215 self.rank = Rank(reader.read_u32::<LittleEndian>()?);
216 let seed_len: usize = reader.read_u32::<LittleEndian>()? as usize;
217 self.seed = vec![[0u8; 32]; seed_len];
218 for s in &mut self.seed {
219 reader.read_exact(s)?;
220 }
221 self.data.read_from(reader)
222 }
223}
224
225impl<D: DataRef> WriterTo for GGSWCompressed<D> {
226 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
227 writer.write_u32::<LittleEndian>(self.k.into())?;
228 writer.write_u32::<LittleEndian>(self.base2k.into())?;
229 writer.write_u32::<LittleEndian>(self.dsize.into())?;
230 writer.write_u32::<LittleEndian>(self.rank.into())?;
231 writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
232 for s in &self.seed {
233 writer.write_all(s)?;
234 }
235 self.data.write_to(writer)
236 }
237}
238
239pub trait GGSWDecompress
240where
241 Self: GLWEDecompress,
242{
243 fn decompress_ggsw<R, O>(&self, res: &mut R, other: &O)
244 where
245 R: GGSWToMut,
246 O: GGSWCompressedToRef,
247 {
248 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
249 let other: &GGSWCompressed<&[u8]> = &other.to_ref();
250
251 assert_eq!(res.rank(), other.rank());
252 let dnum: usize = res.dnum().into();
253 let rank: usize = res.rank().into();
254
255 for row_i in 0..dnum {
256 for col_j in 0..rank + 1 {
257 self.decompress_glwe(&mut res.at_mut(row_i, col_j), &other.at(row_i, col_j));
258 }
259 }
260 }
261}
262
263impl<B: Backend> GGSWDecompress for Module<B> where Self: GLWEDecompress {}
264
265impl<D: DataMut> GGSW<D> {
266 pub fn decompress<O, M>(&mut self, module: &M, other: &O)
267 where
268 O: GGSWCompressedToRef,
269 M: GGSWDecompress,
270 {
271 module.decompress_ggsw(self, other);
272 }
273}
274
275pub trait GGSWCompressedToMut {
276 fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]>;
277}
278
279impl<D: DataMut> GGSWCompressedToMut for GGSWCompressed<D> {
280 fn to_mut(&mut self) -> GGSWCompressed<&mut [u8]> {
281 GGSWCompressed {
282 k: self.k(),
283 base2k: self.base2k(),
284 dsize: self.dsize(),
285 rank: self.rank(),
286 seed: self.seed.clone(),
287 data: self.data.to_mut(),
288 }
289 }
290}
291
292pub trait GGSWCompressedToRef {
293 fn to_ref(&self) -> GGSWCompressed<&[u8]>;
294}
295
296impl<D: DataRef> GGSWCompressedToRef for GGSWCompressed<D> {
297 fn to_ref(&self) -> GGSWCompressed<&[u8]> {
298 GGSWCompressed {
299 k: self.k(),
300 base2k: self.base2k(),
301 dsize: self.dsize(),
302 rank: self.rank(),
303 seed: self.seed.clone(),
304 data: self.data.to_ref(),
305 }
306 }
307}