poulpy_core/keyswitching/
gglwe_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4 VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5 VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{
11 GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos,
12 prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
13};
14
15impl GGLWEAutomorphismKey<Vec<u8>> {
16 pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
17 module: &Module<B>,
18 out_infos: &OUT,
19 in_infos: &IN,
20 key_infos: &KEY,
21 ) -> usize
22 where
23 OUT: GGLWELayoutInfos,
24 IN: GGLWELayoutInfos,
25 KEY: GGLWELayoutInfos,
26 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
27 {
28 GGLWESwitchingKey::keyswitch_scratch_space(module, out_infos, in_infos, key_infos)
29 }
30
31 pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
32 where
33 OUT: GGLWELayoutInfos,
34 KEY: GGLWELayoutInfos,
35 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
36 {
37 GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_infos, key_infos)
38 }
39}
40
41impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
42 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
43 &mut self,
44 module: &Module<B>,
45 lhs: &GGLWEAutomorphismKey<DataLhs>,
46 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
47 scratch: &mut Scratch<B>,
48 ) where
49 Module<B>: VecZnxDftAllocBytes
50 + VmpApplyDftToDftTmpBytes
51 + VecZnxBigNormalizeTmpBytes
52 + VmpApplyDftToDft<B>
53 + VmpApplyDftToDftAdd<B>
54 + VecZnxDftApply<B>
55 + VecZnxIdftApplyConsume<B>
56 + VecZnxBigAddSmallInplace<B>
57 + VecZnxBigNormalize<B>
58 + VecZnxNormalize<B>
59 + VecZnxNormalizeTmpBytes,
60 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
61 {
62 self.key.keyswitch(module, &lhs.key, rhs, scratch);
63 }
64
65 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
66 &mut self,
67 module: &Module<B>,
68 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
69 scratch: &mut Scratch<B>,
70 ) where
71 Module<B>: VecZnxDftAllocBytes
72 + VmpApplyDftToDftTmpBytes
73 + VecZnxBigNormalizeTmpBytes
74 + VmpApplyDftToDft<B>
75 + VmpApplyDftToDftAdd<B>
76 + VecZnxDftApply<B>
77 + VecZnxIdftApplyConsume<B>
78 + VecZnxBigAddSmallInplace<B>
79 + VecZnxBigNormalize<B>
80 + VecZnxNormalize<B>
81 + VecZnxNormalizeTmpBytes,
82 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
83 {
84 self.key.keyswitch_inplace(module, &rhs.key, scratch);
85 }
86}
87
88impl GGLWESwitchingKey<Vec<u8>> {
89 pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
90 module: &Module<B>,
91 out_infos: &OUT,
92 in_infos: &IN,
93 key_apply: &KEY,
94 ) -> usize
95 where
96 OUT: GGLWELayoutInfos,
97 IN: GGLWELayoutInfos,
98 KEY: GGLWELayoutInfos,
99 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
100 {
101 GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, key_apply)
102 }
103
104 pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_apply: &KEY) -> usize
105 where
106 OUT: GGLWELayoutInfos + GLWEInfos,
107 KEY: GGLWELayoutInfos + GLWEInfos,
108 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
109 {
110 GLWECiphertext::keyswitch_inplace_scratch_space(module, out_infos, key_apply)
111 }
112}
113
114impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
115 pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
116 &mut self,
117 module: &Module<B>,
118 lhs: &GGLWESwitchingKey<DataLhs>,
119 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
120 scratch: &mut Scratch<B>,
121 ) where
122 Module<B>: VecZnxDftAllocBytes
123 + VmpApplyDftToDftTmpBytes
124 + VecZnxBigNormalizeTmpBytes
125 + VmpApplyDftToDft<B>
126 + VmpApplyDftToDftAdd<B>
127 + VecZnxDftApply<B>
128 + VecZnxIdftApplyConsume<B>
129 + VecZnxBigAddSmallInplace<B>
130 + VecZnxBigNormalize<B>
131 + VecZnxNormalize<B>
132 + VecZnxNormalizeTmpBytes,
133 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
134 {
135 #[cfg(debug_assertions)]
136 {
137 assert_eq!(
138 self.rank_in(),
139 lhs.rank_in(),
140 "ksk_out input rank: {} != ksk_in input rank: {}",
141 self.rank_in(),
142 lhs.rank_in()
143 );
144 assert_eq!(
145 lhs.rank_out(),
146 rhs.rank_in(),
147 "ksk_in output rank: {} != ksk_apply input rank: {}",
148 self.rank_out(),
149 rhs.rank_in()
150 );
151 assert_eq!(
152 self.rank_out(),
153 rhs.rank_out(),
154 "ksk_out output rank: {} != ksk_apply output rank: {}",
155 self.rank_out(),
156 rhs.rank_out()
157 );
158 assert!(
159 self.rows() <= lhs.rows(),
160 "self.rows()={} > lhs.rows()={}",
161 self.rows(),
162 lhs.rows()
163 );
164 assert_eq!(
165 self.digits(),
166 lhs.digits(),
167 "ksk_out digits: {} != ksk_in digits: {}",
168 self.digits(),
169 lhs.digits()
170 )
171 }
172
173 (0..self.rank_in().into()).for_each(|col_i| {
174 (0..self.rows().into()).for_each(|row_j| {
175 self.at_mut(row_j, col_i)
176 .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
177 });
178 });
179
180 (self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| {
181 (0..self.rank_in().into()).for_each(|col_j| {
182 self.at_mut(row_i, col_j).data.zero();
183 });
184 });
185 }
186
187 pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
188 &mut self,
189 module: &Module<B>,
190 rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
191 scratch: &mut Scratch<B>,
192 ) where
193 Module<B>: VecZnxDftAllocBytes
194 + VmpApplyDftToDftTmpBytes
195 + VecZnxBigNormalizeTmpBytes
196 + VmpApplyDftToDft<B>
197 + VmpApplyDftToDftAdd<B>
198 + VecZnxDftApply<B>
199 + VecZnxIdftApplyConsume<B>
200 + VecZnxBigAddSmallInplace<B>
201 + VecZnxBigNormalize<B>
202 + VecZnxNormalize<B>
203 + VecZnxNormalizeTmpBytes,
204 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
205 {
206 #[cfg(debug_assertions)]
207 {
208 assert_eq!(
209 self.rank_out(),
210 rhs.rank_out(),
211 "ksk_out output rank: {} != ksk_apply output rank: {}",
212 self.rank_out(),
213 rhs.rank_out()
214 );
215 }
216
217 (0..self.rank_in().into()).for_each(|col_i| {
218 (0..self.rows().into()).for_each(|row_j| {
219 self.at_mut(row_j, col_i)
220 .keyswitch_inplace(module, rhs, scratch)
221 });
222 });
223 }
224}