poulpy_core/keyswitching/
gglwe_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxZero,
5 },
6 layouts::{Backend, DataMut, DataRef, Module, Scratch},
7};
8
9use crate::layouts::{
10 GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, Infos,
11 prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
12};
13
14impl GGLWEAutomorphismKey<Vec<u8>> {
15 #[allow(clippy::too_many_arguments)]
16 pub fn keyswitch_scratch_space<B: Backend>(
17 module: &Module<B>,
18 n: usize,
19 basek: usize,
20 k_out: usize,
21 k_in: usize,
22 k_ksk: usize,
23 digits: usize,
24 rank: usize,
25 ) -> usize
26 where
27 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
28 {
29 GGLWESwitchingKey::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank)
30 }
31
32 pub fn keyswitch_inplace_scratch_space<B: Backend>(
33 module: &Module<B>,
34 n: usize,
35 basek: usize,
36 k_out: usize,
37 k_ksk: usize,
38 digits: usize,
39 rank: usize,
40 ) -> usize
41 where
42 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
43 {
44 GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank)
45 }
46}
47
48impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
49 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
50 &mut self,
51 module: &Module<B>,
52 lhs: &GGLWEAutomorphismKey<DataLhs>,
53 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
54 scratch: &mut Scratch<B>,
55 ) where
56 Module<B>: VecZnxDftAllocBytes
57 + VmpApplyTmpBytes
58 + VecZnxBigNormalizeTmpBytes
59 + VmpApply<B>
60 + VmpApplyAdd<B>
61 + VecZnxDftFromVecZnx<B>
62 + VecZnxDftToVecZnxBigConsume<B>
63 + VecZnxBigAddSmallInplace<B>
64 + VecZnxBigNormalize<B>,
65 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
66 {
67 self.key.keyswitch(module, &lhs.key, rhs, scratch);
68 }
69
70 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
71 &mut self,
72 module: &Module<B>,
73 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
74 scratch: &mut Scratch<B>,
75 ) where
76 Module<B>: VecZnxDftAllocBytes
77 + VmpApplyTmpBytes
78 + VecZnxBigNormalizeTmpBytes
79 + VmpApply<B>
80 + VmpApplyAdd<B>
81 + VecZnxDftFromVecZnx<B>
82 + VecZnxDftToVecZnxBigConsume<B>
83 + VecZnxBigAddSmallInplace<B>
84 + VecZnxBigNormalize<B>,
85 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
86 {
87 self.key.keyswitch_inplace(module, &rhs.key, scratch);
88 }
89}
90
91impl GGLWESwitchingKey<Vec<u8>> {
92 #[allow(clippy::too_many_arguments)]
93 pub fn keyswitch_scratch_space<B: Backend>(
94 module: &Module<B>,
95 n: usize,
96 basek: usize,
97 k_out: usize,
98 k_in: usize,
99 k_ksk: usize,
100 digits: usize,
101 rank_in: usize,
102 rank_out: usize,
103 ) -> usize
104 where
105 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
106 {
107 GLWECiphertext::keyswitch_scratch_space(
108 module, n, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out,
109 )
110 }
111
112 pub fn keyswitch_inplace_scratch_space<B: Backend>(
113 module: &Module<B>,
114 n: usize,
115 basek: usize,
116 k_out: usize,
117 k_ksk: usize,
118 digits: usize,
119 rank: usize,
120 ) -> usize
121 where
122 Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
123 {
124 GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank)
125 }
126}
127
128impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
129 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
130 &mut self,
131 module: &Module<B>,
132 lhs: &GGLWESwitchingKey<DataLhs>,
133 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
134 scratch: &mut Scratch<B>,
135 ) where
136 Module<B>: VecZnxDftAllocBytes
137 + VmpApplyTmpBytes
138 + VecZnxBigNormalizeTmpBytes
139 + VmpApply<B>
140 + VmpApplyAdd<B>
141 + VecZnxDftFromVecZnx<B>
142 + VecZnxDftToVecZnxBigConsume<B>
143 + VecZnxBigAddSmallInplace<B>
144 + VecZnxBigNormalize<B>,
145 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
146 {
147 #[cfg(debug_assertions)]
148 {
149 assert_eq!(
150 self.rank_in(),
151 lhs.rank_in(),
152 "ksk_out input rank: {} != ksk_in input rank: {}",
153 self.rank_in(),
154 lhs.rank_in()
155 );
156 assert_eq!(
157 lhs.rank_out(),
158 rhs.rank_in(),
159 "ksk_in output rank: {} != ksk_apply input rank: {}",
160 self.rank_out(),
161 rhs.rank_in()
162 );
163 assert_eq!(
164 self.rank_out(),
165 rhs.rank_out(),
166 "ksk_out output rank: {} != ksk_apply output rank: {}",
167 self.rank_out(),
168 rhs.rank_out()
169 );
170 }
171
172 (0..self.rank_in()).for_each(|col_i| {
173 (0..self.rows()).for_each(|row_j| {
174 self.at_mut(row_j, col_i)
175 .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
176 });
177 });
178
179 (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
180 (0..self.rank_in()).for_each(|col_j| {
181 self.at_mut(row_i, col_j).data.zero();
182 });
183 });
184 }
185
186 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
187 &mut self,
188 module: &Module<B>,
189 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
190 scratch: &mut Scratch<B>,
191 ) where
192 Module<B>: VecZnxDftAllocBytes
193 + VmpApplyTmpBytes
194 + VecZnxBigNormalizeTmpBytes
195 + VmpApply<B>
196 + VmpApplyAdd<B>
197 + VecZnxDftFromVecZnx<B>
198 + VecZnxDftToVecZnxBigConsume<B>
199 + VecZnxBigAddSmallInplace<B>
200 + VecZnxBigNormalize<B>,
201 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
202 {
203 #[cfg(debug_assertions)]
204 {
205 assert_eq!(
206 self.rank_out(),
207 rhs.rank_out(),
208 "ksk_out output rank: {} != ksk_apply output rank: {}",
209 self.rank_out(),
210 rhs.rank_out()
211 );
212 }
213
214 (0..self.rank_in()).for_each(|col_i| {
215 (0..self.rows()).for_each(|row_j| {
216 self.at_mut(row_j, col_i)
217 .keyswitch_inplace(module, rhs, scratch)
218 });
219 });
220 }
221}