Skip to main content

poulpy_core/layouts/
glwe_secret_tensor.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, SvpApplyDftToDft, SvpPrepare, VecZnxBigAlloc, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
5    },
6    layouts::{
7        Backend, Data, HostDataMut, HostDataRef, Module, ScalarZnx, ScalarZnxToBackendRef, ScratchArena, ScratchOwned,
8        SvpPPolReborrowBackendMut, SvpPPolReborrowBackendRef, VecZnxBigToBackendMut, VecZnxBigToBackendRef, VecZnxDft,
9        VecZnxDftToBackendMut, VecZnxDftToBackendRef, ZnxView, ZnxViewMut, scalar_znx_as_vec_znx_backend_mut_from_mut,
10        scalar_znx_as_vec_znx_backend_ref_from_ref,
11    },
12};
13
14use crate::{
15    GetDistribution, GetDistributionMut, ScratchArenaTakeCore,
16    dist::Distribution,
17    layouts::{
18        Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretBackendMut, GLWESecretBackendRef, GLWESecretPreparedFactory,
19        GLWESecretToBackendMut, GLWESecretToBackendRef, LWEInfos, Rank,
20    },
21};
22
23pub struct GLWESecretTensor<D: Data> {
24    pub(crate) data: ScalarZnx<D>,
25    pub(crate) rank: Rank,
26    pub(crate) dist: Distribution,
27}
28
29impl GLWESecretTensor<Vec<u8>> {
30    pub(crate) fn pairs(rank: usize) -> usize {
31        (((rank + 1) * rank) >> 1).max(1)
32    }
33}
34
35impl<D: Data> GetDistribution for GLWESecretTensor<D> {
36    fn dist(&self) -> &Distribution {
37        &self.dist
38    }
39}
40
41impl<D: Data> GetDistributionMut for GLWESecretTensor<D> {
42    fn dist_mut(&mut self) -> &mut Distribution {
43        &mut self.dist
44    }
45}
46
47impl<D: Data> LWEInfos for GLWESecretTensor<D> {
48    fn base2k(&self) -> Base2K {
49        Base2K(0)
50    }
51
52    fn n(&self) -> Degree {
53        Degree(self.data.n() as u32)
54    }
55
56    fn size(&self) -> usize {
57        1
58    }
59}
60
61impl<D: Data> LWEInfos for &mut GLWESecretTensor<D> {
62    fn base2k(&self) -> Base2K {
63        (**self).base2k()
64    }
65
66    fn n(&self) -> Degree {
67        (**self).n()
68    }
69
70    fn size(&self) -> usize {
71        (**self).size()
72    }
73}
74
75impl<D: HostDataRef> GLWESecretTensor<D> {
76    pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> {
77        if i > j {
78            std::mem::swap(&mut i, &mut j);
79        };
80        let rank: usize = self.rank().into();
81        ScalarZnx::from_data(
82            bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)),
83            self.n().into(),
84            1,
85        )
86    }
87}
88
89impl<D: HostDataMut> GLWESecretTensor<D> {
90    pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> {
91        if i > j {
92            std::mem::swap(&mut i, &mut j);
93        };
94        let rank: usize = self.rank().into();
95        let n = self.n().into();
96        ScalarZnx::from_data(
97            bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)),
98            n,
99            1,
100        )
101    }
102}
103
104impl<D: Data> GLWEInfos for GLWESecretTensor<D> {
105    fn rank(&self) -> Rank {
106        self.rank
107    }
108}
109
110impl<D: Data> GLWEInfos for &mut GLWESecretTensor<D> {
111    fn rank(&self) -> Rank {
112        (**self).rank()
113    }
114}
115
116impl<BE: Backend> GLWESecretToBackendRef<BE> for GLWESecretTensor<BE::OwnedBuf> {
117    fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
118        GLWESecret {
119            data: <ScalarZnx<BE::OwnedBuf> as ScalarZnxToBackendRef<BE>>::to_backend_ref(&self.data),
120            dist: self.dist,
121        }
122    }
123}
124
125impl<'b, BE: Backend + 'b> GLWESecretToBackendRef<BE> for &mut GLWESecretTensor<BE::BufMut<'b>> {
126    fn to_backend_ref(&self) -> GLWESecretBackendRef<'_, BE> {
127        GLWESecret {
128            data: ScalarZnx::from_data(BE::view_ref_mut(&self.data.data), self.data.n(), self.data.cols()),
129            dist: self.dist,
130        }
131    }
132}
133
134impl<BE: Backend> GLWESecretToBackendMut<BE> for GLWESecretTensor<BE::OwnedBuf> {
135    fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
136        GLWESecret {
137            data: <ScalarZnx<BE::OwnedBuf> as poulpy_hal::layouts::ScalarZnxToBackendMut<BE>>::to_backend_mut(&mut self.data),
138            dist: self.dist,
139        }
140    }
141}
142
143impl<'b, BE: Backend + 'b> GLWESecretToBackendMut<BE> for &mut GLWESecretTensor<BE::BufMut<'b>> {
144    fn to_backend_mut(&mut self) -> GLWESecretBackendMut<'_, BE> {
145        let n = self.data.n();
146        let cols = self.data.cols();
147        GLWESecret {
148            data: ScalarZnx::from_data(BE::view_mut_ref(&mut self.data.data), n, cols),
149            dist: self.dist,
150        }
151    }
152}
153
154#[expect(
155    dead_code,
156    reason = "host-owned constructors are kept for serialization and host-only staging"
157)]
158impl GLWESecretTensor<Vec<u8>> {
159    pub(crate) fn alloc_from_infos<A>(infos: &A) -> Self
160    where
161        A: GLWEInfos,
162    {
163        Self::alloc(infos.n(), infos.rank())
164    }
165
166    pub(crate) fn alloc(n: Degree, rank: Rank) -> Self {
167        GLWESecretTensor {
168            data: ScalarZnx::from_data(
169                poulpy_hal::layouts::HostBytesBackend::alloc_bytes(ScalarZnx::<Vec<u8>>::bytes_of(
170                    n.into(),
171                    Self::pairs(rank.into()),
172                )),
173                n.into(),
174                Self::pairs(rank.into()),
175            ),
176            rank,
177            dist: Distribution::NONE,
178        }
179    }
180
181    pub fn bytes_of_from_infos<A>(infos: &A) -> usize
182    where
183        A: GLWEInfos,
184    {
185        Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into())
186    }
187
188    pub fn bytes_of(n: Degree, rank: Rank) -> usize {
189        ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into()))
190    }
191}
192
193// module-only API: secret tensor preparation is provided by `GLWESecretTensorFactory` on `Module`.
194
195pub trait GLWESecretTensorFactory<BE: Backend> {
196    fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize;
197
198    fn glwe_secret_tensor_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut ScratchArena<'_, BE>)
199    where
200        R: GLWESecretToBackendMut<BE> + GetDistributionMut + GLWEInfos,
201        O: GLWESecretToBackendRef<BE> + GetDistribution + GLWEInfos;
202}
203
204impl<BE: Backend> GLWESecretTensorFactory<BE> for Module<BE>
205where
206    Self: ModuleN
207        + GLWESecretPreparedFactory<BE>
208        + VecZnxBigNormalize<BE>
209        + VecZnxDftApply<BE>
210        + SvpApplyDftToDft<BE>
211        + VecZnxIdftApplyTmpA<BE>
212        + VecZnxBigNormalize<BE>
213        + VecZnxDftBytesOf
214        + VecZnxBigBytesOf
215        + VecZnxBigNormalizeTmpBytes,
216{
217    fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize {
218        self.glwe_secret_prepared_bytes_of(rank)
219    }
220
221    fn glwe_secret_tensor_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut ScratchArena<'_, BE>)
222    where
223        R: GLWESecretToBackendMut<BE> + GetDistributionMut + GLWEInfos,
224        A: GLWESecretToBackendRef<BE> + GetDistribution + GLWEInfos,
225    {
226        let res = &mut res.to_backend_mut();
227        let a = a.to_backend_ref();
228
229        assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32);
230        assert_eq!(res.n(), self.n() as u32);
231        assert_eq!(a.n(), self.n() as u32);
232        assert!(
233            scratch.available() >= self.glwe_secret_tensor_prepare_tmp_bytes(a.rank()),
234            "scratch.available(): {} < GLWESecretTensorFactory::glwe_secret_tensor_prepare_tmp_bytes: {}",
235            scratch.available(),
236            self.glwe_secret_tensor_prepare_tmp_bytes(a.rank())
237        );
238
239        let rank: usize = a.rank().into();
240
241        let scratch = scratch.borrow();
242        let (mut a_prepared, _scratch_1) = scratch.take_glwe_secret_prepared_scratch(self, rank.into());
243        {
244            let mut a_prepared_data = a_prepared.data.reborrow_backend_mut();
245            for i in 0..rank {
246                self.svp_prepare(&mut a_prepared_data, i, &a.data, i);
247            }
248        }
249        a_prepared.dist = *a.dist();
250
251        let base2k: usize = 17;
252
253        let mut a_dft = VecZnxDft::<BE::OwnedBuf, BE>::alloc(self.n(), rank, 1);
254        let a_backend_vec = scalar_znx_as_vec_znx_backend_ref_from_ref::<BE>(&a.data);
255        for i in 0..rank {
256            let mut a_dft_backend = a_dft.to_backend_mut();
257            self.vec_znx_dft_apply(1, 0, &mut a_dft_backend, i, &a_backend_vec, i);
258        }
259
260        let mut a_ij_dft = VecZnxDft::<BE::OwnedBuf, BE>::alloc(self.n(), 1, 1);
261        let a_prepared_backend_ref = a_prepared.data.reborrow_backend_ref();
262        let mut a_ij_big_backend = self.vec_znx_big_alloc(1, 1);
263        let mut norm_scratch = ScratchOwned {
264            data: BE::alloc_bytes(self.vec_znx_big_normalize_tmp_bytes()),
265            _phantom: std::marker::PhantomData,
266        };
267        let mut res_backend = scalar_znx_as_vec_znx_backend_mut_from_mut::<BE>(&mut res.data);
268
269        // sk_tensor = sk (x) sk
270        // For example: (s0, s1) (x) (s0, s1) = (s0^2, s0s1, s1^2)
271        for i in 0..rank {
272            for j in i..rank {
273                let idx: usize = i * rank + j - (i * (i + 1) / 2);
274                let a_dft_ref = a_dft.to_backend_ref();
275                {
276                    let mut a_ij_dft_backend = a_ij_dft.to_backend_mut();
277                    self.svp_apply_dft_to_dft(&mut a_ij_dft_backend, 0, &a_prepared_backend_ref, j, &a_dft_ref, i);
278                }
279                {
280                    let mut a_ij_big = a_ij_big_backend.to_backend_mut();
281                    let mut a_ij_dft = a_ij_dft.to_backend_mut();
282                    self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0);
283                }
284                {
285                    let a_ij_big = a_ij_big_backend.to_backend_ref();
286                    self.vec_znx_big_normalize(
287                        &mut res_backend,
288                        base2k,
289                        0,
290                        idx,
291                        &a_ij_big,
292                        base2k,
293                        0,
294                        &mut norm_scratch.arena(),
295                    );
296                }
297            }
298        }
299
300        res.dist = *a.dist();
301    }
302}