poulpy_core/keyswitching/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5        VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{
11    GGLWEAutomorphismKey, GGLWELayoutInfos, GGLWESwitchingKey, GLWECiphertext, GLWEInfos,
12    prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
13};
14
15impl GGLWEAutomorphismKey<Vec<u8>> {
16    pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
17        module: &Module<B>,
18        out_infos: &OUT,
19        in_infos: &IN,
20        key_infos: &KEY,
21    ) -> usize
22    where
23        OUT: GGLWELayoutInfos,
24        IN: GGLWELayoutInfos,
25        KEY: GGLWELayoutInfos,
26        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
27    {
28        GGLWESwitchingKey::keyswitch_scratch_space(module, out_infos, in_infos, key_infos)
29    }
30
31    pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
32    where
33        OUT: GGLWELayoutInfos,
34        KEY: GGLWELayoutInfos,
35        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
36    {
37        GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, out_infos, key_infos)
38    }
39}
40
41impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
42    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
43        &mut self,
44        module: &Module<B>,
45        lhs: &GGLWEAutomorphismKey<DataLhs>,
46        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
47        scratch: &mut Scratch<B>,
48    ) where
49        Module<B>: VecZnxDftAllocBytes
50            + VmpApplyDftToDftTmpBytes
51            + VecZnxBigNormalizeTmpBytes
52            + VmpApplyDftToDft<B>
53            + VmpApplyDftToDftAdd<B>
54            + VecZnxDftApply<B>
55            + VecZnxIdftApplyConsume<B>
56            + VecZnxBigAddSmallInplace<B>
57            + VecZnxBigNormalize<B>
58            + VecZnxNormalize<B>
59            + VecZnxNormalizeTmpBytes,
60        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
61    {
62        self.key.keyswitch(module, &lhs.key, rhs, scratch);
63    }
64
65    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
66        &mut self,
67        module: &Module<B>,
68        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
69        scratch: &mut Scratch<B>,
70    ) where
71        Module<B>: VecZnxDftAllocBytes
72            + VmpApplyDftToDftTmpBytes
73            + VecZnxBigNormalizeTmpBytes
74            + VmpApplyDftToDft<B>
75            + VmpApplyDftToDftAdd<B>
76            + VecZnxDftApply<B>
77            + VecZnxIdftApplyConsume<B>
78            + VecZnxBigAddSmallInplace<B>
79            + VecZnxBigNormalize<B>
80            + VecZnxNormalize<B>
81            + VecZnxNormalizeTmpBytes,
82        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
83    {
84        self.key.keyswitch_inplace(module, &rhs.key, scratch);
85    }
86}
87
88impl GGLWESwitchingKey<Vec<u8>> {
89    pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
90        module: &Module<B>,
91        out_infos: &OUT,
92        in_infos: &IN,
93        key_apply: &KEY,
94    ) -> usize
95    where
96        OUT: GGLWELayoutInfos,
97        IN: GGLWELayoutInfos,
98        KEY: GGLWELayoutInfos,
99        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
100    {
101        GLWECiphertext::keyswitch_scratch_space(module, out_infos, in_infos, key_apply)
102    }
103
104    pub fn keyswitch_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_apply: &KEY) -> usize
105    where
106        OUT: GGLWELayoutInfos + GLWEInfos,
107        KEY: GGLWELayoutInfos + GLWEInfos,
108        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
109    {
110        GLWECiphertext::keyswitch_inplace_scratch_space(module, out_infos, key_apply)
111    }
112}
113
114impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
115    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
116        &mut self,
117        module: &Module<B>,
118        lhs: &GGLWESwitchingKey<DataLhs>,
119        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
120        scratch: &mut Scratch<B>,
121    ) where
122        Module<B>: VecZnxDftAllocBytes
123            + VmpApplyDftToDftTmpBytes
124            + VecZnxBigNormalizeTmpBytes
125            + VmpApplyDftToDft<B>
126            + VmpApplyDftToDftAdd<B>
127            + VecZnxDftApply<B>
128            + VecZnxIdftApplyConsume<B>
129            + VecZnxBigAddSmallInplace<B>
130            + VecZnxBigNormalize<B>
131            + VecZnxNormalize<B>
132            + VecZnxNormalizeTmpBytes,
133        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
134    {
135        #[cfg(debug_assertions)]
136        {
137            assert_eq!(
138                self.rank_in(),
139                lhs.rank_in(),
140                "ksk_out input rank: {} != ksk_in input rank: {}",
141                self.rank_in(),
142                lhs.rank_in()
143            );
144            assert_eq!(
145                lhs.rank_out(),
146                rhs.rank_in(),
147                "ksk_in output rank: {} != ksk_apply input rank: {}",
148                self.rank_out(),
149                rhs.rank_in()
150            );
151            assert_eq!(
152                self.rank_out(),
153                rhs.rank_out(),
154                "ksk_out output rank: {} != ksk_apply output rank: {}",
155                self.rank_out(),
156                rhs.rank_out()
157            );
158            assert!(
159                self.rows() <= lhs.rows(),
160                "self.rows()={} > lhs.rows()={}",
161                self.rows(),
162                lhs.rows()
163            );
164            assert_eq!(
165                self.digits(),
166                lhs.digits(),
167                "ksk_out digits: {} != ksk_in digits: {}",
168                self.digits(),
169                lhs.digits()
170            )
171        }
172
173        (0..self.rank_in().into()).for_each(|col_i| {
174            (0..self.rows().into()).for_each(|row_j| {
175                self.at_mut(row_j, col_i)
176                    .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
177            });
178        });
179
180        (self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| {
181            (0..self.rank_in().into()).for_each(|col_j| {
182                self.at_mut(row_i, col_j).data.zero();
183            });
184        });
185    }
186
187    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
188        &mut self,
189        module: &Module<B>,
190        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
191        scratch: &mut Scratch<B>,
192    ) where
193        Module<B>: VecZnxDftAllocBytes
194            + VmpApplyDftToDftTmpBytes
195            + VecZnxBigNormalizeTmpBytes
196            + VmpApplyDftToDft<B>
197            + VmpApplyDftToDftAdd<B>
198            + VecZnxDftApply<B>
199            + VecZnxIdftApplyConsume<B>
200            + VecZnxBigAddSmallInplace<B>
201            + VecZnxBigNormalize<B>
202            + VecZnxNormalize<B>
203            + VecZnxNormalizeTmpBytes,
204        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
205    {
206        #[cfg(debug_assertions)]
207        {
208            assert_eq!(
209                self.rank_out(),
210                rhs.rank_out(),
211                "ksk_out output rank: {} != ksk_apply output rank: {}",
212                self.rank_out(),
213                rhs.rank_out()
214            );
215        }
216
217        (0..self.rank_in().into()).for_each(|col_i| {
218            (0..self.rows().into()).for_each(|row_j| {
219                self.at_mut(row_j, col_i)
220                    .keyswitch_inplace(module, rhs, scratch)
221            });
222        });
223    }
224}