poulpy_core/automorphism/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4        VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace,
5        VecZnxBigSubSmallBInplace, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig},
8};
9
10use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared};
11
12impl GLWECiphertext<Vec<u8>> {
13    #[allow(clippy::too_many_arguments)]
14    pub fn automorphism_scratch_space<B: Backend>(
15        module: &Module<B>,
16        basek: usize,
17        k_out: usize,
18        k_in: usize,
19        k_ksk: usize,
20        digits: usize,
21        rank: usize,
22    ) -> usize
23    where
24        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
25    {
26        Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
27    }
28
29    pub fn automorphism_inplace_scratch_space<B: Backend>(
30        module: &Module<B>,
31        basek: usize,
32        k_out: usize,
33        k_ksk: usize,
34        digits: usize,
35        rank: usize,
36    ) -> usize
37    where
38        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
39    {
40        Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
41    }
42}
43
44impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
45    pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
46        &mut self,
47        module: &Module<B>,
48        lhs: &GLWECiphertext<DataLhs>,
49        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
50        scratch: &mut Scratch<B>,
51    ) where
52        Module<B>: VecZnxDftAllocBytes
53            + VmpApplyDftToDftTmpBytes
54            + VecZnxBigNormalizeTmpBytes
55            + VmpApplyDftToDft<B>
56            + VmpApplyDftToDftAdd<B>
57            + DFT<B>
58            + IDFTConsume<B>
59            + VecZnxBigAddSmallInplace<B>
60            + VecZnxBigNormalize<B>
61            + VecZnxAutomorphismInplace,
62        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
63    {
64        self.keyswitch(module, lhs, &rhs.key, scratch);
65        (0..self.rank() + 1).for_each(|i| {
66            module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
67        })
68    }
69
70    pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
71        &mut self,
72        module: &Module<B>,
73        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
74        scratch: &mut Scratch<B>,
75    ) where
76        Module<B>: VecZnxDftAllocBytes
77            + VmpApplyDftToDftTmpBytes
78            + VecZnxBigNormalizeTmpBytes
79            + VmpApplyDftToDft<B>
80            + VmpApplyDftToDftAdd<B>
81            + DFT<B>
82            + IDFTConsume<B>
83            + VecZnxBigAddSmallInplace<B>
84            + VecZnxBigNormalize<B>
85            + VecZnxAutomorphismInplace,
86        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
87    {
88        self.keyswitch_inplace(module, &rhs.key, scratch);
89        (0..self.rank() + 1).for_each(|i| {
90            module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
91        })
92    }
93
94    pub fn automorphism_add<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
95        &mut self,
96        module: &Module<B>,
97        lhs: &GLWECiphertext<DataLhs>,
98        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
99        scratch: &mut Scratch<B>,
100    ) where
101        Module<B>: VecZnxDftAllocBytes
102            + VmpApplyDftToDftTmpBytes
103            + VecZnxBigNormalizeTmpBytes
104            + VmpApplyDftToDft<B>
105            + VmpApplyDftToDftAdd<B>
106            + DFT<B>
107            + IDFTConsume<B>
108            + VecZnxBigAddSmallInplace<B>
109            + VecZnxBigNormalize<B>
110            + VecZnxBigAutomorphismInplace<B>,
111        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
112    {
113        #[cfg(debug_assertions)]
114        {
115            self.assert_keyswitch(module, lhs, &rhs.key, scratch);
116        }
117        let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
118        let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
119        (0..self.cols()).for_each(|i| {
120            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
121            module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i);
122            module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
123        })
124    }
125
126    pub fn automorphism_add_inplace<DataRhs: DataRef, B: Backend>(
127        &mut self,
128        module: &Module<B>,
129        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
130        scratch: &mut Scratch<B>,
131    ) where
132        Module<B>: VecZnxDftAllocBytes
133            + VmpApplyDftToDftTmpBytes
134            + VecZnxBigNormalizeTmpBytes
135            + VmpApplyDftToDft<B>
136            + VmpApplyDftToDftAdd<B>
137            + DFT<B>
138            + IDFTConsume<B>
139            + VecZnxBigAddSmallInplace<B>
140            + VecZnxBigNormalize<B>
141            + VecZnxBigAutomorphismInplace<B>,
142        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
143    {
144        unsafe {
145            let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
146            self.automorphism_add(module, &*self_ptr, rhs, scratch);
147        }
148    }
149
150    pub fn automorphism_sub_ab<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
151        &mut self,
152        module: &Module<B>,
153        lhs: &GLWECiphertext<DataLhs>,
154        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
155        scratch: &mut Scratch<B>,
156    ) where
157        Module<B>: VecZnxDftAllocBytes
158            + VmpApplyDftToDftTmpBytes
159            + VecZnxBigNormalizeTmpBytes
160            + VmpApplyDftToDft<B>
161            + VmpApplyDftToDftAdd<B>
162            + DFT<B>
163            + IDFTConsume<B>
164            + VecZnxBigAddSmallInplace<B>
165            + VecZnxBigNormalize<B>
166            + VecZnxBigAutomorphismInplace<B>
167            + VecZnxBigSubSmallAInplace<B>,
168        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
169    {
170        #[cfg(debug_assertions)]
171        {
172            self.assert_keyswitch(module, lhs, &rhs.key, scratch);
173        }
174        let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
175        let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
176        (0..self.cols()).for_each(|i| {
177            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
178            module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i);
179            module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
180        })
181    }
182
183    pub fn automorphism_sub_ab_inplace<DataRhs: DataRef, B: Backend>(
184        &mut self,
185        module: &Module<B>,
186        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
187        scratch: &mut Scratch<B>,
188    ) where
189        Module<B>: VecZnxDftAllocBytes
190            + VmpApplyDftToDftTmpBytes
191            + VecZnxBigNormalizeTmpBytes
192            + VmpApplyDftToDft<B>
193            + VmpApplyDftToDftAdd<B>
194            + DFT<B>
195            + IDFTConsume<B>
196            + VecZnxBigAddSmallInplace<B>
197            + VecZnxBigNormalize<B>
198            + VecZnxBigAutomorphismInplace<B>
199            + VecZnxBigSubSmallAInplace<B>,
200        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
201    {
202        unsafe {
203            let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
204            self.automorphism_sub_ab(module, &*self_ptr, rhs, scratch);
205        }
206    }
207
208    pub fn automorphism_sub_ba<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
209        &mut self,
210        module: &Module<B>,
211        lhs: &GLWECiphertext<DataLhs>,
212        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
213        scratch: &mut Scratch<B>,
214    ) where
215        Module<B>: VecZnxDftAllocBytes
216            + VmpApplyDftToDftTmpBytes
217            + VecZnxBigNormalizeTmpBytes
218            + VmpApplyDftToDft<B>
219            + VmpApplyDftToDftAdd<B>
220            + DFT<B>
221            + IDFTConsume<B>
222            + VecZnxBigAddSmallInplace<B>
223            + VecZnxBigNormalize<B>
224            + VecZnxBigAutomorphismInplace<B>
225            + VecZnxBigSubSmallBInplace<B>,
226        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
227    {
228        #[cfg(debug_assertions)]
229        {
230            self.assert_keyswitch(module, lhs, &rhs.key, scratch);
231        }
232        let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // TODO: optimise size
233        let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
234        (0..self.cols()).for_each(|i| {
235            module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
236            module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i);
237            module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
238        })
239    }
240
241    pub fn automorphism_sub_ba_inplace<DataRhs: DataRef, B: Backend>(
242        &mut self,
243        module: &Module<B>,
244        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
245        scratch: &mut Scratch<B>,
246    ) where
247        Module<B>: VecZnxDftAllocBytes
248            + VmpApplyDftToDftTmpBytes
249            + VecZnxBigNormalizeTmpBytes
250            + VmpApplyDftToDft<B>
251            + VmpApplyDftToDftAdd<B>
252            + DFT<B>
253            + IDFTConsume<B>
254            + VecZnxBigAddSmallInplace<B>
255            + VecZnxBigNormalize<B>
256            + VecZnxBigAutomorphismInplace<B>
257            + VecZnxBigSubSmallBInplace<B>,
258        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
259    {
260        unsafe {
261            let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
262            self.automorphism_sub_ba(module, &*self_ptr, rhs, scratch);
263        }
264    }
265}