poulpy_core/layouts/compressed/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{VecZnxCopy, VecZnxFillUniform},
3    layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos},
4    source::Source,
5};
6
7use crate::layouts::{
8    Base2K, Degree, Digits, GGLWECiphertext, GGLWELayoutInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
9    compressed::{Decompress, GLWECiphertextCompressed},
10};
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12use std::fmt;
13
14#[derive(PartialEq, Eq, Clone)]
15pub struct GGLWECiphertextCompressed<D: Data> {
16    pub(crate) data: MatZnx<D>,
17    pub(crate) base2k: Base2K,
18    pub(crate) k: TorusPrecision,
19    pub(crate) rank_out: Rank,
20    pub(crate) digits: Digits,
21    pub(crate) seed: Vec<[u8; 32]>,
22}
23
24impl<D: Data> LWEInfos for GGLWECiphertextCompressed<D> {
25    fn n(&self) -> Degree {
26        Degree(self.data.n() as u32)
27    }
28
29    fn base2k(&self) -> Base2K {
30        self.base2k
31    }
32
33    fn k(&self) -> TorusPrecision {
34        self.k
35    }
36
37    fn size(&self) -> usize {
38        self.data.size()
39    }
40}
41impl<D: Data> GLWEInfos for GGLWECiphertextCompressed<D> {
42    fn rank(&self) -> Rank {
43        self.rank_out()
44    }
45}
46
47impl<D: Data> GGLWELayoutInfos for GGLWECiphertextCompressed<D> {
48    fn rank_in(&self) -> Rank {
49        Rank(self.data.cols_in() as u32)
50    }
51
52    fn rank_out(&self) -> Rank {
53        self.rank_out
54    }
55
56    fn digits(&self) -> Digits {
57        self.digits
58    }
59
60    fn rows(&self) -> Rows {
61        Rows(self.data.rows() as u32)
62    }
63}
64
65impl<D: DataRef> fmt::Debug for GGLWECiphertextCompressed<D> {
66    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
67        write!(f, "{self}")
68    }
69}
70
71impl<D: DataMut> FillUniform for GGLWECiphertextCompressed<D> {
72    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
73        self.data.fill_uniform(log_bound, source);
74    }
75}
76
77impl<D: DataRef> fmt::Display for GGLWECiphertextCompressed<D> {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        write!(
80            f,
81            "(GGLWECiphertextCompressed: base2k={} k={} digits={}) {}",
82            self.base2k.0, self.k.0, self.digits.0, self.data
83        )
84    }
85}
86
87impl GGLWECiphertextCompressed<Vec<u8>> {
88    pub fn alloc<A>(infos: &A) -> Self
89    where
90        A: GGLWELayoutInfos,
91    {
92        Self::alloc_with(
93            infos.n(),
94            infos.base2k(),
95            infos.k(),
96            infos.rows(),
97            infos.digits(),
98            infos.rank_in(),
99            infos.rank_out(),
100        )
101    }
102
103    pub fn alloc_with(
104        n: Degree,
105        base2k: Base2K,
106        k: TorusPrecision,
107        rows: Rows,
108        digits: Digits,
109        rank_in: Rank,
110        rank_out: Rank,
111    ) -> Self {
112        let size: usize = k.0.div_ceil(base2k.0) as usize;
113        debug_assert!(
114            size as u32 > digits.0,
115            "invalid gglwe: ceil(k/base2k): {size} <= digits: {}",
116            digits.0
117        );
118
119        assert!(
120            rows.0 * digits.0 <= size as u32,
121            "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}",
122            rows.0,
123            digits.0,
124        );
125
126        Self {
127            data: MatZnx::alloc(
128                n.into(),
129                rows.into(),
130                rank_in.into(),
131                1,
132                k.0.div_ceil(base2k.0) as usize,
133            ),
134            k,
135            base2k,
136            digits,
137            rank_out,
138            seed: vec![[0u8; 32]; (rows.0 * rank_in.0) as usize],
139        }
140    }
141
142    pub fn alloc_bytes<A>(infos: &A) -> usize
143    where
144        A: GGLWELayoutInfos,
145    {
146        Self::alloc_bytes_with(
147            infos.n(),
148            infos.base2k(),
149            infos.k(),
150            infos.rows(),
151            infos.digits(),
152            infos.rank_in(),
153            infos.rank_out(),
154        )
155    }
156
157    pub fn alloc_bytes_with(
158        n: Degree,
159        base2k: Base2K,
160        k: TorusPrecision,
161        rows: Rows,
162        digits: Digits,
163        rank_in: Rank,
164        _rank_out: Rank,
165    ) -> usize {
166        let size: usize = k.0.div_ceil(base2k.0) as usize;
167        debug_assert!(
168            size as u32 > digits.0,
169            "invalid gglwe: ceil(k/base2k): {size} <= digits: {}",
170            digits.0
171        );
172
173        assert!(
174            rows.0 * digits.0 <= size as u32,
175            "invalid gglwe: rows: {} * digits:{} > ceil(k/base2k): {size}",
176            rows.0,
177            digits.0,
178        );
179
180        MatZnx::alloc_bytes(
181            n.into(),
182            rows.into(),
183            rank_in.into(),
184            1,
185            k.0.div_ceil(base2k.0) as usize,
186        )
187    }
188}
189
190impl<D: DataRef> GGLWECiphertextCompressed<D> {
191    pub(crate) fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> {
192        let rank_in: usize = self.rank_in().into();
193        GLWECiphertextCompressed {
194            data: self.data.at(row, col),
195            k: self.k,
196            base2k: self.base2k,
197            rank: self.rank_out,
198            seed: self.seed[rank_in * row + col],
199        }
200    }
201}
202
203impl<D: DataMut> GGLWECiphertextCompressed<D> {
204    pub(crate) fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> {
205        let rank_in: usize = self.rank_in().into();
206        GLWECiphertextCompressed {
207            k: self.k,
208            base2k: self.base2k,
209            rank: self.rank_out,
210            data: self.data.at_mut(row, col),
211            seed: self.seed[rank_in * row + col], // Warning: value is copied and not borrow mut
212        }
213    }
214}
215
216impl<D: DataMut> ReaderFrom for GGLWECiphertextCompressed<D> {
217    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
218        self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
219        self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
220        self.digits = Digits(reader.read_u32::<LittleEndian>()?);
221        self.rank_out = Rank(reader.read_u32::<LittleEndian>()?);
222        let seed_len: u32 = reader.read_u32::<LittleEndian>()?;
223        self.seed = vec![[0u8; 32]; seed_len as usize];
224        for s in &mut self.seed {
225            reader.read_exact(s)?;
226        }
227        self.data.read_from(reader)
228    }
229}
230
231impl<D: DataRef> WriterTo for GGLWECiphertextCompressed<D> {
232    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
233        writer.write_u32::<LittleEndian>(self.k.into())?;
234        writer.write_u32::<LittleEndian>(self.base2k.into())?;
235        writer.write_u32::<LittleEndian>(self.digits.into())?;
236        writer.write_u32::<LittleEndian>(self.rank_out.into())?;
237        writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
238        for s in &self.seed {
239            writer.write_all(s)?;
240        }
241        self.data.write_to(writer)
242    }
243}
244
245impl<D: DataMut, B: Backend, DR: DataRef> Decompress<B, GGLWECiphertextCompressed<DR>> for GGLWECiphertext<D>
246where
247    Module<B>: VecZnxFillUniform + VecZnxCopy,
248{
249    fn decompress(&mut self, module: &Module<B>, other: &GGLWECiphertextCompressed<DR>) {
250        #[cfg(debug_assertions)]
251        {
252            assert_eq!(
253                self.n(),
254                other.n(),
255                "invalid receiver: self.n()={} != other.n()={}",
256                self.n(),
257                other.n()
258            );
259            assert_eq!(
260                self.size(),
261                other.size(),
262                "invalid receiver: self.size()={} != other.size()={}",
263                self.size(),
264                other.size()
265            );
266            assert_eq!(
267                self.rank_in(),
268                other.rank_in(),
269                "invalid receiver: self.rank_in()={} != other.rank_in()={}",
270                self.rank_in(),
271                other.rank_in()
272            );
273            assert_eq!(
274                self.rank_out(),
275                other.rank_out(),
276                "invalid receiver: self.rank_out()={} != other.rank_out()={}",
277                self.rank_out(),
278                other.rank_out()
279            );
280
281            assert_eq!(
282                self.rows(),
283                other.rows(),
284                "invalid receiver: self.rows()={} != other.rows()={}",
285                self.rows(),
286                other.rows()
287            );
288        }
289
290        let rank_in: usize = self.rank_in().into();
291        let rows: usize = self.rows().into();
292
293        (0..rank_in).for_each(|col_i| {
294            (0..rows).for_each(|row_i| {
295                self.at_mut(row_i, col_i)
296                    .decompress(module, &other.at(row_i, col_i));
297            });
298        });
299    }
300}