1use poulpy_hal::{
2 api::{
3 ModuleN, SvpApplyDftToDft, SvpPrepare, VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
5 },
6 layouts::{
7 Backend, Data, HostDataMut, HostDataRef, Module, ScalarZnx, ScalarZnxToBackendRef, ScratchArena, ScratchOwned,
8 SvpPPolReborrowBackendMut, SvpPPolReborrowBackendRef, VecZnxBigToBackendMut, VecZnxBigToBackendRef, VecZnxDft,
9 VecZnxDftToBackendMut, VecZnxDftToBackendRef, ZnxView, ZnxViewMut, scalar_znx_as_vec_znx_backend_mut_from_mut,
10 scalar_znx_as_vec_znx_backend_ref_from_ref,
11 },
12};
13
14use crate::{
15 GetDistribution, GetDistributionMut, ScratchArenaTakeCore,
16 dist::Distribution,
17 layouts::{
18 Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretBackendMut, GLWESecretBackendRef, GLWESecretPreparedFactory,
19 GLWESecretToBackendMut, GLWESecretToBackendRef, LWEInfos, Rank,
20 },
21};
22
23pub struct GLWESecretTensor<D: Data> {
24 pub(crate) data: ScalarZnx<D>,
25 pub(crate) rank: Rank,
26 pub(crate) dist: Distribution,
27}
28
29impl GLWESecretTensor<Vec<u8>> {
30 pub(crate) fn pairs(rank: usize) -> usize {
31 (((rank + 1) * rank) >> 1).max(1)
32 }
33}
34
35impl<D: Data> GetDistribution for GLWESecretTensor<D> {
36 fn dist(&self) -> &Distribution {
37 &self.dist
38 }
39}
40
41impl<D: Data> GetDistributionMut for GLWESecretTensor<D> {
42 fn dist_mut(&mut self) -> &mut Distribution {
43 &mut self.dist
44 }
45}
46
47impl<D: Data> LWEInfos for GLWESecretTensor<D> {
48 fn base2k(&self) -> Base2K {
49 Base2K(0)
50 }
51
52 fn n(&self) -> Degree {
53 Degree(self.data.n() as u32)
54 }
55
56 fn size(&self) -> usize {
57 1
58 }
59}
60
61impl<D: Data> LWEInfos for &mut GLWESecretTensor<D> {
62 fn base2k(&self) -> Base2K {
63 (**self).base2k()
64 }
65
66 fn n(&self) -> Degree {
67 (**self).n()
68 }
69
70 fn size(&self) -> usize {
71 (**self).size()
72 }
73}
74
75impl<D: HostDataRef> GLWESecretTensor<D> {
76 pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> {
77 if i > j {
78 std::mem::swap(&mut i, &mut j);
79 };
80 let rank: usize = self.rank().into();
81 ScalarZnx::from_data(
82 bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)),
83 self.n().into(),
84 1,
85 )
86 }
87}
88
89impl<D: HostDataMut> GLWESecretTensor<D> {
90 pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> {
91 if i > j {
92 std::mem::swap(&mut i, &mut j);
93 };
94 let rank: usize = self.rank().into();
95 let n = self.n().into();
96 ScalarZnx::from_data(
97 bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)),
98 n,
99 1,
100 )
101 }
102}
103
104impl<D: Data> GLWEInfos for GLWESecretTensor<D> {
105 fn rank(&self) -> Rank {
106 self.rank
107 }
108}
109
110impl<D: Data> GLWEInfos for &mut GLWESecretTensor<D> {
111 fn rank(&self) -> Rank {
112 (**self).rank()
113 }
114}
115
116impl<BE: Backend> GLWESecretToBackendRef<BE> for GLWESecretTensor<BE::OwnedBuf> {
117 fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
118 GLWESecret {
119 data: <ScalarZnx<BE::OwnedBuf> as ScalarZnxToBackendRef<BE>>::to_backend_ref(&self.data),
120 dist: self.dist,
121 }
122 }
123}
124
125impl<'b, BE: Backend + 'b> GLWESecretToBackendRef<BE> for &mut GLWESecretTensor<BE::BufMut<'b>> {
126 fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
127 GLWESecret {
128 data: ScalarZnx::from_data(BE::view_ref_mut(&self.data.data), self.data.n(), self.data.cols()),
129 dist: self.dist,
130 }
131 }
132}
133
134impl<BE: Backend> GLWESecretToBackendMut<BE> for GLWESecretTensor<BE::OwnedBuf> {
135 fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
136 GLWESecret {
137 data: <ScalarZnx<BE::OwnedBuf> as poulpy_hal::layouts::ScalarZnxToBackendMut<BE>>::to_backend_mut(&mut self.data),
138 dist: self.dist,
139 }
140 }
141}
142
143impl<'b, BE: Backend + 'b> GLWESecretToBackendMut<BE> for &mut GLWESecretTensor<BE::BufMut<'b>> {
144 fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
145 let n = self.data.n();
146 let cols = self.data.cols();
147 GLWESecret {
148 data: ScalarZnx::from_data(BE::view_mut_ref(&mut self.data.data), n, cols),
149 dist: self.dist,
150 }
151 }
152}
153
154#[expect(
155 dead_code,
156 reason = "host-owned constructors are kept for serialization and host-only staging"
157)]
158impl GLWESecretTensor<Vec<u8>> {
159 pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
160 where
161 A: GLWEInfos,
162 {
163 Self::alloc(infos.n(), infos.rank())
164 }
165
166 pub(crate) fn alloc(n: Degree, rank: Rank) -> Self {
167 GLWESecretTensor {
168 data: ScalarZnx::from_data(
169 poulpy_hal::layouts::HostBytesBackend::alloc_bytes(ScalarZnx::<Vec<u8>>::bytes_of(
170 n.into(),
171 Self::pairs(rank.into()),
172 )),
173 n.into(),
174 Self::pairs(rank.into()),
175 ),
176 rank,
177 dist: Distribution::NONE,
178 }
179 }
180
181 pub fn bytes_of_from_infos<A>(infos: &A) -> usize
182 where
183 A: GLWEInfos,
184 {
185 Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into())
186 }
187
188 pub fn bytes_of(n: Degree, rank: Rank) -> usize {
189 ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into()))
190 }
191}
192
193pub trait GLWESecretTensorFactory<BE: Backend> {
196 fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize;
197
198 fn glwe_secret_tensor_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut ScratchArena<'_, BE>)
199 where
200 R: GLWESecretToBackendMut<BE> + GetDistributionMut + GLWEInfos,
201 O: GLWESecretToBackendRef<BE> + GetDistribution + GLWEInfos;
202}
203
204impl<BE: Backend> GLWESecretTensorFactory<BE> for Module<BE>
205where
206 Self: ModuleN
207 + GLWESecretPreparedFactory<BE>
208 + VecZnxBigNormalize<BE>
209 + VecZnxDftApply<BE>
210 + SvpApplyDftToDft<BE>
211 + VecZnxIdftApplyTmpA<BE>
212 + VecZnxBigNormalize<BE>
213 + VecZnxDftBytesOf
214 + VecZnxBigBytesOf
215 + VecZnxBigNormalizeTmpBytes,
216{
217 fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize {
218 self.glwe_secret_prepared_bytes_of(rank)
219 }
220
221 fn glwe_secret_tensor_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut ScratchArena<'_, BE>)
222 where
223 R: GLWESecretToBackendMut<BE> + GetDistributionMut + GLWEInfos,
224 A: GLWESecretToBackendRef<BE> + GetDistribution + GLWEInfos,
225 {
226 let res = &mut res.to_backend_mut();
227 let a = a.to_backend_ref();
228
229 assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32);
230 assert_eq!(res.n(), self.n() as u32);
231 assert_eq!(a.n(), self.n() as u32);
232 assert!(
233 scratch.available() >= self.glwe_secret_tensor_prepare_tmp_bytes(a.rank()),
234 "scratch.available(): {} < GLWESecretTensorFactory::glwe_secret_tensor_prepare_tmp_bytes: {}",
235 scratch.available(),
236 self.glwe_secret_tensor_prepare_tmp_bytes(a.rank())
237 );
238
239 let rank: usize = a.rank().into();
240
241 let scratch = scratch.borrow();
242 let (mut a_prepared, _scratch_1) = scratch.take_glwe_secret_prepared_scratch(self, rank.into());
243 {
244 let mut a_prepared_data = a_prepared.data.reborrow_backend_mut();
245 for i in 0..rank {
246 self.svp_prepare(&mut a_prepared_data, i, &a.data, i);
247 }
248 }
249 a_prepared.dist = *a.dist();
250
251 let base2k: usize = 17;
252
253 let mut a_dft = VecZnxDft::<BE::OwnedBuf, BE>::alloc(self.n(), rank, 1);
254 let a_backend_vec = scalar_znx_as_vec_znx_backend_ref_from_ref::<BE>(&a.data);
255 for i in 0..rank {
256 let mut a_dft_backend = a_dft.to_backend_mut();
257 self.vec_znx_dft_apply(1, 0, &mut a_dft_backend, i, &a_backend_vec, i);
258 }
259
260 let mut a_ij_dft = VecZnxDft::<BE::OwnedBuf, BE>::alloc(self.n(), 1, 1);
261 let a_prepared_backend_ref = a_prepared.data.reborrow_backend_ref();
262 let mut a_ij_big_backend = self.vec_znx_big_alloc(1, 1);
263 let mut norm_scratch = ScratchOwned {
264 data: BE::alloc_bytes(self.vec_znx_big_normalize_tmp_bytes()),
265 _phantom: std::marker::PhantomData,
266 };
267 let mut res_backend = scalar_znx_as_vec_znx_backend_mut_from_mut::<BE>(&mut res.data);
268
269 for i in 0..rank {
272 for j in i..rank {
273 let idx: usize = i * rank + j - (i * (i + 1) / 2);
274 let a_dft_ref = a_dft.to_backend_ref();
275 {
276 let mut a_ij_dft_backend = a_ij_dft.to_backend_mut();
277 self.svp_apply_dft_to_dft(&mut a_ij_dft_backend, 0, &a_prepared_backend_ref, j, &a_dft_ref, i);
278 }
279 {
280 let mut a_ij_big = a_ij_big_backend.to_backend_mut();
281 let mut a_ij_dft = a_ij_dft.to_backend_mut();
282 self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0);
283 }
284 {
285 let a_ij_big = a_ij_big_backend.to_backend_ref();
286 self.vec_znx_big_normalize(
287 &mut res_backend,
288 base2k,
289 0,
290 idx,
291 &a_ij_big,
292 base2k,
293 0,
294 &mut norm_scratch.arena(),
295 );
296 }
297 }
298 }
299
300 res.dist = *a.dist();
301 }
302}