poulpy_core/automorphism/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4        VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallInplace,
5        VecZnxBigSubSmallNegateInplace, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize,
6        VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig},
9};
10
11use crate::layouts::{GGLWELayoutInfos, GLWECiphertext, GLWEInfos, LWEInfos, prepared::GGLWEAutomorphismKeyPrepared};
12
13impl GLWECiphertext<Vec<u8>> {
14    pub fn automorphism_scratch_space<B: Backend, OUT, IN, KEY>(
15        module: &Module<B>,
16        out_infos: &OUT,
17        in_infos: &IN,
18        key_infos: &KEY,
19    ) -> usize
20    where
21        OUT: GLWEInfos,
22        IN: GLWEInfos,
23        KEY: GGLWELayoutInfos,
24        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
25    {
26        Self::keyswitch_scratch_space(module, out_infos, in_infos, key_infos)
27    }
28
29    pub fn automorphism_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
30    where
31        OUT: GLWEInfos,
32        KEY: GGLWELayoutInfos,
33        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
34    {
35        Self::keyswitch_inplace_scratch_space(module, out_infos, key_infos)
36    }
37}
38
39impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
40    pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
41        &mut self,
42        module: &Module<B>,
43        lhs: &GLWECiphertext<DataLhs>,
44        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
45        scratch: &mut Scratch<B>,
46    ) where
47        Module<B>: VecZnxDftAllocBytes
48            + VmpApplyDftToDftTmpBytes
49            + VecZnxBigNormalizeTmpBytes
50            + VmpApplyDftToDft<B>
51            + VmpApplyDftToDftAdd<B>
52            + VecZnxDftApply<B>
53            + VecZnxIdftApplyConsume<B>
54            + VecZnxBigAddSmallInplace<B>
55            + VecZnxBigNormalize<B>
56            + VecZnxAutomorphismInplace<B>
57            + VecZnxNormalize<B>
58            + VecZnxNormalizeTmpBytes,
59        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
60    {
61        self.keyswitch(module, lhs, &rhs.key, scratch);
62        (0..(self.rank() + 1).into()).for_each(|i| {
63            module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch);
64        })
65    }
66
67    pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
68        &mut self,
69        module: &Module<B>,
70        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
71        scratch: &mut Scratch<B>,
72    ) where
73        Module<B>: VecZnxDftAllocBytes
74            + VmpApplyDftToDftTmpBytes
75            + VecZnxBigNormalizeTmpBytes
76            + VmpApplyDftToDft<B>
77            + VmpApplyDftToDftAdd<B>
78            + VecZnxDftApply<B>
79            + VecZnxIdftApplyConsume<B>
80            + VecZnxBigAddSmallInplace<B>
81            + VecZnxBigNormalize<B>
82            + VecZnxAutomorphismInplace<B>
83            + VecZnxNormalize<B>
84            + VecZnxNormalizeTmpBytes,
85        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
86    {
87        self.keyswitch_inplace(module, &rhs.key, scratch);
88        (0..(self.rank() + 1).into()).for_each(|i| {
89            module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i, scratch);
90        })
91    }
92
93    pub fn automorphism_add<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
94        &mut self,
95        module: &Module<B>,
96        lhs: &GLWECiphertext<DataLhs>,
97        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
98        scratch: &mut Scratch<B>,
99    ) where
100        Module<B>: VecZnxDftAllocBytes
101            + VmpApplyDftToDftTmpBytes
102            + VecZnxBigNormalizeTmpBytes
103            + VmpApplyDftToDft<B>
104            + VmpApplyDftToDftAdd<B>
105            + VecZnxDftApply<B>
106            + VecZnxIdftApplyConsume<B>
107            + VecZnxBigAddSmallInplace<B>
108            + VecZnxBigNormalize<B>
109            + VecZnxBigAutomorphismInplace<B>
110            + VecZnxNormalizeTmpBytes
111            + VecZnxNormalize<B>,
112        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
113    {
114        #[cfg(debug_assertions)]
115        {
116            self.assert_keyswitch(module, lhs, &rhs.key, scratch);
117        }
118        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size
119        let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
120        (0..(self.rank() + 1).into()).for_each(|i| {
121            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
122            module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i);
123            module.vec_znx_big_normalize(
124                self.base2k().into(),
125                &mut self.data,
126                i,
127                rhs.base2k().into(),
128                &res_big,
129                i,
130                scratch_1,
131            );
132        })
133    }
134
135    pub fn automorphism_add_inplace<DataRhs: DataRef, B: Backend>(
136        &mut self,
137        module: &Module<B>,
138        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
139        scratch: &mut Scratch<B>,
140    ) where
141        Module<B>: VecZnxDftAllocBytes
142            + VmpApplyDftToDftTmpBytes
143            + VecZnxBigNormalizeTmpBytes
144            + VmpApplyDftToDft<B>
145            + VmpApplyDftToDftAdd<B>
146            + VecZnxDftApply<B>
147            + VecZnxIdftApplyConsume<B>
148            + VecZnxBigAddSmallInplace<B>
149            + VecZnxBigNormalize<B>
150            + VecZnxBigAutomorphismInplace<B>
151            + VecZnxNormalizeTmpBytes
152            + VecZnxNormalize<B>,
153        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
154    {
155        #[cfg(debug_assertions)]
156        {
157            self.assert_keyswitch_inplace(module, &rhs.key, scratch);
158        }
159        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size
160        let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
161        (0..(self.rank() + 1).into()).for_each(|i| {
162            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
163            module.vec_znx_big_add_small_inplace(&mut res_big, i, &self.data, i);
164            module.vec_znx_big_normalize(
165                self.base2k().into(),
166                &mut self.data,
167                i,
168                rhs.base2k().into(),
169                &res_big,
170                i,
171                scratch_1,
172            );
173        })
174    }
175
176    pub fn automorphism_sub_ab<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
177        &mut self,
178        module: &Module<B>,
179        lhs: &GLWECiphertext<DataLhs>,
180        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
181        scratch: &mut Scratch<B>,
182    ) where
183        Module<B>: VecZnxDftAllocBytes
184            + VmpApplyDftToDftTmpBytes
185            + VecZnxBigNormalizeTmpBytes
186            + VmpApplyDftToDft<B>
187            + VmpApplyDftToDftAdd<B>
188            + VecZnxDftApply<B>
189            + VecZnxIdftApplyConsume<B>
190            + VecZnxBigAddSmallInplace<B>
191            + VecZnxBigNormalize<B>
192            + VecZnxBigAutomorphismInplace<B>
193            + VecZnxBigSubSmallInplace<B>
194            + VecZnxNormalizeTmpBytes
195            + VecZnxNormalize<B>,
196        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
197    {
198        #[cfg(debug_assertions)]
199        {
200            self.assert_keyswitch(module, lhs, &rhs.key, scratch);
201        }
202        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size
203        let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
204        (0..(self.rank() + 1).into()).for_each(|i| {
205            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
206            module.vec_znx_big_sub_small_inplace(&mut res_big, i, &lhs.data, i);
207            module.vec_znx_big_normalize(
208                self.base2k().into(),
209                &mut self.data,
210                i,
211                rhs.base2k().into(),
212                &res_big,
213                i,
214                scratch_1,
215            );
216        })
217    }
218
219    pub fn automorphism_sub_inplace<DataRhs: DataRef, B: Backend>(
220        &mut self,
221        module: &Module<B>,
222        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
223        scratch: &mut Scratch<B>,
224    ) where
225        Module<B>: VecZnxDftAllocBytes
226            + VmpApplyDftToDftTmpBytes
227            + VecZnxBigNormalizeTmpBytes
228            + VmpApplyDftToDft<B>
229            + VmpApplyDftToDftAdd<B>
230            + VecZnxDftApply<B>
231            + VecZnxIdftApplyConsume<B>
232            + VecZnxBigAddSmallInplace<B>
233            + VecZnxBigNormalize<B>
234            + VecZnxBigAutomorphismInplace<B>
235            + VecZnxBigSubSmallInplace<B>
236            + VecZnxNormalizeTmpBytes
237            + VecZnxNormalize<B>,
238        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
239    {
240        #[cfg(debug_assertions)]
241        {
242            self.assert_keyswitch_inplace(module, &rhs.key, scratch);
243        }
244        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size
245        let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
246        (0..(self.rank() + 1).into()).for_each(|i| {
247            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
248            module.vec_znx_big_sub_small_inplace(&mut res_big, i, &self.data, i);
249            module.vec_znx_big_normalize(
250                self.base2k().into(),
251                &mut self.data,
252                i,
253                rhs.base2k().into(),
254                &res_big,
255                i,
256                scratch_1,
257            );
258        })
259    }
260
261    pub fn automorphism_sub_negate<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
262        &mut self,
263        module: &Module<B>,
264        lhs: &GLWECiphertext<DataLhs>,
265        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
266        scratch: &mut Scratch<B>,
267    ) where
268        Module<B>: VecZnxDftAllocBytes
269            + VmpApplyDftToDftTmpBytes
270            + VecZnxBigNormalizeTmpBytes
271            + VmpApplyDftToDft<B>
272            + VmpApplyDftToDftAdd<B>
273            + VecZnxDftApply<B>
274            + VecZnxIdftApplyConsume<B>
275            + VecZnxBigAddSmallInplace<B>
276            + VecZnxBigNormalize<B>
277            + VecZnxBigAutomorphismInplace<B>
278            + VecZnxBigSubSmallNegateInplace<B>
279            + VecZnxNormalizeTmpBytes
280            + VecZnxNormalize<B>,
281        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
282    {
283        #[cfg(debug_assertions)]
284        {
285            self.assert_keyswitch(module, lhs, &rhs.key, scratch);
286        }
287        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size
288        let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
289        (0..(self.rank() + 1).into()).for_each(|i| {
290            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
291            module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &lhs.data, i);
292            module.vec_znx_big_normalize(
293                self.base2k().into(),
294                &mut self.data,
295                i,
296                rhs.base2k().into(),
297                &res_big,
298                i,
299                scratch_1,
300            );
301        })
302    }
303
304    pub fn automorphism_sub_negate_inplace<DataRhs: DataRef, B: Backend>(
305        &mut self,
306        module: &Module<B>,
307        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
308        scratch: &mut Scratch<B>,
309    ) where
310        Module<B>: VecZnxDftAllocBytes
311            + VmpApplyDftToDftTmpBytes
312            + VecZnxBigNormalizeTmpBytes
313            + VmpApplyDftToDft<B>
314            + VmpApplyDftToDftAdd<B>
315            + VecZnxDftApply<B>
316            + VecZnxIdftApplyConsume<B>
317            + VecZnxBigAddSmallInplace<B>
318            + VecZnxBigNormalize<B>
319            + VecZnxBigAutomorphismInplace<B>
320            + VecZnxBigSubSmallNegateInplace<B>
321            + VecZnxNormalizeTmpBytes
322            + VecZnxNormalize<B>,
323        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
324    {
325        #[cfg(debug_assertions)]
326        {
327            self.assert_keyswitch_inplace(module, &rhs.key, scratch);
328        }
329        let (res_dft, scratch_1) = scratch.take_vec_znx_dft(self.n().into(), (self.rank() + 1).into(), rhs.size()); // TODO: optimise size
330        let mut res_big: VecZnxBig<_, B> = self.keyswitch_internal(module, res_dft, &rhs.key, scratch_1);
331        (0..(self.rank() + 1).into()).for_each(|i| {
332            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i, scratch_1);
333            module.vec_znx_big_sub_small_negate_inplace(&mut res_big, i, &self.data, i);
334            module.vec_znx_big_normalize(
335                self.base2k().into(),
336                &mut self.data,
337                i,
338                rhs.base2k().into(),
339                &res_big,
340                i,
341                scratch_1,
342            );
343        })
344    }
345}