1use std::ops::{Deref, DerefMut};
2
3use poulpy_hal::layouts::{
4 Backend, ScalarZnx, SvpPPolReborrowBackendMut, SvpPPolReborrowBackendRef, VmpPMatReborrowBackendMut,
5 VmpPMatReborrowBackendRef, mat_znx_backend_mut_from_mut, mat_znx_backend_ref_from_mut, vec_znx_backend_mut_from_mut,
6 vec_znx_backend_ref_from_mut, vec_znx_backend_ref_from_ref,
7};
8
9use crate::{
10 GetDistribution, GetDistributionMut,
11 dist::Distribution,
12 layouts::{
13 Base2K, GGLWE, GGLWEBackendMut, GGLWEBackendRef, GGLWEInfos, GGLWEPrepared, GGLWEPreparedBackendMut,
14 GGLWEPreparedBackendRef, GGLWEPreparedToBackendMut, GGLWEPreparedToBackendRef, GGLWEToBackendMut, GGLWEToBackendRef,
15 GGSW, GGSWBackendMut, GGSWBackendRef, GGSWInfos, GGSWPrepared, GGSWPreparedBackendMut, GGSWPreparedBackendRef,
16 GGSWPreparedToBackendMut, GGSWPreparedToBackendRef, GGSWToBackendMut, GGSWToBackendRef, GLWE, GLWEBackendMut,
17 GLWEBackendRef, GLWEInfos, GLWEPlaintext, GLWESecret, GLWESecretBackendMut, GLWESecretBackendRef, GLWESecretPrepared,
18 GLWESecretPreparedBackendMut, GLWESecretPreparedBackendRef, GLWESecretPreparedToBackendMut,
19 GLWESecretPreparedToBackendRef, GLWESecretTensor, GLWESecretToBackendMut, GLWESecretToBackendRef, GLWETensor,
20 GLWEToBackendMut, GLWEToBackendRef, LWE, LWEBackendMut, LWEBackendRef, LWEInfos, LWEPlaintext, LWEPlaintextBackendMut,
21 LWEPlaintextBackendRef, LWEPlaintextToBackendMut, LWEPlaintextToBackendRef, LWEToBackendMut, LWEToBackendRef, Rank,
22 SetGGLWEInfos, SetLWEInfos,
23 },
24};
25
26macro_rules! view_wrapper {
27 ($name:ident, $inner:ty) => {
28 pub struct $name<'a, BE: Backend + 'a> {
29 inner: $inner,
30 }
31
32 impl<'a, BE: Backend + 'a> $name<'a, BE> {
33 pub fn from_inner(inner: $inner) -> Self {
34 Self { inner }
35 }
36
37 pub fn into_inner(self) -> $inner {
38 self.inner
39 }
40 }
41
42 impl<'a, BE: Backend + 'a> Deref for $name<'a, BE> {
43 type Target = $inner;
44
45 fn deref(&self) -> &Self::Target {
46 &self.inner
47 }
48 }
49
50 impl<'a, BE: Backend + 'a> DerefMut for $name<'a, BE> {
51 fn deref_mut(&mut self) -> &mut Self::Target {
52 &mut self.inner
53 }
54 }
55
56 impl<'a, BE: Backend + 'a> LWEInfos for $name<'a, BE> {
57 fn base2k(&self) -> Base2K {
58 self.inner.base2k()
59 }
60
61 fn n(&self) -> crate::layouts::Degree {
62 self.inner.n()
63 }
64
65 fn size(&self) -> usize {
66 self.inner.size()
67 }
68 }
69 };
70}
71
72view_wrapper!(LWEViewMut, LWE<BE::BufMut<'a>>);
73view_wrapper!(LWEPlaintextViewMut, LWEPlaintext<BE::BufMut<'a>>);
74view_wrapper!(GLWEViewRef, GLWE<BE::BufRef<'a>>);
75view_wrapper!(GLWEViewMut, GLWE<BE::BufMut<'a>>);
76view_wrapper!(GLWEPlaintextViewMut, GLWEPlaintext<BE::BufMut<'a>>);
77view_wrapper!(GLWETensorViewMut, GLWETensor<BE::BufMut<'a>>);
78view_wrapper!(GLWESecretViewMut, GLWESecret<BE::BufMut<'a>>);
79view_wrapper!(GLWESecretTensorViewMut, GLWESecretTensor<BE::BufMut<'a>>);
80view_wrapper!(GLWESecretPreparedViewMut, GLWESecretPrepared<BE::BufMut<'a>, BE>);
81view_wrapper!(GGLWEViewMut, GGLWE<BE::BufMut<'a>>);
82view_wrapper!(GGLWEPreparedViewMut, GGLWEPrepared<BE::BufMut<'a>, BE>);
83view_wrapper!(GGSWViewMut, GGSW<BE::BufMut<'a>>);
84view_wrapper!(GGSWPreparedViewMut, GGSWPrepared<BE::BufMut<'a>, BE>);
85
86impl<'a, BE: Backend + 'a> GGLWEViewMut<'a, BE> {
87 pub fn at_view(&self, row: usize, col: usize) -> GLWEViewRef<'_, BE> {
88 GLWEViewRef::from_inner(crate::layouts::gglwe_at_backend_ref_from_mut::<BE>(&self.inner, row, col))
89 }
90
91 pub fn at_view_mut(&mut self, row: usize, col: usize) -> GLWEViewMut<'_, BE> {
92 GLWEViewMut::from_inner(crate::layouts::gglwe_at_backend_mut_from_mut::<BE>(&mut self.inner, row, col))
93 }
94}
95
96macro_rules! impl_set_lwe_infos {
97 ($name:ident) => {
98 impl<'a, BE: Backend + 'a> SetLWEInfos for $name<'a, BE> {
99 fn set_base2k(&mut self, base2k: Base2K) {
100 self.inner.set_base2k(base2k);
101 }
102 }
103 };
104}
105
106impl_set_lwe_infos!(LWEViewMut);
107impl_set_lwe_infos!(GLWEViewMut);
108impl_set_lwe_infos!(GLWEPlaintextViewMut);
109
110impl<'a, BE: Backend + 'a> SetLWEInfos for LWEPlaintextViewMut<'a, BE> {
111 fn set_base2k(&mut self, base2k: Base2K) {
112 self.inner.base2k = base2k;
113 }
114}
115
116macro_rules! impl_glwe_infos {
117 ($name:ident) => {
118 impl<'a, BE: Backend + 'a> GLWEInfos for $name<'a, BE> {
119 fn rank(&self) -> Rank {
120 self.inner.rank()
121 }
122 }
123 };
124}
125
126impl_glwe_infos!(GLWEViewMut);
127impl_glwe_infos!(GLWEViewRef);
128impl_glwe_infos!(GLWEPlaintextViewMut);
129impl_glwe_infos!(GLWETensorViewMut);
130impl_glwe_infos!(GLWESecretViewMut);
131impl_glwe_infos!(GLWESecretTensorViewMut);
132impl_glwe_infos!(GLWESecretPreparedViewMut);
133impl_glwe_infos!(GGLWEViewMut);
134impl_glwe_infos!(GGLWEPreparedViewMut);
135impl_glwe_infos!(GGSWViewMut);
136impl_glwe_infos!(GGSWPreparedViewMut);
137
138macro_rules! impl_dist {
139 ($name:ident) => {
140 impl<'a, BE: Backend + 'a> GetDistribution for $name<'a, BE> {
141 fn dist(&self) -> &Distribution {
142 self.inner.dist()
143 }
144 }
145
146 impl<'a, BE: Backend + 'a> GetDistributionMut for $name<'a, BE> {
147 fn dist_mut(&mut self) -> &mut Distribution {
148 self.inner.dist_mut()
149 }
150 }
151 };
152}
153
154impl_dist!(GLWESecretTensorViewMut);
155impl_dist!(GLWESecretPreparedViewMut);
156
157impl<'a, BE: Backend + 'a> GetDistribution for GLWESecretViewMut<'a, BE> {
158 fn dist(&self) -> &Distribution {
159 self.inner.dist()
160 }
161}
162
163impl<'a, BE: Backend + 'a> GGLWEInfos for GGLWEViewMut<'a, BE> {
164 fn dnum(&self) -> crate::layouts::Dnum {
165 self.inner.dnum()
166 }
167
168 fn dsize(&self) -> crate::layouts::Dsize {
169 self.inner.dsize()
170 }
171
172 fn rank_in(&self) -> Rank {
173 self.inner.rank_in()
174 }
175
176 fn rank_out(&self) -> Rank {
177 self.inner.rank_out()
178 }
179}
180
181impl<'a, BE: Backend + 'a> GGLWEInfos for GGLWEPreparedViewMut<'a, BE> {
182 fn dnum(&self) -> crate::layouts::Dnum {
183 self.inner.dnum()
184 }
185
186 fn dsize(&self) -> crate::layouts::Dsize {
187 self.inner.dsize()
188 }
189
190 fn rank_in(&self) -> Rank {
191 self.inner.rank_in()
192 }
193
194 fn rank_out(&self) -> Rank {
195 self.inner.rank_out()
196 }
197}
198
199impl<'a, BE: Backend + 'a> SetGGLWEInfos for GGLWEViewMut<'a, BE> {
200 fn set_dsize(&mut self, dsize: usize) {
201 self.inner.dsize = dsize.into();
202 }
203}
204
205impl<'a, BE: Backend + 'a> GGSWInfos for GGSWViewMut<'a, BE> {
206 fn dnum(&self) -> crate::layouts::Dnum {
207 self.inner.dnum()
208 }
209
210 fn dsize(&self) -> crate::layouts::Dsize {
211 self.inner.dsize()
212 }
213}
214
215impl<'a, BE: Backend + 'a> GGSWInfos for GGSWPreparedViewMut<'a, BE> {
216 fn dnum(&self) -> crate::layouts::Dnum {
217 self.inner.dnum()
218 }
219
220 fn dsize(&self) -> crate::layouts::Dsize {
221 self.inner.dsize()
222 }
223}
224
225impl<'a, BE: Backend + 'a> LWEToBackendRef<BE> for LWEViewMut<'a, BE> {
226 fn to_backend_ref(&self) -> LWEBackendRef<'_, BE> {
227 LWE {
228 base2k: self.inner.base2k,
229 body: vec_znx_backend_ref_from_mut::<BE>(&self.inner.body),
230 mask: vec_znx_backend_ref_from_mut::<BE>(&self.inner.mask),
231 }
232 }
233}
234
235impl<'a, BE: Backend + 'a> LWEToBackendMut<BE> for LWEViewMut<'a, BE> {
236 fn to_backend_mut(&mut self) -> LWEBackendMut<'_, BE> {
237 let base2k = self.inner.base2k;
238 let body = vec_znx_backend_mut_from_mut::<BE>(&mut self.inner.body);
239 let mask = vec_znx_backend_mut_from_mut::<BE>(&mut self.inner.mask);
240 LWE { base2k, body, mask }
241 }
242}
243
244impl<'a, BE: Backend + 'a> LWEPlaintextToBackendRef<BE> for LWEPlaintextViewMut<'a, BE> {
245 fn to_backend_ref(&self) -> LWEPlaintextBackendRef<'_, BE> {
246 LWEPlaintext {
247 base2k: self.inner.base2k,
248 data: vec_znx_backend_ref_from_mut::<BE>(&self.inner.data),
249 }
250 }
251}
252
253impl<'a, BE: Backend + 'a> LWEPlaintextToBackendMut<BE> for LWEPlaintextViewMut<'a, BE> {
254 fn to_backend_mut(&mut self) -> LWEPlaintextBackendMut<'_, BE> {
255 LWEPlaintext {
256 base2k: self.inner.base2k,
257 data: vec_znx_backend_mut_from_mut::<BE>(&mut self.inner.data),
258 }
259 }
260}
261
262macro_rules! impl_glwe_to_backend {
263 ($name:ident) => {
264 impl<'a, BE: Backend + 'a> GLWEToBackendRef<BE> for $name<'a, BE> {
265 fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
266 GLWE {
267 base2k: self.inner.base2k,
268 data: vec_znx_backend_ref_from_mut::<BE>(&self.inner.data),
269 }
270 }
271 }
272
273 impl<'a, BE: Backend + 'a> GLWEToBackendMut<BE> for $name<'a, BE> {
274 fn to_backend_mut(&mut self) -> GLWEBackendMut<'_, BE> {
275 GLWE {
276 base2k: self.inner.base2k,
277 data: vec_znx_backend_mut_from_mut::<BE>(&mut self.inner.data),
278 }
279 }
280 }
281 };
282}
283
284impl_glwe_to_backend!(GLWEViewMut);
285impl_glwe_to_backend!(GLWEPlaintextViewMut);
286impl_glwe_to_backend!(GLWETensorViewMut);
287
288impl<'a, BE: Backend + 'a> GLWEToBackendRef<BE> for GLWEViewRef<'a, BE> {
289 fn to_backend_ref(&self) -> GLWEBackendRef<'_, BE> {
290 GLWE {
291 base2k: self.inner.base2k,
292 data: vec_znx_backend_ref_from_ref::<BE>(&self.inner.data),
293 }
294 }
295}
296
297impl<'a, BE: Backend + 'a> GLWESecretToBackendRef<BE> for GLWESecretViewMut<'a, BE> {
298 fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
299 GLWESecret {
300 dist: self.inner.dist,
301 data: ScalarZnx::from_data(
302 BE::view_ref_mut(&self.inner.data.data),
303 self.inner.data.n(),
304 self.inner.data.cols(),
305 ),
306 }
307 }
308}
309
310impl<'a, BE: Backend + 'a> GLWESecretToBackendMut<BE> for GLWESecretViewMut<'a, BE> {
311 fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
312 let n = self.inner.data.n();
313 let cols = self.inner.data.cols();
314 GLWESecret {
315 dist: self.inner.dist,
316 data: ScalarZnx::from_data(BE::view_mut_ref(&mut self.inner.data.data), n, cols),
317 }
318 }
319}
320
321impl<'a, BE: Backend + 'a> GLWESecretToBackendRef<BE> for GLWESecretTensorViewMut<'a, BE> {
322 fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
323 GLWESecret {
324 dist: self.inner.dist,
325 data: ScalarZnx::from_data(
326 BE::view_ref_mut(&self.inner.data.data),
327 self.inner.data.n(),
328 self.inner.data.cols(),
329 ),
330 }
331 }
332}
333
334impl<'a, BE: Backend + 'a> GLWESecretToBackendMut<BE> for GLWESecretTensorViewMut<'a, BE> {
335 fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
336 let n = self.inner.data.n();
337 let cols = self.inner.data.cols();
338 GLWESecret {
339 dist: self.inner.dist,
340 data: ScalarZnx::from_data(BE::view_mut_ref(&mut self.inner.data.data), n, cols),
341 }
342 }
343}
344
345impl<'a, BE: Backend + 'a> GLWESecretPreparedToBackendRef<BE> for GLWESecretPreparedViewMut<'a, BE> {
346 fn to_backend_ref(&self) -> GLWESecretPreparedBackendRef<'_, BE> {
347 GLWESecretPrepared {
348 dist: self.inner.dist,
349 data: self.inner.data.reborrow_backend_ref(),
350 }
351 }
352}
353
354impl<'a, BE: Backend + 'a> GLWESecretPreparedToBackendMut<BE> for GLWESecretPreparedViewMut<'a, BE> {
355 fn to_backend_mut(&mut self) -> GLWESecretPreparedBackendMut<'_, BE> {
356 GLWESecretPrepared {
357 dist: self.inner.dist,
358 data: self.inner.data.reborrow_backend_mut(),
359 }
360 }
361}
362
363impl<'a, BE: Backend + 'a> GGLWEToBackendRef<BE> for GGLWEViewMut<'a, BE> {
364 fn to_backend_ref(&self) -> GGLWEBackendRef<'_, BE> {
365 GGLWEBackendRef::from_inner(GGLWE {
366 base2k: self.inner.base2k,
367 dsize: self.inner.dsize,
368 data: mat_znx_backend_ref_from_mut::<BE>(&self.inner.data),
369 })
370 }
371}
372
373impl<'a, BE: Backend + 'a> GGLWEToBackendMut<BE> for GGLWEViewMut<'a, BE> {
374 fn to_backend_mut(&mut self) -> GGLWEBackendMut<'_, BE> {
375 GGLWEBackendMut::from_inner(GGLWE {
376 base2k: self.inner.base2k,
377 dsize: self.inner.dsize,
378 data: mat_znx_backend_mut_from_mut::<BE>(&mut self.inner.data),
379 })
380 }
381}
382
383impl<'a, BE: Backend + 'a> GGLWEPreparedToBackendRef<BE> for GGLWEPreparedViewMut<'a, BE> {
384 fn to_backend_ref(&self) -> GGLWEPreparedBackendRef<'_, BE> {
385 GGLWEPrepared {
386 base2k: self.inner.base2k,
387 dsize: self.inner.dsize,
388 data: self.inner.data.reborrow_backend_ref(),
389 }
390 }
391}
392
393impl<'a, BE: Backend + 'a> GGLWEPreparedToBackendMut<BE> for GGLWEPreparedViewMut<'a, BE> {
394 fn to_backend_mut(&mut self) -> GGLWEPreparedBackendMut<'_, BE> {
395 GGLWEPrepared {
396 base2k: self.inner.base2k,
397 dsize: self.inner.dsize,
398 data: self.inner.data.reborrow_backend_mut(),
399 }
400 }
401}
402
403impl<'a, BE: Backend + 'a> GGSWToBackendRef<BE> for GGSWViewMut<'a, BE> {
404 fn to_backend_ref(&self) -> GGSWBackendRef<'_, BE> {
405 GGSWBackendRef::from_inner(GGSW {
406 base2k: self.inner.base2k,
407 dsize: self.inner.dsize,
408 data: mat_znx_backend_ref_from_mut::<BE>(&self.inner.data),
409 })
410 }
411}
412
413impl<'a, BE: Backend + 'a> GGSWToBackendMut<BE> for GGSWViewMut<'a, BE> {
414 fn to_backend_mut(&mut self) -> GGSWBackendMut<'_, BE> {
415 GGSWBackendMut::from_inner(GGSW {
416 base2k: self.inner.base2k,
417 dsize: self.inner.dsize,
418 data: mat_znx_backend_mut_from_mut::<BE>(&mut self.inner.data),
419 })
420 }
421}
422
423impl<'a, BE: Backend + 'a> GGSWPreparedToBackendRef<BE> for GGSWPreparedViewMut<'a, BE> {
424 fn to_backend_ref(&self) -> GGSWPreparedBackendRef<'_, BE> {
425 GGSWPrepared {
426 base2k: self.inner.base2k,
427 dsize: self.inner.dsize,
428 data: self.inner.data.reborrow_backend_ref(),
429 }
430 }
431}
432
433impl<'a, BE: Backend + 'a> GGSWPreparedToBackendMut<BE> for GGSWPreparedViewMut<'a, BE> {
434 fn to_backend_mut(&mut self) -> GGSWPreparedBackendMut<'_, BE> {
435 GGSWPrepared {
436 base2k: self.inner.base2k,
437 dsize: self.inner.dsize,
438 data: self.inner.data.reborrow_backend_mut(),
439 }
440 }
441}