1use std::fmt;
2
3use poulpy_hal::{
4 layouts::{
5 Backend, Data, FillUniform, HostDataMut, HostDataRef, Module, ReaderFrom, TransferFrom, VecZnx, VecZnxToBackendMut,
6 VecZnxToBackendRef, WriterTo,
7 },
8 source::Source,
9};
10
11use crate::api::ModuleTransfer;
12use crate::layouts::{Base2K, Degree, TorusPrecision};
13use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
14
15pub trait LWEInfos {
20 fn n(&self) -> Degree;
22 fn log_n(&self) -> usize {
24 self.n().log2()
25 }
26 fn max_k(&self) -> TorusPrecision {
28 TorusPrecision(self.size() as u32 * self.base2k().as_u32())
29 }
30
31 fn base2k(&self) -> Base2K;
33 fn size(&self) -> usize;
35
36 fn lwe_layout(&self) -> LWELayout {
38 LWELayout {
39 n: self.n(),
40 k: self.max_k(),
41 base2k: self.base2k(),
42 }
43 }
44}
45
46pub trait SetLWEInfos {
48 fn set_base2k(&mut self, base2k: Base2K);
50}
51
52#[derive(PartialEq, Eq, Copy, Clone, Debug)]
54pub struct LWELayout {
55 pub n: Degree,
57 pub k: TorusPrecision,
59 pub base2k: Base2K,
61}
62
63impl LWEInfos for LWELayout {
64 fn base2k(&self) -> Base2K {
65 self.base2k
66 }
67
68 fn n(&self) -> Degree {
69 self.n
70 }
71
72 fn size(&self) -> usize {
73 self.k.as_usize().div_ceil(self.base2k.into())
74 }
75}
76
77#[derive(PartialEq, Eq, Clone)]
85pub struct LWE<D: Data> {
86 pub(crate) body: VecZnx<D>,
87 pub(crate) mask: VecZnx<D>,
88 pub(crate) base2k: Base2K,
89}
90
91pub type LWEBackendRef<'a, BE> = LWE<<BE as Backend>::BufRef<'a>>;
92pub type LWEBackendMut<'a, BE> = LWE<<BE as Backend>::BufMut<'a>>;
93
94impl<D: Data> LWEInfos for LWE<D> {
95 fn base2k(&self) -> Base2K {
96 self.base2k
97 }
98
99 fn n(&self) -> Degree {
100 Degree(self.mask.n() as u32)
101 }
102
103 fn size(&self) -> usize {
104 self.mask.size()
105 }
106}
107
108impl<D: Data> SetLWEInfos for LWE<D> {
109 fn set_base2k(&mut self, base2k: Base2K) {
110 self.base2k = base2k
111 }
112}
113
114impl<D: Data> LWE<D> {
115 pub fn body(&self) -> &VecZnx<D> {
117 &self.body
118 }
119
120 pub fn body_mut(&mut self) -> &mut VecZnx<D> {
122 &mut self.body
123 }
124
125 pub fn mask(&self) -> &VecZnx<D> {
127 &self.mask
128 }
129
130 pub fn mask_mut(&mut self) -> &mut VecZnx<D> {
132 &mut self.mask
133 }
134
135 fn validate_shape(&self) -> std::io::Result<()> {
136 if self.base2k.as_u32() == 0 {
137 return Err(std::io::Error::new(
138 std::io::ErrorKind::InvalidData,
139 "LWE base2k must be non-zero",
140 ));
141 }
142 if self.body.n() != 1 {
143 return Err(std::io::Error::new(
144 std::io::ErrorKind::InvalidData,
145 format!("LWE body degree must be 1, got {}", self.body.n()),
146 ));
147 }
148 if self.body.cols() != 1 {
149 return Err(std::io::Error::new(
150 std::io::ErrorKind::InvalidData,
151 format!("LWE body cols must be 1, got {}", self.body.cols()),
152 ));
153 }
154 if self.mask.cols() != 1 {
155 return Err(std::io::Error::new(
156 std::io::ErrorKind::InvalidData,
157 format!("LWE mask cols must be 1, got {}", self.mask.cols()),
158 ));
159 }
160 if self.body.size() != self.mask.size() {
161 return Err(std::io::Error::new(
162 std::io::ErrorKind::InvalidData,
163 format!(
164 "LWE body and mask sizes must match, got body.size={} mask.size={}",
165 self.body.size(),
166 self.mask.size()
167 ),
168 ));
169 }
170 if self.body.size() > self.body.max_size() {
171 return Err(std::io::Error::new(
172 std::io::ErrorKind::InvalidData,
173 format!(
174 "LWE body size must not exceed max_size, got size={} max_size={}",
175 self.body.size(),
176 self.body.max_size()
177 ),
178 ));
179 }
180 if self.mask.size() > self.mask.max_size() {
181 return Err(std::io::Error::new(
182 std::io::ErrorKind::InvalidData,
183 format!(
184 "LWE mask size must not exceed max_size, got size={} max_size={}",
185 self.mask.size(),
186 self.mask.max_size()
187 ),
188 ));
189 }
190 Ok(())
191 }
192}
193
194impl<D: HostDataRef> LWE<D> {
195 pub fn to_backend<BE, To>(&self, dst: &Module<To>) -> LWE<To::OwnedBuf>
198 where
199 BE: Backend<OwnedBuf = D>,
200 To: Backend,
201 To: TransferFrom<BE>,
202 {
203 dst.upload_lwe(self)
204 }
205}
206
207impl<D: Data> LWE<D> {
208 pub fn reinterpret<To>(self) -> LWE<To::OwnedBuf>
210 where
211 To: Backend<OwnedBuf = D>,
212 {
213 let body_shape = self.body.shape();
214 let body_data = self.body.data;
215 let mask_shape = self.mask.shape();
216 let mask_data = self.mask.data;
217 LWE {
218 body: VecZnx::from_data_with_max_size(
219 body_data,
220 body_shape.n(),
221 body_shape.cols(),
222 body_shape.size(),
223 body_shape.max_size(),
224 ),
225 mask: VecZnx::from_data_with_max_size(
226 mask_data,
227 mask_shape.n(),
228 mask_shape.cols(),
229 mask_shape.size(),
230 mask_shape.max_size(),
231 ),
232 base2k: self.base2k,
233 }
234 }
235}
236
237impl<D: HostDataRef> fmt::Debug for LWE<D> {
238 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
239 write!(f, "{self}")
240 }
241}
242
243impl<D: HostDataRef> fmt::Display for LWE<D> {
244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245 write!(
246 f,
247 "LWE: base2k={} k={}: body={} mask={}",
248 self.base2k().0,
249 self.max_k().0,
250 self.body,
251 self.mask
252 )
253 }
254}
255
256impl<D: HostDataMut> FillUniform for LWE<D>
257where
258 VecZnx<D>: FillUniform,
259{
260 fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
261 self.mask.fill_uniform(log_bound, source);
262 }
263}
264
265impl LWE<Vec<u8>> {
266 pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
268 where
269 A: LWEInfos,
270 {
271 Self::alloc(infos.n(), infos.base2k(), infos.max_k())
272 }
273
274 pub(crate) fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self {
280 let size: usize = k.0.div_ceil(base2k.0) as usize;
281 LWE {
282 body: VecZnx::from_data(
283 poulpy_hal::layouts::HostBytesBackend::alloc_bytes(VecZnx::<Vec<u8>>::bytes_of(1, 1, size)),
284 1,
285 1,
286 size,
287 ),
288 mask: VecZnx::from_data(
289 poulpy_hal::layouts::HostBytesBackend::alloc_bytes(VecZnx::<Vec<u8>>::bytes_of(n.as_usize(), 1, size)),
290 n.as_usize(),
291 1,
292 size,
293 ),
294 base2k,
295 }
296 }
297
298 pub fn bytes_of_from_infos<A>(infos: &A) -> usize
300 where
301 A: LWEInfos,
302 {
303 Self::bytes_of(infos.n(), infos.base2k(), infos.max_k())
304 }
305
306 pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize {
312 let size: usize = k.0.div_ceil(base2k.0) as usize;
313 VecZnx::<Vec<u8>>::bytes_of(1, 1, size) + VecZnx::<Vec<u8>>::bytes_of(n.as_usize(), 1, size)
314 }
315}
316
317pub trait LWEToBackendRef<BE: Backend> {
318 fn to_backend_ref(&self) -> LWEBackendRef<'_, BE>;
319}
320
321impl<BE: Backend, D: Data> LWEToBackendRef<BE> for LWE<D>
322where
323 VecZnx<D>: VecZnxToBackendRef<BE>,
324{
325 fn to_backend_ref(&self) -> LWEBackendRef<'_, BE> {
326 LWE {
327 base2k: self.base2k,
328 body: self.body.to_backend_ref(),
329 mask: self.mask.to_backend_ref(),
330 }
331 }
332}
333
334pub trait LWEToBackendMut<BE: Backend>: LWEToBackendRef<BE> {
335 fn to_backend_mut(&mut self) -> LWEBackendMut<'_, BE>;
336}
337
338impl<BE: Backend, D: Data> LWEToBackendMut<BE> for LWE<D>
339where
340 VecZnx<D>: VecZnxToBackendRef<BE> + VecZnxToBackendMut<BE>,
341{
342 fn to_backend_mut(&mut self) -> LWEBackendMut<'_, BE> {
343 LWE {
344 base2k: self.base2k,
345 body: self.body.to_backend_mut(),
346 mask: self.mask.to_backend_mut(),
347 }
348 }
349}
350
351impl<'b, BE: Backend + 'b> LWEToBackendRef<BE> for &mut LWE<BE::BufMut<'b>> {
352 fn to_backend_ref(&self) -> LWEBackendRef<'_, BE> {
353 LWE {
354 base2k: self.base2k,
355 body: poulpy_hal::layouts::vec_znx_backend_ref_from_mut::<BE>(&self.body),
356 mask: poulpy_hal::layouts::vec_znx_backend_ref_from_mut::<BE>(&self.mask),
357 }
358 }
359}
360
361impl<'b, BE: Backend + 'b> LWEToBackendMut<BE> for &mut LWE<BE::BufMut<'b>> {
362 fn to_backend_mut(&mut self) -> LWEBackendMut<'_, BE> {
363 LWE {
364 base2k: self.base2k,
365 body: poulpy_hal::layouts::vec_znx_backend_mut_from_mut::<BE>(&mut self.body),
366 mask: poulpy_hal::layouts::vec_znx_backend_mut_from_mut::<BE>(&mut self.mask),
367 }
368 }
369}
370
371impl<D: HostDataMut> ReaderFrom for LWE<D> {
372 fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
374 self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
375 self.body.read_from(reader)?;
376 self.mask.read_from(reader)?;
377 self.validate_shape()
378 }
379}
380
381impl<D: HostDataRef> WriterTo for LWE<D> {
382 fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
384 writer.write_u32::<LittleEndian>(self.base2k.into())?;
385 self.body.write_to(writer)?;
386 self.mask.write_to(writer)
387 }
388}