poulpy_core/layouts/compressed/
ggsw_ct.rs1use poulpy_hal::{
2 api::{VecZnxCopy, VecZnxFillUniform},
3 layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos},
4 source::Source,
5};
6
7use crate::layouts::{
8 Base2K, Degree, Digits, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
9 compressed::{Decompress, GLWECiphertextCompressed},
10};
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12use std::fmt;
13
14#[derive(PartialEq, Eq, Clone)]
15pub struct GGSWCiphertextCompressed<D: Data> {
16 pub(crate) data: MatZnx<D>,
17 pub(crate) k: TorusPrecision,
18 pub(crate) base2k: Base2K,
19 pub(crate) digits: Digits,
20 pub(crate) rank: Rank,
21 pub(crate) seed: Vec<[u8; 32]>,
22}
23
24impl<D: Data> LWEInfos for GGSWCiphertextCompressed<D> {
25 fn n(&self) -> Degree {
26 Degree(self.data.n() as u32)
27 }
28
29 fn base2k(&self) -> Base2K {
30 self.base2k
31 }
32
33 fn k(&self) -> TorusPrecision {
34 self.k
35 }
36 fn size(&self) -> usize {
37 self.data.size()
38 }
39}
40impl<D: Data> GLWEInfos for GGSWCiphertextCompressed<D> {
41 fn rank(&self) -> Rank {
42 self.rank
43 }
44}
45
46impl<D: Data> GGSWInfos for GGSWCiphertextCompressed<D> {
47 fn digits(&self) -> Digits {
48 self.digits
49 }
50
51 fn rows(&self) -> Rows {
52 Rows(self.data.rows() as u32)
53 }
54}
55
56impl<D: DataRef> fmt::Debug for GGSWCiphertextCompressed<D> {
57 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58 write!(f, "{}", self.data)
59 }
60}
61
62impl<D: DataRef> fmt::Display for GGSWCiphertextCompressed<D> {
63 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64 write!(
65 f,
66 "(GGSWCiphertextCompressed: base2k={} k={} digits={}) {}",
67 self.base2k, self.k, self.digits, self.data
68 )
69 }
70}
71
72impl<D: DataMut> FillUniform for GGSWCiphertextCompressed<D> {
73 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
74 self.data.fill_uniform(log_bound, source);
75 }
76}
77
78impl GGSWCiphertextCompressed<Vec<u8>> {
79 pub fn alloc<A>(infos: &A) -> Self
80 where
81 A: GGSWInfos,
82 {
83 Self::alloc_with(
84 infos.n(),
85 infos.base2k(),
86 infos.k(),
87 infos.rows(),
88 infos.digits(),
89 infos.rank(),
90 )
91 }
92
93 pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self {
94 let size: usize = k.0.div_ceil(base2k.0) as usize;
95 debug_assert!(
96 size as u32 > digits.0,
97 "invalid ggsw: ceil(k/base2k): {size} <= digits: {}",
98 digits.0
99 );
100
101 assert!(
102 rows.0 * digits.0 <= size as u32,
103 "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}",
104 rows.0,
105 digits.0,
106 );
107
108 Self {
109 data: MatZnx::alloc(
110 n.into(),
111 rows.into(),
112 (rank + 1).into(),
113 1,
114 k.0.div_ceil(base2k.0) as usize,
115 ),
116 k,
117 base2k,
118 digits,
119 rank,
120 seed: Vec::new(),
121 }
122 }
123
124 pub fn alloc_bytes<A>(infos: &A) -> usize
125 where
126 A: GGSWInfos,
127 {
128 Self::alloc_bytes_with(
129 infos.n(),
130 infos.base2k(),
131 infos.k(),
132 infos.rows(),
133 infos.digits(),
134 infos.rank(),
135 )
136 }
137
138 pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize {
139 let size: usize = k.0.div_ceil(base2k.0) as usize;
140 debug_assert!(
141 size as u32 > digits.0,
142 "invalid ggsw: ceil(k/base2k): {size} <= digits: {}",
143 digits.0
144 );
145
146 assert!(
147 rows.0 * digits.0 <= size as u32,
148 "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}",
149 rows.0,
150 digits.0,
151 );
152
153 MatZnx::alloc_bytes(
154 n.into(),
155 rows.into(),
156 (rank + 1).into(),
157 1,
158 k.0.div_ceil(base2k.0) as usize,
159 )
160 }
161}
162
163impl<D: DataRef> GGSWCiphertextCompressed<D> {
164 pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> {
165 let rank: usize = self.rank().into();
166 GLWECiphertextCompressed {
167 data: self.data.at(row, col),
168 k: self.k,
169 base2k: self.base2k,
170 rank: self.rank,
171 seed: self.seed[row * (rank + 1) + col],
172 }
173 }
174}
175
176impl<D: DataMut> GGSWCiphertextCompressed<D> {
177 pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> {
178 let rank: usize = self.rank().into();
179 GLWECiphertextCompressed {
180 data: self.data.at_mut(row, col),
181 k: self.k,
182 base2k: self.base2k,
183 rank: self.rank,
184 seed: self.seed[row * (rank + 1) + col],
185 }
186 }
187}
188
189impl<D: DataMut> ReaderFrom for GGSWCiphertextCompressed<D> {
190 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
191 self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
192 self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
193 self.digits = Digits(reader.read_u32::<LittleEndian>()?);
194 self.rank = Rank(reader.read_u32::<LittleEndian>()?);
195 let seed_len: usize = reader.read_u32::<LittleEndian>()? as usize;
196 self.seed = vec![[0u8; 32]; seed_len];
197 for s in &mut self.seed {
198 reader.read_exact(s)?;
199 }
200 self.data.read_from(reader)
201 }
202}
203
204impl<D: DataRef> WriterTo for GGSWCiphertextCompressed<D> {
205 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
206 writer.write_u32::<LittleEndian>(self.k.into())?;
207 writer.write_u32::<LittleEndian>(self.base2k.into())?;
208 writer.write_u32::<LittleEndian>(self.digits.into())?;
209 writer.write_u32::<LittleEndian>(self.rank.into())?;
210 writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
211 for s in &self.seed {
212 writer.write_all(s)?;
213 }
214 self.data.write_to(writer)
215 }
216}
217
218impl<D: DataMut, B: Backend, DR: DataRef> Decompress<B, GGSWCiphertextCompressed<DR>> for GGSWCiphertext<D>
219where
220 Module<B>: VecZnxFillUniform + VecZnxCopy,
221{
222 fn decompress(&mut self, module: &Module<B>, other: &GGSWCiphertextCompressed<DR>) {
223 #[cfg(debug_assertions)]
224 {
225 assert_eq!(self.rank(), other.rank())
226 }
227
228 let rows: usize = self.rows().into();
229 let rank: usize = self.rank().into();
230 (0..rows).for_each(|row_i| {
231 (0..rank + 1).for_each(|col_j| {
232 self.at_mut(row_i, col_j)
233 .decompress(module, &other.at(row_i, col_j));
234 });
235 });
236 }
237}