1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4 VecZnxBigAutomorphismInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxBigSubSmallAInplace,
5 VecZnxBigSubSmallBInplace, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnxBig},
8};
9
10use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared};
11
12impl GLWECiphertext<Vec<u8>> {
13 #[allow(clippy::too_many_arguments)]
14 pub fn automorphism_scratch_space<B: Backend>(
15 module: &Module<B>,
16 basek: usize,
17 k_out: usize,
18 k_in: usize,
19 k_ksk: usize,
20 digits: usize,
21 rank: usize,
22 ) -> usize
23 where
24 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
25 {
26 Self::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
27 }
28
29 pub fn automorphism_inplace_scratch_space<B: Backend>(
30 module: &Module<B>,
31 basek: usize,
32 k_out: usize,
33 k_ksk: usize,
34 digits: usize,
35 rank: usize,
36 ) -> usize
37 where
38 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
39 {
40 Self::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
41 }
42}
43
44impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
45 pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
46 &mut self,
47 module: &Module<B>,
48 lhs: &GLWECiphertext<DataLhs>,
49 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
50 scratch: &mut Scratch<B>,
51 ) where
52 Module<B>: VecZnxDftAllocBytes
53 + VmpApplyDftToDftTmpBytes
54 + VecZnxBigNormalizeTmpBytes
55 + VmpApplyDftToDft<B>
56 + VmpApplyDftToDftAdd<B>
57 + DFT<B>
58 + IDFTConsume<B>
59 + VecZnxBigAddSmallInplace<B>
60 + VecZnxBigNormalize<B>
61 + VecZnxAutomorphismInplace,
62 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
63 {
64 self.keyswitch(module, lhs, &rhs.key, scratch);
65 (0..self.rank() + 1).for_each(|i| {
66 module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
67 })
68 }
69
70 pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
71 &mut self,
72 module: &Module<B>,
73 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
74 scratch: &mut Scratch<B>,
75 ) where
76 Module<B>: VecZnxDftAllocBytes
77 + VmpApplyDftToDftTmpBytes
78 + VecZnxBigNormalizeTmpBytes
79 + VmpApplyDftToDft<B>
80 + VmpApplyDftToDftAdd<B>
81 + DFT<B>
82 + IDFTConsume<B>
83 + VecZnxBigAddSmallInplace<B>
84 + VecZnxBigNormalize<B>
85 + VecZnxAutomorphismInplace,
86 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
87 {
88 self.keyswitch_inplace(module, &rhs.key, scratch);
89 (0..self.rank() + 1).for_each(|i| {
90 module.vec_znx_automorphism_inplace(rhs.p(), &mut self.data, i);
91 })
92 }
93
94 pub fn automorphism_add<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
95 &mut self,
96 module: &Module<B>,
97 lhs: &GLWECiphertext<DataLhs>,
98 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
99 scratch: &mut Scratch<B>,
100 ) where
101 Module<B>: VecZnxDftAllocBytes
102 + VmpApplyDftToDftTmpBytes
103 + VecZnxBigNormalizeTmpBytes
104 + VmpApplyDftToDft<B>
105 + VmpApplyDftToDftAdd<B>
106 + DFT<B>
107 + IDFTConsume<B>
108 + VecZnxBigAddSmallInplace<B>
109 + VecZnxBigNormalize<B>
110 + VecZnxBigAutomorphismInplace<B>,
111 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
112 {
113 #[cfg(debug_assertions)]
114 {
115 self.assert_keyswitch(module, lhs, &rhs.key, scratch);
116 }
117 let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
119 (0..self.cols()).for_each(|i| {
120 module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
121 module.vec_znx_big_add_small_inplace(&mut res_big, i, &lhs.data, i);
122 module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
123 })
124 }
125
126 pub fn automorphism_add_inplace<DataRhs: DataRef, B: Backend>(
127 &mut self,
128 module: &Module<B>,
129 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
130 scratch: &mut Scratch<B>,
131 ) where
132 Module<B>: VecZnxDftAllocBytes
133 + VmpApplyDftToDftTmpBytes
134 + VecZnxBigNormalizeTmpBytes
135 + VmpApplyDftToDft<B>
136 + VmpApplyDftToDftAdd<B>
137 + DFT<B>
138 + IDFTConsume<B>
139 + VecZnxBigAddSmallInplace<B>
140 + VecZnxBigNormalize<B>
141 + VecZnxBigAutomorphismInplace<B>,
142 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
143 {
144 unsafe {
145 let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
146 self.automorphism_add(module, &*self_ptr, rhs, scratch);
147 }
148 }
149
150 pub fn automorphism_sub_ab<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
151 &mut self,
152 module: &Module<B>,
153 lhs: &GLWECiphertext<DataLhs>,
154 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
155 scratch: &mut Scratch<B>,
156 ) where
157 Module<B>: VecZnxDftAllocBytes
158 + VmpApplyDftToDftTmpBytes
159 + VecZnxBigNormalizeTmpBytes
160 + VmpApplyDftToDft<B>
161 + VmpApplyDftToDftAdd<B>
162 + DFT<B>
163 + IDFTConsume<B>
164 + VecZnxBigAddSmallInplace<B>
165 + VecZnxBigNormalize<B>
166 + VecZnxBigAutomorphismInplace<B>
167 + VecZnxBigSubSmallAInplace<B>,
168 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
169 {
170 #[cfg(debug_assertions)]
171 {
172 self.assert_keyswitch(module, lhs, &rhs.key, scratch);
173 }
174 let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
176 (0..self.cols()).for_each(|i| {
177 module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
178 module.vec_znx_big_sub_small_a_inplace(&mut res_big, i, &lhs.data, i);
179 module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
180 })
181 }
182
183 pub fn automorphism_sub_ab_inplace<DataRhs: DataRef, B: Backend>(
184 &mut self,
185 module: &Module<B>,
186 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
187 scratch: &mut Scratch<B>,
188 ) where
189 Module<B>: VecZnxDftAllocBytes
190 + VmpApplyDftToDftTmpBytes
191 + VecZnxBigNormalizeTmpBytes
192 + VmpApplyDftToDft<B>
193 + VmpApplyDftToDftAdd<B>
194 + DFT<B>
195 + IDFTConsume<B>
196 + VecZnxBigAddSmallInplace<B>
197 + VecZnxBigNormalize<B>
198 + VecZnxBigAutomorphismInplace<B>
199 + VecZnxBigSubSmallAInplace<B>,
200 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
201 {
202 unsafe {
203 let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
204 self.automorphism_sub_ab(module, &*self_ptr, rhs, scratch);
205 }
206 }
207
208 pub fn automorphism_sub_ba<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
209 &mut self,
210 module: &Module<B>,
211 lhs: &GLWECiphertext<DataLhs>,
212 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
213 scratch: &mut Scratch<B>,
214 ) where
215 Module<B>: VecZnxDftAllocBytes
216 + VmpApplyDftToDftTmpBytes
217 + VecZnxBigNormalizeTmpBytes
218 + VmpApplyDftToDft<B>
219 + VmpApplyDftToDftAdd<B>
220 + DFT<B>
221 + IDFTConsume<B>
222 + VecZnxBigAddSmallInplace<B>
223 + VecZnxBigNormalize<B>
224 + VecZnxBigAutomorphismInplace<B>
225 + VecZnxBigSubSmallBInplace<B>,
226 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
227 {
228 #[cfg(debug_assertions)]
229 {
230 self.assert_keyswitch(module, lhs, &rhs.key, scratch);
231 }
232 let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); let mut res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, &rhs.key, scratch1);
234 (0..self.cols()).for_each(|i| {
235 module.vec_znx_big_automorphism_inplace(rhs.p(), &mut res_big, i);
236 module.vec_znx_big_sub_small_b_inplace(&mut res_big, i, &lhs.data, i);
237 module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
238 })
239 }
240
241 pub fn automorphism_sub_ba_inplace<DataRhs: DataRef, B: Backend>(
242 &mut self,
243 module: &Module<B>,
244 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
245 scratch: &mut Scratch<B>,
246 ) where
247 Module<B>: VecZnxDftAllocBytes
248 + VmpApplyDftToDftTmpBytes
249 + VecZnxBigNormalizeTmpBytes
250 + VmpApplyDftToDft<B>
251 + VmpApplyDftToDftAdd<B>
252 + DFT<B>
253 + IDFTConsume<B>
254 + VecZnxBigAddSmallInplace<B>
255 + VecZnxBigNormalize<B>
256 + VecZnxBigAutomorphismInplace<B>
257 + VecZnxBigSubSmallBInplace<B>,
258 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
259 {
260 unsafe {
261 let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
262 self.automorphism_sub_ba(module, &*self_ptr, rhs, scratch);
263 }
264 }
265}