1use poulpy_hal::{
2 layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, ReaderFrom, WriterTo, ZnxInfos},
3 source::Source,
4};
5use std::fmt;
6
7use crate::layouts::{Base2K, BuildError, Degree, Digits, GLWECiphertext, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision};
8
9pub trait GGSWInfos
10where
11 Self: GLWEInfos,
12{
13 fn rows(&self) -> Rows;
14 fn digits(&self) -> Digits;
15 fn layout(&self) -> GGSWCiphertextLayout {
16 GGSWCiphertextLayout {
17 n: self.n(),
18 base2k: self.base2k(),
19 k: self.k(),
20 rank: self.rank(),
21 rows: self.rows(),
22 digits: self.digits(),
23 }
24 }
25}
26
27#[derive(PartialEq, Eq, Copy, Clone, Debug)]
28pub struct GGSWCiphertextLayout {
29 pub n: Degree,
30 pub base2k: Base2K,
31 pub k: TorusPrecision,
32 pub rows: Rows,
33 pub digits: Digits,
34 pub rank: Rank,
35}
36
37impl LWEInfos for GGSWCiphertextLayout {
38 fn base2k(&self) -> Base2K {
39 self.base2k
40 }
41
42 fn k(&self) -> TorusPrecision {
43 self.k
44 }
45
46 fn n(&self) -> Degree {
47 self.n
48 }
49}
50impl GLWEInfos for GGSWCiphertextLayout {
51 fn rank(&self) -> Rank {
52 self.rank
53 }
54}
55
56impl GGSWInfos for GGSWCiphertextLayout {
57 fn digits(&self) -> Digits {
58 self.digits
59 }
60
61 fn rows(&self) -> Rows {
62 self.rows
63 }
64}
65
66#[derive(PartialEq, Eq, Clone)]
67pub struct GGSWCiphertext<D: Data> {
68 pub(crate) data: MatZnx<D>,
69 pub(crate) k: TorusPrecision,
70 pub(crate) base2k: Base2K,
71 pub(crate) digits: Digits,
72}
73
74impl<D: Data> LWEInfos for GGSWCiphertext<D> {
75 fn n(&self) -> Degree {
76 Degree(self.data.n() as u32)
77 }
78
79 fn base2k(&self) -> Base2K {
80 self.base2k
81 }
82
83 fn k(&self) -> TorusPrecision {
84 self.k
85 }
86
87 fn size(&self) -> usize {
88 self.data.size()
89 }
90}
91
92impl<D: Data> GLWEInfos for GGSWCiphertext<D> {
93 fn rank(&self) -> Rank {
94 Rank(self.data.cols_out() as u32 - 1)
95 }
96}
97
98impl<D: Data> GGSWInfos for GGSWCiphertext<D> {
99 fn digits(&self) -> Digits {
100 self.digits
101 }
102
103 fn rows(&self) -> Rows {
104 Rows(self.data.rows() as u32)
105 }
106}
107
108pub struct GGSWCiphertextBuilder<D: Data> {
109 data: Option<MatZnx<D>>,
110 base2k: Option<Base2K>,
111 k: Option<TorusPrecision>,
112 digits: Option<Digits>,
113}
114
115impl<D: Data> GGSWCiphertext<D> {
116 #[inline]
117 pub fn builder() -> GGSWCiphertextBuilder<D> {
118 GGSWCiphertextBuilder {
119 data: None,
120 base2k: None,
121 k: None,
122 digits: None,
123 }
124 }
125}
126
127impl GGSWCiphertextBuilder<Vec<u8>> {
128 #[inline]
129 pub fn layout<A>(mut self, infos: &A) -> Self
130 where
131 A: GGSWInfos,
132 {
133 debug_assert!(
134 infos.size() as u32 > infos.digits().0,
135 "invalid ggsw: ceil(k/base2k): {} <= digits: {}",
136 infos.size(),
137 infos.digits()
138 );
139
140 assert!(
141 infos.rows().0 * infos.digits().0 <= infos.size() as u32,
142 "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {}",
143 infos.rows(),
144 infos.digits(),
145 infos.size(),
146 );
147
148 self.data = Some(MatZnx::alloc(
149 infos.n().into(),
150 infos.rows().into(),
151 (infos.rank() + 1).into(),
152 (infos.rank() + 1).into(),
153 infos.size(),
154 ));
155 self.base2k = Some(infos.base2k());
156 self.k = Some(infos.k());
157 self.digits = Some(infos.digits());
158 self
159 }
160}
161
162impl<D: Data> GGSWCiphertextBuilder<D> {
163 #[inline]
164 pub fn data(mut self, data: MatZnx<D>) -> Self {
165 self.data = Some(data);
166 self
167 }
168 #[inline]
169 pub fn base2k(mut self, base2k: Base2K) -> Self {
170 self.base2k = Some(base2k);
171 self
172 }
173 #[inline]
174 pub fn k(mut self, k: TorusPrecision) -> Self {
175 self.k = Some(k);
176 self
177 }
178
179 #[inline]
180 pub fn digits(mut self, digits: Digits) -> Self {
181 self.digits = Some(digits);
182 self
183 }
184
185 pub fn build(self) -> Result<GGSWCiphertext<D>, BuildError> {
186 let data: MatZnx<D> = self.data.ok_or(BuildError::MissingData)?;
187 let base2k: Base2K = self.base2k.ok_or(BuildError::MissingBase2K)?;
188 let k: TorusPrecision = self.k.ok_or(BuildError::MissingK)?;
189 let digits: Digits = self.digits.ok_or(BuildError::MissingDigits)?;
190
191 if base2k == 0_u32 {
192 return Err(BuildError::ZeroBase2K);
193 }
194
195 if digits == 0_u32 {
196 return Err(BuildError::ZeroBase2K);
197 }
198
199 if k == 0_u32 {
200 return Err(BuildError::ZeroTorusPrecision);
201 }
202
203 if data.n() == 0 {
204 return Err(BuildError::ZeroDegree);
205 }
206
207 if data.cols() == 0 {
208 return Err(BuildError::ZeroCols);
209 }
210
211 if data.size() == 0 {
212 return Err(BuildError::ZeroLimbs);
213 }
214
215 Ok(GGSWCiphertext {
216 data,
217 base2k,
218 k,
219 digits,
220 })
221 }
222}
223
224impl<D: DataRef> fmt::Debug for GGSWCiphertext<D> {
225 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
226 write!(f, "{}", self.data)
227 }
228}
229
230impl<D: DataRef> fmt::Display for GGSWCiphertext<D> {
231 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232 write!(
233 f,
234 "(GGSWCiphertext: k: {} base2k: {} digits: {}) {}",
235 self.k().0,
236 self.base2k().0,
237 self.digits().0,
238 self.data
239 )
240 }
241}
242
243impl<D: DataMut> FillUniform for GGSWCiphertext<D> {
244 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
245 self.data.fill_uniform(log_bound, source);
246 }
247}
248
249impl<D: DataRef> GGSWCiphertext<D> {
250 pub fn at(&self, row: usize, col: usize) -> GLWECiphertext<&[u8]> {
251 GLWECiphertext::builder()
252 .data(self.data.at(row, col))
253 .base2k(self.base2k())
254 .k(self.k())
255 .build()
256 .unwrap()
257 }
258}
259
260impl<D: DataMut> GGSWCiphertext<D> {
261 pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertext<&mut [u8]> {
262 GLWECiphertext::builder()
263 .base2k(self.base2k())
264 .k(self.k())
265 .data(self.data.at_mut(row, col))
266 .build()
267 .unwrap()
268 }
269}
270
271impl GGSWCiphertext<Vec<u8>> {
272 pub fn alloc<A>(infos: &A) -> Self
273 where
274 A: GGSWInfos,
275 {
276 Self::alloc_with(
277 infos.n(),
278 infos.base2k(),
279 infos.k(),
280 infos.rows(),
281 infos.digits(),
282 infos.rank(),
283 )
284 }
285
286 pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self {
287 let size: usize = k.0.div_ceil(base2k.0) as usize;
288 debug_assert!(
289 size as u32 > digits.0,
290 "invalid ggsw: ceil(k/base2k): {size} <= digits: {}",
291 digits.0
292 );
293
294 assert!(
295 rows.0 * digits.0 <= size as u32,
296 "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}",
297 rows.0,
298 digits.0,
299 );
300
301 Self {
302 data: MatZnx::alloc(
303 n.into(),
304 rows.into(),
305 (rank + 1).into(),
306 (rank + 1).into(),
307 k.0.div_ceil(base2k.0) as usize,
308 ),
309 k,
310 base2k,
311 digits,
312 }
313 }
314
315 pub fn alloc_bytes<A>(infos: &A) -> usize
316 where
317 A: GGSWInfos,
318 {
319 Self::alloc_bytes_with(
320 infos.n(),
321 infos.base2k(),
322 infos.k(),
323 infos.rows(),
324 infos.digits(),
325 infos.rank(),
326 )
327 }
328
329 pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize {
330 let size: usize = k.0.div_ceil(base2k.0) as usize;
331 debug_assert!(
332 size as u32 > digits.0,
333 "invalid ggsw: ceil(k/base2k): {size} <= digits: {}",
334 digits.0
335 );
336
337 assert!(
338 rows.0 * digits.0 <= size as u32,
339 "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}",
340 rows.0,
341 digits.0,
342 );
343
344 MatZnx::alloc_bytes(
345 n.into(),
346 rows.into(),
347 (rank + 1).into(),
348 (rank + 1).into(),
349 k.0.div_ceil(base2k.0) as usize,
350 )
351 }
352}
353
354use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
355
356impl<D: DataMut> ReaderFrom for GGSWCiphertext<D> {
357 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
358 self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
359 self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
360 self.digits = Digits(reader.read_u32::<LittleEndian>()?);
361 self.data.read_from(reader)
362 }
363}
364
365impl<D: DataRef> WriterTo for GGSWCiphertext<D> {
366 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
367 writer.write_u32::<LittleEndian>(self.k.into())?;
368 writer.write_u32::<LittleEndian>(self.base2k.into())?;
369 writer.write_u32::<LittleEndian>(self.digits.into())?;
370 self.data.write_to(writer)
371 }
372}