poulpy_core/keyswitching/
gglwe_ct.rs1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
4 VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5 },
6 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
7};
8
9use crate::layouts::{
10 GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, Infos,
11 prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
12};
13
14impl GGLWEAutomorphismKey<Vec<u8>> {
15 #[allow(clippy::too_many_arguments)]
16 pub fn keyswitch_scratch_space<B: Backend>(
17 module: &Module<B>,
18 basek: usize,
19 k_out: usize,
20 k_in: usize,
21 k_ksk: usize,
22 digits: usize,
23 rank: usize,
24 ) -> usize
25 where
26 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
27 {
28 GGLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
29 }
30
31 pub fn keyswitch_inplace_scratch_space<B: Backend>(
32 module: &Module<B>,
33 basek: usize,
34 k_out: usize,
35 k_ksk: usize,
36 digits: usize,
37 rank: usize,
38 ) -> usize
39 where
40 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
41 {
42 GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
43 }
44}
45
46impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
47 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
48 &mut self,
49 module: &Module<B>,
50 lhs: &GGLWEAutomorphismKey<DataLhs>,
51 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
52 scratch: &mut Scratch<B>,
53 ) where
54 Module<B>: VecZnxDftAllocBytes
55 + VmpApplyDftToDftTmpBytes
56 + VecZnxBigNormalizeTmpBytes
57 + VmpApplyDftToDft<B>
58 + VmpApplyDftToDftAdd<B>
59 + DFT<B>
60 + IDFTConsume<B>
61 + VecZnxBigAddSmallInplace<B>
62 + VecZnxBigNormalize<B>,
63 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
64 {
65 self.key.keyswitch(module, &lhs.key, rhs, scratch);
66 }
67
68 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
69 &mut self,
70 module: &Module<B>,
71 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
72 scratch: &mut Scratch<B>,
73 ) where
74 Module<B>: VecZnxDftAllocBytes
75 + VmpApplyDftToDftTmpBytes
76 + VecZnxBigNormalizeTmpBytes
77 + VmpApplyDftToDft<B>
78 + VmpApplyDftToDftAdd<B>
79 + DFT<B>
80 + IDFTConsume<B>
81 + VecZnxBigAddSmallInplace<B>
82 + VecZnxBigNormalize<B>,
83 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
84 {
85 self.key.keyswitch_inplace(module, &rhs.key, scratch);
86 }
87}
88
89impl GGLWESwitchingKey<Vec<u8>> {
90 #[allow(clippy::too_many_arguments)]
91 pub fn keyswitch_scratch_space<B: Backend>(
92 module: &Module<B>,
93 basek: usize,
94 k_out: usize,
95 k_in: usize,
96 k_ksk: usize,
97 digits: usize,
98 rank_in: usize,
99 rank_out: usize,
100 ) -> usize
101 where
102 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
103 {
104 GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out)
105 }
106
107 pub fn keyswitch_inplace_scratch_space<B: Backend>(
108 module: &Module<B>,
109 basek: usize,
110 k_out: usize,
111 k_ksk: usize,
112 digits: usize,
113 rank: usize,
114 ) -> usize
115 where
116 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
117 {
118 GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
119 }
120}
121
122impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
123 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
124 &mut self,
125 module: &Module<B>,
126 lhs: &GGLWESwitchingKey<DataLhs>,
127 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
128 scratch: &mut Scratch<B>,
129 ) where
130 Module<B>: VecZnxDftAllocBytes
131 + VmpApplyDftToDftTmpBytes
132 + VecZnxBigNormalizeTmpBytes
133 + VmpApplyDftToDft<B>
134 + VmpApplyDftToDftAdd<B>
135 + DFT<B>
136 + IDFTConsume<B>
137 + VecZnxBigAddSmallInplace<B>
138 + VecZnxBigNormalize<B>,
139 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
140 {
141 #[cfg(debug_assertions)]
142 {
143 assert_eq!(
144 self.rank_in(),
145 lhs.rank_in(),
146 "ksk_out input rank: {} != ksk_in input rank: {}",
147 self.rank_in(),
148 lhs.rank_in()
149 );
150 assert_eq!(
151 lhs.rank_out(),
152 rhs.rank_in(),
153 "ksk_in output rank: {} != ksk_apply input rank: {}",
154 self.rank_out(),
155 rhs.rank_in()
156 );
157 assert_eq!(
158 self.rank_out(),
159 rhs.rank_out(),
160 "ksk_out output rank: {} != ksk_apply output rank: {}",
161 self.rank_out(),
162 rhs.rank_out()
163 );
164 }
165
166 (0..self.rank_in()).for_each(|col_i| {
167 (0..self.rows()).for_each(|row_j| {
168 self.at_mut(row_j, col_i)
169 .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
170 });
171 });
172
173 (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
174 (0..self.rank_in()).for_each(|col_j| {
175 self.at_mut(row_i, col_j).data.zero();
176 });
177 });
178 }
179
180 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
181 &mut self,
182 module: &Module<B>,
183 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
184 scratch: &mut Scratch<B>,
185 ) where
186 Module<B>: VecZnxDftAllocBytes
187 + VmpApplyDftToDftTmpBytes
188 + VecZnxBigNormalizeTmpBytes
189 + VmpApplyDftToDft<B>
190 + VmpApplyDftToDftAdd<B>
191 + DFT<B>
192 + IDFTConsume<B>
193 + VecZnxBigAddSmallInplace<B>
194 + VecZnxBigNormalize<B>,
195 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
196 {
197 #[cfg(debug_assertions)]
198 {
199 assert_eq!(
200 self.rank_out(),
201 rhs.rank_out(),
202 "ksk_out output rank: {} != ksk_apply output rank: {}",
203 self.rank_out(),
204 rhs.rank_out()
205 );
206 }
207
208 (0..self.rank_in()).for_each(|col_i| {
209 (0..self.rows()).for_each(|row_j| {
210 self.at_mut(row_j, col_i)
211 .keyswitch_inplace(module, rhs, scratch)
212 });
213 });
214 }
215}