1use poulpy_hal::{
2 api::{VmpPMatAlloc, VmpPMatBytesOf, VmpPrepare, VmpPrepareTmpBytes},
3 layouts::{Backend, Data, DataMut, DataRef, Module, Scratch, VmpPMat, VmpPMatToMut, VmpPMatToRef, ZnxInfos},
4};
5
6use crate::layouts::{
7 Base2K, Degree, Dnum, Dsize, GGLWE, GGLWEInfos, GGLWEToRef, GLWEInfos, GetDegree, LWEInfos, Rank, TorusPrecision,
8};
9
10#[derive(PartialEq, Eq)]
11pub struct GGLWEPrepared<D: Data, B: Backend> {
12 pub(crate) data: VmpPMat<D, B>,
13 pub(crate) k: TorusPrecision,
14 pub(crate) base2k: Base2K,
15 pub(crate) dsize: Dsize,
16}
17
18impl<D: Data, B: Backend> LWEInfos for GGLWEPrepared<D, B> {
19 fn n(&self) -> Degree {
20 Degree(self.data.n() as u32)
21 }
22
23 fn base2k(&self) -> Base2K {
24 self.base2k
25 }
26
27 fn k(&self) -> TorusPrecision {
28 self.k
29 }
30
31 fn size(&self) -> usize {
32 self.data.size()
33 }
34}
35
36impl<D: Data, B: Backend> GLWEInfos for GGLWEPrepared<D, B> {
37 fn rank(&self) -> Rank {
38 self.rank_out()
39 }
40}
41
42impl<D: Data, B: Backend> GGLWEInfos for GGLWEPrepared<D, B> {
43 fn rank_in(&self) -> Rank {
44 Rank(self.data.cols_in() as u32)
45 }
46
47 fn rank_out(&self) -> Rank {
48 Rank(self.data.cols_out() as u32 - 1)
49 }
50
51 fn dsize(&self) -> Dsize {
52 self.dsize
53 }
54
55 fn dnum(&self) -> Dnum {
56 Dnum(self.data.rows() as u32)
57 }
58}
59
60pub trait GGLWEPreparedFactory<BE: Backend>
61where
62 Self: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes,
63{
64 fn alloc_gglwe_prepared(
65 &self,
66 base2k: Base2K,
67 k: TorusPrecision,
68 rank_in: Rank,
69 rank_out: Rank,
70 dnum: Dnum,
71 dsize: Dsize,
72 ) -> GGLWEPrepared<Vec<u8>, BE> {
73 let size: usize = k.0.div_ceil(base2k.0) as usize;
74 debug_assert!(
75 size as u32 > dsize.0,
76 "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
77 dsize.0
78 );
79
80 assert!(
81 dnum.0 * dsize.0 <= size as u32,
82 "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
83 dnum.0,
84 dsize.0,
85 );
86
87 GGLWEPrepared {
88 data: self.vmp_pmat_alloc(dnum.into(), rank_in.into(), (rank_out + 1).into(), size),
89 k,
90 base2k,
91 dsize,
92 }
93 }
94
95 fn alloc_gglwe_prepared_from_infos<A>(&self, infos: &A) -> GGLWEPrepared<Vec<u8>, BE>
96 where
97 A: GGLWEInfos,
98 {
99 assert_eq!(self.ring_degree(), infos.n());
100 self.alloc_gglwe_prepared(
101 infos.base2k(),
102 infos.k(),
103 infos.rank_in(),
104 infos.rank_out(),
105 infos.dnum(),
106 infos.dsize(),
107 )
108 }
109
110 fn bytes_of_gglwe_prepared(
111 &self,
112 base2k: Base2K,
113 k: TorusPrecision,
114 rank_in: Rank,
115 rank_out: Rank,
116 dnum: Dnum,
117 dsize: Dsize,
118 ) -> usize {
119 let size: usize = k.0.div_ceil(base2k.0) as usize;
120 debug_assert!(
121 size as u32 > dsize.0,
122 "invalid gglwe: ceil(k/base2k): {size} <= dsize: {}",
123 dsize.0
124 );
125
126 assert!(
127 dnum.0 * dsize.0 <= size as u32,
128 "invalid gglwe: dnum: {} * dsize:{} > ceil(k/base2k): {size}",
129 dnum.0,
130 dsize.0,
131 );
132
133 self.bytes_of_vmp_pmat(dnum.into(), rank_in.into(), (rank_out + 1).into(), size)
134 }
135
136 fn bytes_of_gglwe_prepared_from_infos<A>(&self, infos: &A) -> usize
137 where
138 A: GGLWEInfos,
139 {
140 assert_eq!(self.ring_degree(), infos.n());
141 self.bytes_of_gglwe_prepared(
142 infos.base2k(),
143 infos.k(),
144 infos.rank_in(),
145 infos.rank_out(),
146 infos.dnum(),
147 infos.dsize(),
148 )
149 }
150
151 fn prepare_gglwe_tmp_bytes<A>(&self, infos: &A) -> usize
152 where
153 A: GGLWEInfos,
154 {
155 self.vmp_prepare_tmp_bytes(
156 infos.dnum().into(),
157 infos.rank_in().into(),
158 (infos.rank() + 1).into(),
159 infos.size(),
160 )
161 }
162
163 fn prepare_gglwe<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<BE>)
164 where
165 R: GGLWEPreparedToMut<BE>,
166 O: GGLWEToRef,
167 {
168 let mut res: GGLWEPrepared<&mut [u8], BE> = res.to_mut();
169 let other: GGLWE<&[u8]> = other.to_ref();
170
171 assert_eq!(res.n(), self.ring_degree());
172 assert_eq!(other.n(), self.ring_degree());
173 assert_eq!(res.base2k, other.base2k);
174 assert_eq!(res.k, other.k);
175 assert_eq!(res.dsize, other.dsize);
176
177 self.vmp_prepare(&mut res.data, &other.data, scratch);
178 }
179}
180
181impl<BE: Backend> GGLWEPreparedFactory<BE> for Module<BE> where
182 Module<BE>: GetDegree + VmpPMatAlloc<BE> + VmpPMatBytesOf + VmpPrepare<BE> + VmpPrepareTmpBytes
183{
184}
185
186impl<B: Backend> GGLWEPrepared<Vec<u8>, B> {
187 pub fn alloc_from_infos<A, M>(module: &M, infos: &A) -> Self
188 where
189 A: GGLWEInfos,
190 M: GGLWEPreparedFactory<B>,
191 {
192 module.alloc_gglwe_prepared_from_infos(infos)
193 }
194
195 pub fn alloc<M>(
196 module: &M,
197 base2k: Base2K,
198 k: TorusPrecision,
199 rank_in: Rank,
200 rank_out: Rank,
201 dnum: Dnum,
202 dsize: Dsize,
203 ) -> Self
204 where
205 M: GGLWEPreparedFactory<B>,
206 {
207 module.alloc_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize)
208 }
209
210 pub fn bytes_of_from_infos<A, M>(module: &M, infos: &A) -> usize
211 where
212 A: GGLWEInfos,
213 M: GGLWEPreparedFactory<B>,
214 {
215 module.bytes_of_gglwe_prepared_from_infos(infos)
216 }
217
218 pub fn bytes_of<M>(
219 module: &M,
220 base2k: Base2K,
221 k: TorusPrecision,
222 rank_in: Rank,
223 rank_out: Rank,
224 dnum: Dnum,
225 dsize: Dsize,
226 ) -> usize
227 where
228 M: GGLWEPreparedFactory<B>,
229 {
230 module.bytes_of_gglwe_prepared(base2k, k, rank_in, rank_out, dnum, dsize)
231 }
232}
233
234impl<D: DataMut, B: Backend> GGLWEPrepared<D, B> {
235 pub fn prepare<O, M>(&mut self, module: &M, other: &O, scratch: &mut Scratch<B>)
236 where
237 O: GGLWEToRef,
238 M: GGLWEPreparedFactory<B>,
239 {
240 module.prepare_gglwe(self, other, scratch);
241 }
242}
243
244impl<B: Backend> GGLWEPrepared<Vec<u8>, B> {
245 pub fn prepare_tmp_bytes<M>(&self, module: &M) -> usize
246 where
247 M: GGLWEPreparedFactory<B>,
248 {
249 module.prepare_gglwe_tmp_bytes(self)
250 }
251}
252
253pub trait GGLWEPreparedToMut<B: Backend> {
254 fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B>;
255}
256
257impl<D: DataMut, B: Backend> GGLWEPreparedToMut<B> for GGLWEPrepared<D, B> {
258 fn to_mut(&mut self) -> GGLWEPrepared<&mut [u8], B> {
259 GGLWEPrepared {
260 k: self.k,
261 base2k: self.base2k,
262 dsize: self.dsize,
263 data: self.data.to_mut(),
264 }
265 }
266}
267
268pub trait GGLWEPreparedToRef<B: Backend> {
269 fn to_ref(&self) -> GGLWEPrepared<&[u8], B>;
270}
271
272impl<D: DataRef, B: Backend> GGLWEPreparedToRef<B> for GGLWEPrepared<D, B> {
273 fn to_ref(&self) -> GGLWEPrepared<&[u8], B> {
274 GGLWEPrepared {
275 k: self.k,
276 base2k: self.base2k,
277 dsize: self.dsize,
278 data: self.data.to_ref(),
279 }
280 }
281}