poulpy_core/layouts/compressed/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{FillUniform, Reset, VecZnxCopy, VecZnxFillUniform},
3    layouts::{Backend, Data, DataMut, DataRef, MatZnx, Module, ReaderFrom, WriterTo},
4    source::Source,
5};
6
7use crate::layouts::{
8    GGSWCiphertext, Infos,
9    compressed::{Decompress, GLWECiphertextCompressed},
10};
11use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
12use std::fmt;
13
14#[derive(PartialEq, Eq, Clone)]
15pub struct GGSWCiphertextCompressed<D: Data> {
16    pub(crate) data: MatZnx<D>,
17    pub(crate) basek: usize,
18    pub(crate) k: usize,
19    pub(crate) digits: usize,
20    pub(crate) rank: usize,
21    pub(crate) seed: Vec<[u8; 32]>,
22}
23
24impl<D: DataRef> fmt::Debug for GGSWCiphertextCompressed<D> {
25    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
26        write!(f, "{}", self.data)
27    }
28}
29
30impl<D: DataRef> fmt::Display for GGSWCiphertextCompressed<D> {
31    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
32        write!(
33            f,
34            "(GGSWCiphertextCompressed: basek={} k={} digits={}) {}",
35            self.basek, self.k, self.digits, self.data
36        )
37    }
38}
39
40impl<D: DataMut> Reset for GGSWCiphertextCompressed<D> {
41    fn reset(&mut self) {
42        self.data.reset();
43        self.basek = 0;
44        self.k = 0;
45        self.digits = 0;
46        self.rank = 0;
47        self.seed = Vec::new();
48    }
49}
50
51impl<D: DataMut> FillUniform for GGSWCiphertextCompressed<D> {
52    fn fill_uniform(&mut self, source: &mut Source) {
53        self.data.fill_uniform(source);
54    }
55}
56
57impl GGSWCiphertextCompressed<Vec<u8>> {
58    pub fn alloc(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> Self {
59        let size: usize = k.div_ceil(basek);
60        debug_assert!(digits > 0, "invalid ggsw: `digits` == 0");
61
62        debug_assert!(
63            size > digits,
64            "invalid ggsw: ceil(k/basek): {} <= digits: {}",
65            size,
66            digits
67        );
68
69        assert!(
70            rows * digits <= size,
71            "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
72            rows,
73            digits,
74            size
75        );
76
77        Self {
78            data: MatZnx::alloc(n, rows, rank + 1, 1, k.div_ceil(basek)),
79            basek,
80            k,
81            digits,
82            rank,
83            seed: Vec::new(),
84        }
85    }
86
87    pub fn bytes_of(n: usize, basek: usize, k: usize, rows: usize, digits: usize, rank: usize) -> usize {
88        let size: usize = k.div_ceil(basek);
89        debug_assert!(
90            size > digits,
91            "invalid ggsw: ceil(k/basek): {} <= digits: {}",
92            size,
93            digits
94        );
95
96        assert!(
97            rows * digits <= size,
98            "invalid ggsw: rows: {} * digits:{} > ceil(k/basek): {}",
99            rows,
100            digits,
101            size
102        );
103
104        MatZnx::alloc_bytes(n, rows, rank + 1, 1, size)
105    }
106}
107
108impl<D: DataRef> GGSWCiphertextCompressed<D> {
109    pub fn at(&self, row: usize, col: usize) -> GLWECiphertextCompressed<&[u8]> {
110        GLWECiphertextCompressed {
111            data: self.data.at(row, col),
112            basek: self.basek,
113            k: self.k,
114            rank: self.rank(),
115            seed: self.seed[row * (self.rank() + 1) + col],
116        }
117    }
118}
119
120impl<D: DataMut> GGSWCiphertextCompressed<D> {
121    pub fn at_mut(&mut self, row: usize, col: usize) -> GLWECiphertextCompressed<&mut [u8]> {
122        let rank: usize = self.rank();
123        GLWECiphertextCompressed {
124            data: self.data.at_mut(row, col),
125            basek: self.basek,
126            k: self.k,
127            rank,
128            seed: self.seed[row * (rank + 1) + col],
129        }
130    }
131}
132
133impl<D: Data> Infos for GGSWCiphertextCompressed<D> {
134    type Inner = MatZnx<D>;
135
136    fn inner(&self) -> &Self::Inner {
137        &self.data
138    }
139
140    fn basek(&self) -> usize {
141        self.basek
142    }
143
144    fn k(&self) -> usize {
145        self.k
146    }
147}
148
149impl<D: Data> GGSWCiphertextCompressed<D> {
150    pub fn rank(&self) -> usize {
151        self.rank
152    }
153
154    pub fn digits(&self) -> usize {
155        self.digits
156    }
157}
158
159impl<D: DataMut> ReaderFrom for GGSWCiphertextCompressed<D> {
160    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
161        self.k = reader.read_u64::<LittleEndian>()? as usize;
162        self.basek = reader.read_u64::<LittleEndian>()? as usize;
163        self.digits = reader.read_u64::<LittleEndian>()? as usize;
164        self.rank = reader.read_u64::<LittleEndian>()? as usize;
165        let seed_len = reader.read_u64::<LittleEndian>()? as usize;
166        self.seed = vec![[0u8; 32]; seed_len];
167        for s in &mut self.seed {
168            reader.read_exact(s)?;
169        }
170        self.data.read_from(reader)
171    }
172}
173
174impl<D: DataRef> WriterTo for GGSWCiphertextCompressed<D> {
175    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
176        writer.write_u64::<LittleEndian>(self.k as u64)?;
177        writer.write_u64::<LittleEndian>(self.basek as u64)?;
178        writer.write_u64::<LittleEndian>(self.digits as u64)?;
179        writer.write_u64::<LittleEndian>(self.rank as u64)?;
180        writer.write_u64::<LittleEndian>(self.seed.len() as u64)?;
181        for s in &self.seed {
182            writer.write_all(s)?;
183        }
184        self.data.write_to(writer)
185    }
186}
187
188impl<D: DataMut, B: Backend, DR: DataRef> Decompress<B, GGSWCiphertextCompressed<DR>> for GGSWCiphertext<D> {
189    fn decompress(&mut self, module: &Module<B>, other: &GGSWCiphertextCompressed<DR>)
190    where
191        Module<B>: VecZnxFillUniform + VecZnxCopy,
192    {
193        #[cfg(debug_assertions)]
194        {
195            assert_eq!(self.rank(), other.rank())
196        }
197
198        let rows: usize = self.rows();
199        let rank: usize = self.rank();
200        (0..rows).for_each(|row_i| {
201            (0..rank + 1).for_each(|col_j| {
202                self.at_mut(row_i, col_j)
203                    .decompress(module, &other.at(row_i, col_j));
204            });
205        });
206    }
207}