poulpy_core/keyswitching/
gglwe_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
5 VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{
11 GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, Infos,
12 prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
13};
14
15impl GGLWEAutomorphismKey<Vec<u8>> {
16 #[allow(clippy::too_many_arguments)]
17 pub fn keyswitch_scratch_space<B: Backend>(
18 module: &Module<B>,
19 basek: usize,
20 k_out: usize,
21 k_in: usize,
22 k_ksk: usize,
23 digits: usize,
24 rank: usize,
25 ) -> usize
26 where
27 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
28 {
29 GGLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
30 }
31
32 pub fn keyswitch_inplace_scratch_space<B: Backend>(
33 module: &Module<B>,
34 basek: usize,
35 k_out: usize,
36 k_ksk: usize,
37 digits: usize,
38 rank: usize,
39 ) -> usize
40 where
41 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
42 {
43 GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
44 }
45}
46
47impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
48 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
49 &mut self,
50 module: &Module<B>,
51 lhs: &GGLWEAutomorphismKey<DataLhs>,
52 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
53 scratch: &mut Scratch<B>,
54 ) where
55 Module<B>: VecZnxDftAllocBytes
56 + VmpApplyDftToDftTmpBytes
57 + VecZnxBigNormalizeTmpBytes
58 + VmpApplyDftToDft<B>
59 + VmpApplyDftToDftAdd<B>
60 + VecZnxDftApply<B>
61 + VecZnxIdftApplyConsume<B>
62 + VecZnxBigAddSmallInplace<B>
63 + VecZnxBigNormalize<B>,
64 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
65 {
66 self.key.keyswitch(module, &lhs.key, rhs, scratch);
67 }
68
69 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
70 &mut self,
71 module: &Module<B>,
72 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
73 scratch: &mut Scratch<B>,
74 ) where
75 Module<B>: VecZnxDftAllocBytes
76 + VmpApplyDftToDftTmpBytes
77 + VecZnxBigNormalizeTmpBytes
78 + VmpApplyDftToDft<B>
79 + VmpApplyDftToDftAdd<B>
80 + VecZnxDftApply<B>
81 + VecZnxIdftApplyConsume<B>
82 + VecZnxBigAddSmallInplace<B>
83 + VecZnxBigNormalize<B>,
84 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
85 {
86 self.key.keyswitch_inplace(module, &rhs.key, scratch);
87 }
88}
89
90impl GGLWESwitchingKey<Vec<u8>> {
91 #[allow(clippy::too_many_arguments)]
92 pub fn keyswitch_scratch_space<B: Backend>(
93 module: &Module<B>,
94 basek: usize,
95 k_out: usize,
96 k_in: usize,
97 k_ksk: usize,
98 digits: usize,
99 rank_in: usize,
100 rank_out: usize,
101 ) -> usize
102 where
103 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
104 {
105 GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out)
106 }
107
108 pub fn keyswitch_inplace_scratch_space<B: Backend>(
109 module: &Module<B>,
110 basek: usize,
111 k_out: usize,
112 k_ksk: usize,
113 digits: usize,
114 rank: usize,
115 ) -> usize
116 where
117 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
118 {
119 GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
120 }
121}
122
123impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
124 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
125 &mut self,
126 module: &Module<B>,
127 lhs: &GGLWESwitchingKey<DataLhs>,
128 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
129 scratch: &mut Scratch<B>,
130 ) where
131 Module<B>: VecZnxDftAllocBytes
132 + VmpApplyDftToDftTmpBytes
133 + VecZnxBigNormalizeTmpBytes
134 + VmpApplyDftToDft<B>
135 + VmpApplyDftToDftAdd<B>
136 + VecZnxDftApply<B>
137 + VecZnxIdftApplyConsume<B>
138 + VecZnxBigAddSmallInplace<B>
139 + VecZnxBigNormalize<B>,
140 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
141 {
142 #[cfg(debug_assertions)]
143 {
144 assert_eq!(
145 self.rank_in(),
146 lhs.rank_in(),
147 "ksk_out input rank: {} != ksk_in input rank: {}",
148 self.rank_in(),
149 lhs.rank_in()
150 );
151 assert_eq!(
152 lhs.rank_out(),
153 rhs.rank_in(),
154 "ksk_in output rank: {} != ksk_apply input rank: {}",
155 self.rank_out(),
156 rhs.rank_in()
157 );
158 assert_eq!(
159 self.rank_out(),
160 rhs.rank_out(),
161 "ksk_out output rank: {} != ksk_apply output rank: {}",
162 self.rank_out(),
163 rhs.rank_out()
164 );
165 assert!(
166 self.rows() <= lhs.rows(),
167 "self.rows()={} > lhs.rows()={}",
168 self.rows(),
169 lhs.rows()
170 );
171 }
172
173 (0..self.rank_in()).for_each(|col_i| {
174 (0..self.rows()).for_each(|row_j| {
175 self.at_mut(row_j, col_i)
176 .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
177 });
178 });
179
180 (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
181 (0..self.rank_in()).for_each(|col_j| {
182 self.at_mut(row_i, col_j).data.zero();
183 });
184 });
185 }
186
187 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
188 &mut self,
189 module: &Module<B>,
190 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
191 scratch: &mut Scratch<B>,
192 ) where
193 Module<B>: VecZnxDftAllocBytes
194 + VmpApplyDftToDftTmpBytes
195 + VecZnxBigNormalizeTmpBytes
196 + VmpApplyDftToDft<B>
197 + VmpApplyDftToDftAdd<B>
198 + VecZnxDftApply<B>
199 + VecZnxIdftApplyConsume<B>
200 + VecZnxBigAddSmallInplace<B>
201 + VecZnxBigNormalize<B>,
202 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
203 {
204 #[cfg(debug_assertions)]
205 {
206 assert_eq!(
207 self.rank_out(),
208 rhs.rank_out(),
209 "ksk_out output rank: {} != ksk_apply output rank: {}",
210 self.rank_out(),
211 rhs.rank_out()
212 );
213 }
214
215 (0..self.rank_in()).for_each(|col_i| {
216 (0..self.rows()).for_each(|row_j| {
217 self.at_mut(row_j, col_i)
218 .keyswitch_inplace(module, rhs, scratch)
219 });
220 });
221 }
222}