poulpy_core/layouts/
glwe_secret_tensor.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyTmpA,
5    },
6    layouts::{
7        Backend, Data, DataMut, DataRef, Module, ScalarZnx, ScalarZnxToMut, ScalarZnxToRef, Scratch, ZnxInfos, ZnxView,
8        ZnxViewMut,
9    },
10};
11
12use crate::{
13    ScratchTakeCore,
14    dist::Distribution,
15    layouts::{
16        Base2K, Degree, GLWEInfos, GLWESecret, GLWESecretPreparedFactory, GLWESecretToMut, GLWESecretToRef, LWEInfos, Rank,
17        TorusPrecision,
18    },
19};
20
21pub struct GLWESecretTensor<D: Data> {
22    pub(crate) data: ScalarZnx<D>,
23    pub(crate) rank: Rank,
24    pub(crate) dist: Distribution,
25}
26
27impl GLWESecretTensor<Vec<u8>> {
28    pub(crate) fn pairs(rank: usize) -> usize {
29        (((rank + 1) * rank) >> 1).max(1)
30    }
31}
32
33impl<D: Data> LWEInfos for GLWESecretTensor<D> {
34    fn base2k(&self) -> Base2K {
35        Base2K(0)
36    }
37
38    fn k(&self) -> TorusPrecision {
39        TorusPrecision(0)
40    }
41
42    fn n(&self) -> Degree {
43        Degree(self.data.n() as u32)
44    }
45
46    fn size(&self) -> usize {
47        1
48    }
49}
50
51impl<D: DataRef> GLWESecretTensor<D> {
52    pub fn at(&self, mut i: usize, mut j: usize) -> ScalarZnx<&[u8]> {
53        if i > j {
54            std::mem::swap(&mut i, &mut j);
55        };
56        let rank: usize = self.rank().into();
57        ScalarZnx {
58            data: bytemuck::cast_slice(self.data.at(i * rank + j - (i * (i + 1) / 2), 0)),
59            n: self.n().into(),
60            cols: 1,
61        }
62    }
63}
64
65impl<D: DataMut> GLWESecretTensor<D> {
66    pub fn at_mut(&mut self, mut i: usize, mut j: usize) -> ScalarZnx<&mut [u8]> {
67        if i > j {
68            std::mem::swap(&mut i, &mut j);
69        };
70        let rank: usize = self.rank().into();
71        ScalarZnx {
72            n: self.n().into(),
73            data: bytemuck::cast_slice_mut(self.data.at_mut(i * rank + j - (i * (i + 1) / 2), 0)),
74            cols: 1,
75        }
76    }
77}
78
79impl<D: Data> GLWEInfos for GLWESecretTensor<D> {
80    fn rank(&self) -> Rank {
81        self.rank
82    }
83}
84
85impl<D: DataRef> GLWESecretToRef for GLWESecretTensor<D> {
86    fn to_ref(&self) -> GLWESecret<&[u8]> {
87        GLWESecret {
88            data: self.data.to_ref(),
89            dist: self.dist,
90        }
91    }
92}
93
94impl<D: DataMut> GLWESecretToMut for GLWESecretTensor<D> {
95    fn to_mut(&mut self) -> GLWESecret<&mut [u8]> {
96        GLWESecret {
97            dist: self.dist,
98            data: self.data.to_mut(),
99        }
100    }
101}
102
103impl GLWESecretTensor<Vec<u8>> {
104    pub fn alloc_from_infos<A>(infos: &A) -> Self
105    where
106        A: GLWEInfos,
107    {
108        Self::alloc(infos.n(), infos.rank())
109    }
110
111    pub fn alloc(n: Degree, rank: Rank) -> Self {
112        GLWESecretTensor {
113            data: ScalarZnx::alloc(n.into(), Self::pairs(rank.into())),
114            rank,
115            dist: Distribution::NONE,
116        }
117    }
118
119    pub fn bytes_of_from_infos<A>(infos: &A) -> usize
120    where
121        A: GLWEInfos,
122    {
123        Self::bytes_of(infos.n(), Self::pairs(infos.rank().into()).into())
124    }
125
126    pub fn bytes_of(n: Degree, rank: Rank) -> usize {
127        ScalarZnx::bytes_of(n.into(), Self::pairs(rank.into()))
128    }
129}
130
131impl<D: DataMut> GLWESecretTensor<D> {
132    pub fn prepare<M, S, BE: Backend>(&mut self, module: &M, other: &S, scratch: &mut Scratch<BE>)
133    where
134        M: GLWESecretTensorFactory<BE>,
135        S: GLWESecretToRef + GLWEInfos,
136        Scratch<BE>: ScratchTakeCore<BE>,
137    {
138        module.glwe_secret_tensor_prepare(self, other, scratch);
139    }
140}
141
142pub trait GLWESecretTensorFactory<BE: Backend> {
143    fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize;
144
145    fn glwe_secret_tensor_prepare<R, O>(&self, res: &mut R, other: &O, scratch: &mut Scratch<BE>)
146    where
147        R: GLWESecretToMut + GLWEInfos,
148        O: GLWESecretToRef + GLWEInfos;
149}
150
151impl<BE: Backend> GLWESecretTensorFactory<BE> for Module<BE>
152where
153    Self: ModuleN
154        + GLWESecretPreparedFactory<BE>
155        + VecZnxBigNormalize<BE>
156        + VecZnxDftApply<BE>
157        + SvpApplyDftToDft<BE>
158        + VecZnxIdftApplyTmpA<BE>
159        + VecZnxBigNormalize<BE>
160        + VecZnxDftBytesOf
161        + VecZnxBigBytesOf
162        + VecZnxBigNormalizeTmpBytes,
163    Scratch<BE>: ScratchTakeCore<BE>,
164{
165    fn glwe_secret_tensor_prepare_tmp_bytes(&self, rank: Rank) -> usize {
166        self.bytes_of_glwe_secret_prepared(rank)
167            + self.bytes_of_vec_znx_dft(rank.into(), 1)
168            + self.bytes_of_vec_znx_dft(1, 1)
169            + self.bytes_of_vec_znx_big(1, 1)
170            + self.vec_znx_big_normalize_tmp_bytes()
171    }
172
173    fn glwe_secret_tensor_prepare<R, A>(&self, res: &mut R, a: &A, scratch: &mut Scratch<BE>)
174    where
175        R: GLWESecretToMut + GLWEInfos,
176        A: GLWESecretToRef + GLWEInfos,
177    {
178        let res: &mut GLWESecret<&mut [u8]> = &mut res.to_mut();
179        let a: &GLWESecret<&[u8]> = &a.to_ref();
180
181        assert_eq!(res.rank(), GLWESecretTensor::pairs(a.rank().into()) as u32);
182        assert_eq!(res.n(), self.n() as u32);
183        assert_eq!(a.n(), self.n() as u32);
184
185        let rank: usize = a.rank().into();
186
187        let (mut a_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, rank.into());
188        a_prepared.prepare(self, a);
189
190        let base2k: usize = 17;
191
192        let (mut a_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank, 1);
193        for i in 0..rank {
194            self.vec_znx_dft_apply(1, 0, &mut a_dft, i, &a.data.as_vec_znx(), i);
195        }
196
197        let (mut a_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1);
198        let (mut a_ij_dft, scratch_4) = scratch_3.take_vec_znx_dft(self, 1, 1);
199
200        // sk_tensor = sk (x) sk
201        // For example: (s0, s1) (x) (s0, s1) = (s0^2, s0s1, s1^2)
202        for i in 0..rank {
203            for j in i..rank {
204                let idx: usize = i * rank + j - (i * (i + 1) / 2);
205                self.svp_apply_dft_to_dft(&mut a_ij_dft, 0, &a_prepared.data, j, &a_dft, i);
206                self.vec_znx_idft_apply_tmpa(&mut a_ij_big, 0, &mut a_ij_dft, 0);
207                self.vec_znx_big_normalize(
208                    base2k,
209                    &mut res.data.as_vec_znx_mut(),
210                    idx,
211                    base2k,
212                    &a_ij_big,
213                    0,
214                    scratch_4,
215                );
216            }
217        }
218    }
219}