poulpy_core/layouts/
ggsw.rs

1use poulpy_hal::{
2    layouts::{Data, DataMut, DataRef, FillUniform, MatZnx, MatZnxToMut, MatZnxToRef, ReaderFrom, WriterTo, ZnxInfos},
3    source::Source,
4};
5use std::fmt;
6
7use crate::layouts::{Base2K, Degree, Dnum, Dsize, GLWE, GLWEInfos, LWEInfos, Rank, TorusPrecision};
8
9pub trait GGSWInfos
10where
11    Self: GLWEInfos,
12{
13    fn dnum(&self) -> Dnum;
14    fn dsize(&self) -> Dsize;
15    fn ggsw_layout(&self) -> GGSWLayout {
16        GGSWLayout {
17            n: self.n(),
18            base2k: self.base2k(),
19            k: self.k(),
20            rank: self.rank(),
21            dnum: self.dnum(),
22            dsize: self.dsize(),
23        }
24    }
25}
26
27#[derive(PartialEq, Eq, Copy, Clone, Debug)]
28pub struct GGSWLayout {
29    pub n: Degree,
30    pub base2k: Base2K,
31    pub k: TorusPrecision,
32    pub rank: Rank,
33    pub dnum: Dnum,
34    pub dsize: Dsize,
35}
36
37impl LWEInfos for GGSWLayout {
38    fn base2k(&self) -> Base2K {
39        self.base2k
40    }
41
42    fn k(&self) -> TorusPrecision {
43        self.k
44    }
45
46    fn n(&self) -> Degree {
47        self.n
48    }
49}
50impl GLWEInfos for GGSWLayout {
51    fn rank(&self) -> Rank {
52        self.rank
53    }
54}
55
56impl GGSWInfos for GGSWLayout {
57    fn dsize(&self) -> Dsize {
58        self.dsize
59    }
60
61    fn dnum(&self) -> Dnum {
62        self.dnum
63    }
64}
65
66#[derive(PartialEq, Eq, Clone)]
67pub struct GGSW<D: Data> {
68    pub(crate) data: MatZnx<D>,
69    pub(crate) k: TorusPrecision,
70    pub(crate) base2k: Base2K,
71    pub(crate) dsize: Dsize,
72}
73
74impl<D: Data> LWEInfos for GGSW<D> {
75    fn n(&self) -> Degree {
76        Degree(self.data.n() as u32)
77    }
78
79    fn base2k(&self) -> Base2K {
80        self.base2k
81    }
82
83    fn k(&self) -> TorusPrecision {
84        self.k
85    }
86
87    fn size(&self) -> usize {
88        self.data.size()
89    }
90}
91
92impl<D: Data> GLWEInfos for GGSW<D> {
93    fn rank(&self) -> Rank {
94        Rank(self.data.cols_out() as u32 - 1)
95    }
96}
97
98impl<D: Data> GGSWInfos for GGSW<D> {
99    fn dsize(&self) -> Dsize {
100        self.dsize
101    }
102
103    fn dnum(&self) -> Dnum {
104        Dnum(self.data.rows() as u32)
105    }
106}
107
108impl<D: DataRef> fmt::Debug for GGSW<D> {
109    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
110        write!(f, "{}", self.data)
111    }
112}
113
114impl<D: DataRef> fmt::Display for GGSW<D> {
115    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
116        write!(
117            f,
118            "(GGSW: k: {} base2k: {} dsize: {}) {}",
119            self.k().0,
120            self.base2k().0,
121            self.dsize().0,
122            self.data
123        )
124    }
125}
126
127impl<D: DataMut> FillUniform for GGSW<D> {
128    fn fill_uniform(&mut self, log_bound: usize, source: &mut Source) {
129        self.data.fill_uniform(log_bound, source);
130    }
131}
132
133impl<D: DataRef> GGSW<D> {
134    pub fn at(&self, row: usize, col: usize) -> GLWE<&[u8]> {
135        GLWE {
136            k: self.k,
137            base2k: self.base2k,
138            data: self.data.at(row, col),
139        }
140    }
141}
142
143impl<D: DataMut> GGSW<D> {
144    pub fn at_mut(&mut self, row: usize, col: usize) -> GLWE<&mut [u8]> {
145        GLWE {
146            k: self.k,
147            base2k: self.base2k,
148            data: self.data.at_mut(row, col),
149        }
150    }
151}
152
153impl GGSW<Vec<u8>> {
154    pub fn alloc_from_infos<A>(infos: &A) -> Self
155    where
156        A: GGSWInfos,
157    {
158        Self::alloc(
159            infos.n(),
160            infos.base2k(),
161            infos.k(),
162            infos.rank(),
163            infos.dnum(),
164            infos.dsize(),
165        )
166    }
167
168    pub fn alloc(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> Self {
169        let size: usize = k.0.div_ceil(base2k.0) as usize;
170        debug_assert!(
171            size as u32 > dsize.0,
172            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
173            dsize.0
174        );
175
176        assert!(
177            dnum.0 * dsize.0 <= size as u32,
178            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
179            dnum.0,
180            dsize.0,
181        );
182
183        GGSW {
184            data: MatZnx::alloc(
185                n.into(),
186                dnum.into(),
187                (rank + 1).into(),
188                (rank + 1).into(),
189                k.0.div_ceil(base2k.0) as usize,
190            ),
191            k,
192            base2k,
193            dsize,
194        }
195    }
196
197    pub fn bytes_of_from_infos<A>(infos: &A) -> usize
198    where
199        A: GGSWInfos,
200    {
201        Self::bytes_of(
202            infos.n(),
203            infos.base2k(),
204            infos.k(),
205            infos.rank(),
206            infos.dnum(),
207            infos.dsize(),
208        )
209    }
210
211    pub fn bytes_of(n: Degree, base2k: Base2K, k: TorusPrecision, rank: Rank, dnum: Dnum, dsize: Dsize) -> usize {
212        let size: usize = k.0.div_ceil(base2k.0) as usize;
213        debug_assert!(
214            size as u32 > dsize.0,
215            "invalid ggsw: ceil(k/base2k): {size} <= dsize: {}",
216            dsize.0
217        );
218
219        assert!(
220            dnum.0 * dsize.0 <= size as u32,
221            "invalid ggsw: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
222            dnum.0,
223            dsize.0,
224        );
225
226        MatZnx::bytes_of(
227            n.into(),
228            dnum.into(),
229            (rank + 1).into(),
230            (rank + 1).into(),
231            k.0.div_ceil(base2k.0) as usize,
232        )
233    }
234}
235
236use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
237
238impl<D: DataMut> ReaderFrom for GGSW<D> {
239    fn read_from<R: std::io::Read>(&mut self, reader: &mut R) -> std::io::Result<()> {
240        self.k = TorusPrecision(reader.read_u32::<LittleEndian>()?);
241        self.base2k = Base2K(reader.read_u32::<LittleEndian>()?);
242        self.dsize = Dsize(reader.read_u32::<LittleEndian>()?);
243        self.data.read_from(reader)
244    }
245}
246
247impl<D: DataRef> WriterTo for GGSW<D> {
248    fn write_to<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
249        writer.write_u32::<LittleEndian>(self.k.into())?;
250        writer.write_u32::<LittleEndian>(self.base2k.into())?;
251        writer.write_u32::<LittleEndian>(self.dsize.into())?;
252        self.data.write_to(writer)
253    }
254}
255
256pub trait GGSWToMut {
257    fn to_mut(&mut self) -> GGSW<&mut [u8]>;
258}
259
260impl<D: DataMut> GGSWToMut for GGSW<D> {
261    fn to_mut(&mut self) -> GGSW<&mut [u8]> {
262        GGSW {
263            dsize: self.dsize,
264            k: self.k,
265            base2k: self.base2k,
266            data: self.data.to_mut(),
267        }
268    }
269}
270
271pub trait GGSWToRef {
272    fn to_ref(&self) -> GGSW<&[u8]>;
273}
274
275impl<D: DataRef> GGSWToRef for GGSW<D> {
276    fn to_ref(&self) -> GGSW<&[u8]> {
277        GGSW {
278            dsize: self.dsize,
279            k: self.k,
280            base2k: self.base2k,
281            data: self.data.to_ref(),
282        }
283    }
284}