poulpy_core/layouts/compressed/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{VecZnxCopy, VecZnxFillUniform},
3    layouts::{Backend, Data, DataMut, DataRef, FillUniform, MatZnx, Module, ReaderFrom, WriterTo, ZnxInfos},
4    source::Source,
5};
6
7use crate::layouts::{
8    Base2K, Degree, Digits, GGSWCiphertext, GGSWInfos, GLWEInfos, LWEInfos, Rank, Rows, TorusPrecision,
9    compressed::{Decompress, GLWECiphertextCompressed},
10};
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12use std::fmt;
13
14#[derive(PartialEq, Eq, Clone)]
15pub struct GGSWCiphertextCompressed<D: Data> {
16    pub(crate) data: MatZnx<D>,
17    pub(crate) k: TorusPrecision,
18    pub(crate) base2k: Base2K,
19    pub(crate) digits: Digits,
20    pub(crate) rank: Rank,
21    pub(crate) seed: Vec<[u8; 32]>,
22}
23
24impl<D: Data> LWEInfos for GGSWCiphertextCompressed<D> {
25    fn n(&self) -> Degree {
26        Degree(self.data.n() as u32)
27    }
28
29    fn base2k(&self) -> Base2K {
30        self.base2k
31    }
32
33    fn k(&self) -> TorusPrecision {
34        self.k
35    }
36    fn size(&self) -> usize {
37        self.data.size()
38    }
39}
40impl<D: Data> GLWEInfos for GGSWCiphertextCompressed<D> {
41    fn rank(&self) -> Rank {
42        self.rank
43    }
44}
45
46impl<D: Data> GGSWInfos for GGSWCiphertextCompressed<D> {
47    fn digits(&self) -> Digits {
48        self.digits
49    }
50
51    fn rows(&self) -> Rows {
52        Rows(self.data.rows() as u32)
53    }
54}
55
56impl<D: DataRef> fmt::Debug for GGSWCiphertextCompressed<D> {
57    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
58        write!(f, "{}", self.data)
59    }
60}
61
62impl<D: DataRef> fmt::Display for GGSWCiphertextCompressed<D> {
63    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
64        write!(
65            f,
66            "(GGSWCiphertextCompressed: base2k={} k={} digits={}) {}",
67            self.base2k, self.k, self.digits, self.data
68        )
69    }
70}
71
72impl<D: DataMut> FillUniform for GGSWCiphertextCompressed<D> {
73    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
74        self.data.fill_uniform(log_bound, source);
75    }
76}
77
78impl GGSWCiphertextCompressed<Vec<u8>> {
79    pub fn alloc<A>(infos: &A) -> Self
80    where
81        A: GGSWInfos,
82    {
83        Self::alloc_with(
84            infos.n(),
85            infos.base2k(),
86            infos.k(),
87            infos.rows(),
88            infos.digits(),
89            infos.rank(),
90        )
91    }
92
93    pub fn alloc_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> Self {
94        let size: usize = k.0.div_ceil(base2k.0) as usize;
95        debug_assert!(
96            size as u32 > digits.0,
97            "invalid ggsw: ceil(k/base2k): {size} <= digits: {}",
98            digits.0
99        );
100
101        assert!(
102            rows.0 * digits.0 <= size as u32,
103            "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}",
104            rows.0,
105            digits.0,
106        );
107
108        Self {
109            data: MatZnx::alloc(
110                n.into(),
111                rows.into(),
112                (rank + 1).into(),
113                1,
114                k.0.div_ceil(base2k.0) as usize,
115            ),
116            k,
117            base2k,
118            digits,
119            rank,
120            seed: Vec::new(),
121        }
122    }
123
124    pub fn alloc_bytes<A>(infos: &A) -> usize
125    where
126        A: GGSWInfos,
127    {
128        Self::alloc_bytes_with(
129            infos.n(),
130            infos.base2k(),
131            infos.k(),
132            infos.rows(),
133            infos.digits(),
134            infos.rank(),
135        )
136    }
137
138    pub fn alloc_bytes_with(n: Degree, base2k: Base2K, k: TorusPrecision, rows: Rows, digits: Digits, rank: Rank) -> usize {
139        let size: usize = k.0.div_ceil(base2k.0) as usize;
140        debug_assert!(
141            size as u32 > digits.0,
142            "invalid ggsw: ceil(k/base2k): {size} <= digits: {}",
143            digits.0
144        );
145
146        assert!(
147            rows.0 * digits.0 <= size as u32,
148            "invalid ggsw: rows: {} * digits:{} > ceil(k/base2k): {size}",
149            rows.0,
150            digits.0,
151        );
152
153        MatZnx::alloc_bytes(
154            n.into(),
155            rows.into(),
156            (rank + 1).into(),
157            1,
158            k.0.div_ceil(base2k.0) as usize,
159        )
160    }
161}
162
163impl<D: DataRef> GGSWCiphertextCompressed<D> {
164    pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> {
165        let rank: usize = self.rank().into();
166        GLWECiphertextCompressed {
167            data: self.data.at(row, col),
168            k: self.k,
169            base2k: self.base2k,
170            rank: self.rank,
171            seed: self.seed[row * (rank + 1) + col],
172        }
173    }
174}
175
176impl<D: DataMut> GGSWCiphertextCompressed<D> {
177    pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> {
178        let rank: usize = self.rank().into();
179        GLWECiphertextCompressed {
180            data: self.data.at_mut(row, col),
181            k: self.k,
182            base2k: self.base2k,
183            rank: self.rank,
184            seed: self.seed[row * (rank + 1) + col],
185        }
186    }
187}
188
189impl<D: DataMut> ReaderFrom for GGSWCiphertextCompressed<D> {
190    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
191        self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
192        self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
193        self.digits = Digits(reader.read_u32::<LittleEndian>()?);
194        self.rank = Rank(reader.read_u32::<LittleEndian>()?);
195        let seed_len: usize = reader.read_u32::<LittleEndian>()? as usize;
196        self.seed = vec![[0u8; 32]; seed_len];
197        for s in &mut self.seed {
198            reader.read_exact(s)?;
199        }
200        self.data.read_from(reader)
201    }
202}
203
204impl<D: DataRef> WriterTo for GGSWCiphertextCompressed<D> {
205    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
206        writer.write_u32::<LittleEndian>(self.k.into())?;
207        writer.write_u32::<LittleEndian>(self.base2k.into())?;
208        writer.write_u32::<LittleEndian>(self.digits.into())?;
209        writer.write_u32::<LittleEndian>(self.rank.into())?;
210        writer.write_u32::<LittleEndian>(self.seed.len() as u32)?;
211        for s in &self.seed {
212            writer.write_all(s)?;
213        }
214        self.data.write_to(writer)
215    }
216}
217
218impl<D: DataMut, B: Backend, DR: DataRef> Decompress<B, GGSWCiphertextCompressed<DR>> for GGSWCiphertext<D>
219where
220    Module<B>: VecZnxFillUniform + VecZnxCopy,
221{
222    fn decompress(&mut self, module: &Module<B>, other: &GGSWCiphertextCompressed<DR>) {
223        #[cfg(debug_assertions)]
224        {
225            assert_eq!(self.rank(), other.rank())
226        }
227
228        let rows: usize = self.rows().into();
229        let rank: usize = self.rank().into();
230        (0..rows).for_each(|row_i| {
231            (0..rank + 1).for_each(|col_j| {
232                self.at_mut(row_i, col_j)
233                    .decompress(module, &other.at(row_i, col_j));
234            });
235        });
236    }
237}