1use poulpy_hal::{
2 api::{
3 ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
5 },
6 layouts::{
7 Backend, Data, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, Scratch, ZnxInfos, ZnxView,
8 ZnxViewMut,
9 },
10};
11
12use crate::{
13 ScratchTakeCore,
14 dist::Distribution,
15 layouts::{
16 Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank,
17 TorusPrecision,
18 },
19};
20
21pub struct GLWESecretTensor<D: Data> {
22 pub(crate) data: ScalarZnx<D>,
23 pub(crate) rank: Rank,
24 pub(crate) dist: Distribution,
25}
26
27impl GLWESecretTensor<Vec<u8>> {
28 pub(crate) fn pairs(rank: usize) -> usize {
29 (((rank + 1) * rank) >> 1).max(1)
30 }
31}
32
33impl<D: Data> LWEInfos for GLWESecretTensor<D> {
34 fn base2k(&self) -> Base2K {
35 Base2K(0)
36 }
37
38 fn k(&self) -> TorusPrecision {
39 TorusPrecision(0)
40 }
41
42 fn n(&self) -> Degree {
43 Degree(self.data.n() as u32)
44 }
45
46 fn size(&self) -> usize {
47 1
48 }
49}
50
51impl<D: DataRef> GLWESecretTensor<D> {
52 pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> {
53 if i > j {
54 std::mem::swap(&mut i, &mut j);
55 };
56 let rank: usize = self.rank().into();
57 ScalarZnx {
58 data: bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)),
59 n: self.n().into(),
60 cols: 1,
61 }
62 }
63}
64
65impl<D: DataMut> GLWESecretTensor<D> {
66 pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> {
67 if i > j {
68 std::mem::swap(&mut i, &mut j);
69 };
70 let rank: usize = self.rank().into();
71 ScalarZnx {
72 n: self.n().into(),
73 data: bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)),
74 cols: 1,
75 }
76 }
77}
78
79impl<D: Data> GLWEInfos for GLWESecretTensor<D> {
80 fn rank(&self) -> Rank {
81 self.rank
82 }
83}
84
85impl<D: DataRef> GLWESecretToRef for GLWESecretTensor<D> {
86 fn to_ref(&self) -> GLWESecret<&[u8]> {
87 GLWESecret {
88 data: self.data.to_ref(),
89 dist: self.dist,
90 }
91 }
92}
93
94impl<D: DataMut> GLWESecretToMut for GLWESecretTensor<D> {
95 fn to_mut(&mut self) -> GLWESecret<&mut [u8]> {
96 GLWESecret {
97 dist: self.dist,
98 data: self.data.to_mut(),
99 }
100 }
101}
102
103impl GLWESecretTensor<Vec<u8>> {
104 pub fn alloc_from_infos<A>(infos: &A) -> Self
105 where
106 A: GLWEInfos,
107 {
108 Self::alloc(infos.n(), infos.rank())
109 }
110
111 pub fn alloc(n: Degree, rank: Rank) -> Self {
112 GLWESecretTensor {
113 data: ScalarZnx::alloc(n.into(), Self::pairs(rank.into())),
114 rank,
115 dist: Distribution::NONE,
116 }
117 }
118
119 pub fn bytes_of_from_infos<A>(infos: &A) -> usize
120 where
121 A: GLWEInfos,
122 {
123 Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into())
124 }
125
126 pub fn bytes_of(n: Degree, rank: Rank) -> usize {
127 ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into()))
128 }
129}
130
131impl<D: DataMut> GLWESecretTensor<D> {
132 pub fn prepare<M, S, BE: Backend>(&mut self, module: &M, other: &S, scratch: &mut Scratch<BE>)
133 where
134 M: GLWESecretTensorFactory<BE>,
135 S: GLWESecretToRef + GLWEInfos,
136 Scratch<BE>: ScratchTakeCore<BE>,
137 {
138 module.glwe_secret_tensor_prepare(self, other, scratch);
139 }
140}
141
142pub trait GLWESecretTensorFactory<BE: Backend> {
143 fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize;
144
145 fn glwe_secret_tensor_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<BE>)
146 where
147 R: GLWESecretToMut + GLWEInfos,
148 O: GLWESecretToRef + GLWEInfos;
149}
150
151impl<BE: Backend> GLWESecretTensorFactory<BE> for Module<BE>
152where
153 Self: ModuleN
154 + GLWESecretPreparedFactory<BE>
155 + VecZnxBigNormalize<BE>
156 + VecZnxDftApply<BE>
157 + SvpApplyDftToDft<BE>
158 + VecZnxIdftApplyTmpA<BE>
159 + VecZnxBigNormalize<BE>
160 + VecZnxDftBytesOf
161 + VecZnxBigBytesOf
162 + VecZnxBigNormalizeTmpBytes,
163 Scratch<BE>: ScratchTakeCore<BE>,
164{
165 fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize {
166 self.bytes_of_glwe_secret_prepared(rank)
167 + self.bytes_of_vec_znx_dft(rank.into(), 1)
168 + self.bytes_of_vec_znx_dft(1, 1)
169 + self.bytes_of_vec_znx_big(1, 1)
170 + self.vec_znx_big_normalize_tmp_bytes()
171 }
172
173 fn glwe_secret_tensor_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
174 where
175 R: GLWESecretToMut + GLWEInfos,
176 A: GLWESecretToRef + GLWEInfos,
177 {
178 let res: &mut GLWESecret<&mut [u8]> = &mut res.to_mut();
179 let a: &GLWESecret<&[u8]> = &a.to_ref();
180
181 assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32);
182 assert_eq!(res.n(), self.n() as u32);
183 assert_eq!(a.n(), self.n() as u32);
184
185 let rank: usize = a.rank().into();
186
187 let (mut a_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, rank.into());
188 a_prepared.prepare(self, a);
189
190 let base2k: usize = 17;
191
192 let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1);
193 for i in 0..rank {
194 self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a.data.as_vec_znx(), i);
195 }
196
197 let (mut a_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1);
198 let (mut a_ij_dft, scratch_4) = scratch_3.take_vec_znx_dft(self, 1, 1);
199
200 for i in 0..rank {
203 for j in i..rank {
204 let idx: usize = i * rank + j - (i * (i + 1) / 2);
205 self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i);
206 self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0);
207 self.vec_znx_big_normalize(
208 base2k,
209 &mut res.data.as_vec_znx_mut(),
210 idx,
211 base2k,
212 &a_ij_big,
213 0,
214 scratch_4,
215 );
216 }
217 }
218 }
219}