Skip to main content

poulpy_core/layouts/
lwe.rs

1use std::fmt;
2
3use poulpy_hal::{
4    layouts::{
5        Backend, Data, FillUniform, HostDataMut, HostDataRef, Module, ReaderFrom, TransferFrom, VecZnx, VecZnxToBackendMut,
6        VecZnxToBackendRef, WriterTo,
7    },
8    source::Source,
9};
10
11use crate::api::ModuleTransfer;
12use crate::layouts::{Base2K, Degree, TorusPrecision};
13use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
14
15/// Trait providing the parameter accessors for an LWE ciphertext.
16///
17/// An LWE ciphertext is a scalar (non-polynomial) ciphertext consisting of
18/// a body `b` and a mask `(a_1, ..., a_n)`.
19pub trait LWEInfos {
20    /// Returns the LWE dimension, i.e. the number of mask elements (= GLWE ring degree N).
21    fn n(&self) -> Degree;
22    /// Returns `log2(n)`.
23    fn log_n(&self) -> usize {
24        self.n().log2()
25    }
26    /// Returns the maximum torus precision representable with the current limb decomposition.
27    fn max_k(&self) -> TorusPrecision {
28        TorusPrecision(self.size() as u32 * self.base2k().as_u32())
29    }
30
31    /// Returns the base-2-log of the limb width used for the RNS/CRT representation.
32    fn base2k(&self) -> Base2K;
33    /// Returns the number of limbs, i.e. `ceil(k / base2k)`.
34    fn size(&self) -> usize;
35
36    /// Returns a plain-data [`LWELayout`] snapshot of the current parameters.
37    fn lwe_layout(&self) -> LWELayout {
38        LWELayout {
39            n: self.n(),
40            k: self.max_k(),
41            base2k: self.base2k(),
42        }
43    }
44}
45
46/// Trait for mutating LWE parameters in place.
47pub trait SetLWEInfos {
48    /// Sets the limb width `base2k`.
49    fn set_base2k(&mut self, base2k: Base2K);
50}
51
52/// Plain-data snapshot of the parameters that describe an [`LWE`] ciphertext.
53#[derive(PartialEq, Eq, Copy, Clone, Debug)]
54pub struct LWELayout {
55    /// Ring degree (LWE dimension).
56    pub n: Degree,
57    /// Torus precision.
58    pub k: TorusPrecision,
59    /// Base-2-log of the limb width.
60    pub base2k: Base2K,
61}
62
63impl LWEInfos for LWELayout {
64    fn base2k(&self) -> Base2K {
65        self.base2k
66    }
67
68    fn n(&self) -> Degree {
69        self.n
70    }
71
72    fn size(&self) -> usize {
73        self.k.as_usize().div_ceil(self.base2k.into())
74    }
75}
76
77/// A scalar (non-polynomial) LWE ciphertext.
78///
79/// Stored as two separate [`VecZnx`] buffers:
80/// - `body`: degree-0 polynomial (n = 1) holding the scalar body `b`.
81/// - `mask`: degree-n polynomial (n = lwe_dim) holding the mask `(a_1, ..., a_n)`.
82///
83/// `D: Data` is the storage backend (e.g. `Vec<u8>`, `&[u8]`, `&mut [u8]`).
84#[derive(PartialEq, Eq, Clone)]
85pub struct LWE<D: Data> {
86    pub(crate) body: VecZnx<D>,
87    pub(crate) mask: VecZnx<D>,
88    pub(crate) base2k: Base2K,
89}
90
91pub type LWEBackendRef<'a, BE> = LWE<<BE as Backend>::BufRef<'a>>;
92pub type LWEBackendMut<'a, BE> = LWE<<BE as Backend>::BufMut<'a>>;
93
94impl<D: Data> LWEInfos for LWE<D> {
95    fn base2k(&self) -> Base2K {
96        self.base2k
97    }
98
99    fn n(&self) -> Degree {
100        Degree(self.mask.n() as u32)
101    }
102
103    fn size(&self) -> usize {
104        self.mask.size()
105    }
106}
107
108impl<D: Data> SetLWEInfos for LWE<D> {
109    fn set_base2k(&mut self, base2k: Base2K) {
110        self.base2k = base2k
111    }
112}
113
114impl<D: Data> LWE<D> {
115    /// Returns a shared reference to the body [`VecZnx`] (n = 1).
116    pub fn body(&self) -> &VecZnx<D> {
117        &self.body
118    }
119
120    /// Returns a mutable reference to the body [`VecZnx`] (n = 1).
121    pub fn body_mut(&mut self) -> &mut VecZnx<D> {
122        &mut self.body
123    }
124
125    /// Returns a shared reference to the mask [`VecZnx`] (n = lwe_dim).
126    pub fn mask(&self) -> &VecZnx<D> {
127        &self.mask
128    }
129
130    /// Returns a mutable reference to the mask [`VecZnx`] (n = lwe_dim).
131    pub fn mask_mut(&mut self) -> &mut VecZnx<D> {
132        &mut self.mask
133    }
134
135    fn validate_shape(&self) -> std::io::Result<()> {
136        if self.base2k.as_u32() == 0 {
137            return Err(std::io::Error::new(
138                std::io::ErrorKind::InvalidData,
139                "LWE base2k must be non-zero",
140            ));
141        }
142        if self.body.n() != 1 {
143            return Err(std::io::Error::new(
144                std::io::ErrorKind::InvalidData,
145                format!("LWE body degree must be 1, got {}", self.body.n()),
146            ));
147        }
148        if self.body.cols() != 1 {
149            return Err(std::io::Error::new(
150                std::io::ErrorKind::InvalidData,
151                format!("LWE body cols must be 1, got {}", self.body.cols()),
152            ));
153        }
154        if self.mask.cols() != 1 {
155            return Err(std::io::Error::new(
156                std::io::ErrorKind::InvalidData,
157                format!("LWE mask cols must be 1, got {}", self.mask.cols()),
158            ));
159        }
160        if self.body.size() != self.mask.size() {
161            return Err(std::io::Error::new(
162                std::io::ErrorKind::InvalidData,
163                format!(
164                    "LWE body and mask sizes must match, got body.size={} mask.size={}",
165                    self.body.size(),
166                    self.mask.size()
167                ),
168            ));
169        }
170        if self.body.size() > self.body.max_size() {
171            return Err(std::io::Error::new(
172                std::io::ErrorKind::InvalidData,
173                format!(
174                    "LWE body size must not exceed max_size, got size={} max_size={}",
175                    self.body.size(),
176                    self.body.max_size()
177                ),
178            ));
179        }
180        if self.mask.size() > self.mask.max_size() {
181            return Err(std::io::Error::new(
182                std::io::ErrorKind::InvalidData,
183                format!(
184                    "LWE mask size must not exceed max_size, got size={} max_size={}",
185                    self.mask.size(),
186                    self.mask.max_size()
187                ),
188            ));
189        }
190        Ok(())
191    }
192}
193
194impl<D: HostDataRef> LWE<D> {
195    /// Copies this ciphertext's backing bytes into an owned buffer of
196    /// backend `To`, routing via host bytes.
197    pub fn to_backend<BE, To>(&self, dst: &Module<To>) -> LWE<To::OwnedBuf>
198    where
199        BE: Backend<OwnedBuf = D>,
200        To: Backend,
201        To: TransferFrom<BE>,
202    {
203        dst.upload_lwe(self)
204    }
205}
206
207impl<D: Data> LWE<D> {
208    /// Zero-cost rename when both backends share the same `OwnedBuf`.
209    pub fn reinterpret<To>(self) -> LWE<To::OwnedBuf>
210    where
211        To: Backend<OwnedBuf = D>,
212    {
213        let body_shape = self.body.shape();
214        let body_data = self.body.data;
215        let mask_shape = self.mask.shape();
216        let mask_data = self.mask.data;
217        LWE {
218            body: VecZnx::from_data_with_max_size(
219                body_data,
220                body_shape.n(),
221                body_shape.cols(),
222                body_shape.size(),
223                body_shape.max_size(),
224            ),
225            mask: VecZnx::from_data_with_max_size(
226                mask_data,
227                mask_shape.n(),
228                mask_shape.cols(),
229                mask_shape.size(),
230                mask_shape.max_size(),
231            ),
232            base2k: self.base2k,
233        }
234    }
235}
236
237impl<D: HostDataRef> fmt::Debug for LWE<D> {
238    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
239        write!(f, "{self}")
240    }
241}
242
243impl<D: HostDataRef> fmt::Display for LWE<D> {
244    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
245        write!(
246            f,
247            "LWE: base2k={} k={}: body={} mask={}",
248            self.base2k().0,
249            self.max_k().0,
250            self.body,
251            self.mask
252        )
253    }
254}
255
256impl<D: HostDataMut> FillUniform for LWE<D>
257where
258    VecZnx<D>: FillUniform,
259{
260    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
261        self.mask.fill_uniform(log_bound, source);
262    }
263}
264
265impl LWE<Vec<u8>> {
266    /// Allocates a new [`LWE`] with the given parameters.
267    pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
268    where
269        A: LWEInfos,
270    {
271        Self::alloc(infos.n(), infos.base2k(), infos.max_k())
272    }
273
274    /// Allocates a new [`LWE`] with the given parameters.
275    ///
276    /// * `n` -- LWE dimension (mask length).
277    /// * `base2k` -- base-2-log of the limb width.
278    /// * `k` -- torus precision.
279    pub(crate) fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision) -> Self {
280        let size: usize = k.0.div_ceil(base2k.0) as usize;
281        LWE {
282            body: VecZnx::from_data(
283                poulpy_hal::layouts::HostBytesBackend::alloc_bytes(VecZnx::<Vec<u8>>::bytes_of(1, 1, size)),
284                1,
285                1,
286                size,
287            ),
288            mask: VecZnx::from_data(
289                poulpy_hal::layouts::HostBytesBackend::alloc_bytes(VecZnx::<Vec<u8>>::bytes_of(n.as_usize(), 1, size)),
290                n.as_usize(),
291                1,
292                size,
293            ),
294            base2k,
295        }
296    }
297
298    /// Returns the byte count required for an [`LWE`] with the given parameters.
299    pub fn bytes_of_from_infos<A>(infos: &A) -> usize
300    where
301        A: LWEInfos,
302    {
303        Self::bytes_of(infos.n(), infos.base2k(), infos.max_k())
304    }
305
306    /// Returns the byte count required for an [`LWE`] with the given parameters.
307    ///
308    /// * `n` -- LWE dimension (mask length).
309    /// * `base2k` -- base-2-log of the limb width.
310    /// * `k` -- torus precision.
311    pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision) -> usize {
312        let size: usize = k.0.div_ceil(base2k.0) as usize;
313        VecZnx::<Vec<u8>>::bytes_of(1, 1, size) + VecZnx::<Vec<u8>>::bytes_of(n.as_usize(), 1, size)
314    }
315}
316
317pub trait LWEToBackendRef<BE: Backend> {
318    fn to_backend_ref(&self) -> LWEBackendRef<'_, BE>;
319}
320
321impl<BE: Backend, D: Data> LWEToBackendRef<BE> for LWE<D>
322where
323    VecZnx<D>: VecZnxToBackendRef<BE>,
324{
325    fn to_backend_ref(&self) -> LWEBackendRef<'_, BE> {
326        LWE {
327            base2k: self.base2k,
328            body: self.body.to_backend_ref(),
329            mask: self.mask.to_backend_ref(),
330        }
331    }
332}
333
334pub trait LWEToBackendMut<BE: Backend>: LWEToBackendRef<BE> {
335    fn to_backend_mut(&mut self) -> LWEBackendMut<'_, BE>;
336}
337
338impl<BE: Backend, D: Data> LWEToBackendMut<BE> for LWE<D>
339where
340    VecZnx<D>: VecZnxToBackendRef<BE> + VecZnxToBackendMut<BE>,
341{
342    fn to_backend_mut(&mut self) -> LWEBackendMut<'_, BE> {
343        LWE {
344            base2k: self.base2k,
345            body: self.body.to_backend_mut(),
346            mask: self.mask.to_backend_mut(),
347        }
348    }
349}
350
351impl<'b, BE: Backend + 'b> LWEToBackendRef<BE> for &mut LWE<BE::BufMut<'b>> {
352    fn to_backend_ref(&self) -> LWEBackendRef<'_, BE> {
353        LWE {
354            base2k: self.base2k,
355            body: poulpy_hal::layouts::vec_znx_backend_ref_from_mut::<BE>(&self.body),
356            mask: poulpy_hal::layouts::vec_znx_backend_ref_from_mut::<BE>(&self.mask),
357        }
358    }
359}
360
361impl<'b, BE: Backend + 'b> LWEToBackendMut<BE> for &mut LWE<BE::BufMut<'b>> {
362    fn to_backend_mut(&mut self) -> LWEBackendMut<'_, BE> {
363        LWE {
364            base2k: self.base2k,
365            body: poulpy_hal::layouts::vec_znx_backend_mut_from_mut::<BE>(&mut self.body),
366            mask: poulpy_hal::layouts::vec_znx_backend_mut_from_mut::<BE>(&mut self.mask),
367        }
368    }
369}
370
371impl<D: HostDataMut> ReaderFrom for LWE<D> {
372    /// Deserialises an [`LWE`] in little-endian binary format.
373    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
374        self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
375        self.body.read_from(reader)?;
376        self.mask.read_from(reader)?;
377        self.validate_shape()
378    }
379}
380
381impl<D: HostDataRef> WriterTo for LWE<D> {
382    /// Serialises the [`LWE`] in little-endian binary format.
383    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
384        writer.write_u32::<LittleEndian>(self.base2k.into())?;
385        self.body.write_to(writer)?;
386        self.mask.write_to(writer)
387    }
388}