poulpy_core/keyswitching/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
5        VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{
11    GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, Infos,
12    prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
13};
14
15impl GGLWEAutomorphismKey<Vec<u8>> {
16    #[allow(clippy::too_many_arguments)]
17    pub fn keyswitch_scratch_space<B: Backend>(
18        module: &Module<B>,
19        basek: usize,
20        k_out: usize,
21        k_in: usize,
22        k_ksk: usize,
23        digits: usize,
24        rank: usize,
25    ) -> usize
26    where
27        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
28    {
29        GGLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
30    }
31
32    pub fn keyswitch_inplace_scratch_space<B: Backend>(
33        module: &Module<B>,
34        basek: usize,
35        k_out: usize,
36        k_ksk: usize,
37        digits: usize,
38        rank: usize,
39    ) -> usize
40    where
41        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
42    {
43        GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
44    }
45}
46
47impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
48    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
49        &mut self,
50        module: &Module<B>,
51        lhs: &GGLWEAutomorphismKey<DataLhs>,
52        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
53        scratch: &mut Scratch<B>,
54    ) where
55        Module<B>: VecZnxDftAllocBytes
56            + VmpApplyDftToDftTmpBytes
57            + VecZnxBigNormalizeTmpBytes
58            + VmpApplyDftToDft<B>
59            + VmpApplyDftToDftAdd<B>
60            + VecZnxDftApply<B>
61            + VecZnxIdftApplyConsume<B>
62            + VecZnxBigAddSmallInplace<B>
63            + VecZnxBigNormalize<B>,
64        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
65    {
66        self.key.keyswitch(module, &lhs.key, rhs, scratch);
67    }
68
69    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
70        &mut self,
71        module: &Module<B>,
72        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
73        scratch: &mut Scratch<B>,
74    ) where
75        Module<B>: VecZnxDftAllocBytes
76            + VmpApplyDftToDftTmpBytes
77            + VecZnxBigNormalizeTmpBytes
78            + VmpApplyDftToDft<B>
79            + VmpApplyDftToDftAdd<B>
80            + VecZnxDftApply<B>
81            + VecZnxIdftApplyConsume<B>
82            + VecZnxBigAddSmallInplace<B>
83            + VecZnxBigNormalize<B>,
84        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
85    {
86        self.key.keyswitch_inplace(module, &rhs.key, scratch);
87    }
88}
89
90impl GGLWESwitchingKey<Vec<u8>> {
91    #[allow(clippy::too_many_arguments)]
92    pub fn keyswitch_scratch_space<B: Backend>(
93        module: &Module<B>,
94        basek: usize,
95        k_out: usize,
96        k_in: usize,
97        k_ksk: usize,
98        digits: usize,
99        rank_in: usize,
100        rank_out: usize,
101    ) -> usize
102    where
103        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
104    {
105        GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out)
106    }
107
108    pub fn keyswitch_inplace_scratch_space<B: Backend>(
109        module: &Module<B>,
110        basek: usize,
111        k_out: usize,
112        k_ksk: usize,
113        digits: usize,
114        rank: usize,
115    ) -> usize
116    where
117        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
118    {
119        GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
120    }
121}
122
123impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
124    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
125        &mut self,
126        module: &Module<B>,
127        lhs: &GGLWESwitchingKey<DataLhs>,
128        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
129        scratch: &mut Scratch<B>,
130    ) where
131        Module<B>: VecZnxDftAllocBytes
132            + VmpApplyDftToDftTmpBytes
133            + VecZnxBigNormalizeTmpBytes
134            + VmpApplyDftToDft<B>
135            + VmpApplyDftToDftAdd<B>
136            + VecZnxDftApply<B>
137            + VecZnxIdftApplyConsume<B>
138            + VecZnxBigAddSmallInplace<B>
139            + VecZnxBigNormalize<B>,
140        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
141    {
142        #[cfg(debug_assertions)]
143        {
144            assert_eq!(
145                self.rank_in(),
146                lhs.rank_in(),
147                "ksk_out input rank: {} != ksk_in input rank: {}",
148                self.rank_in(),
149                lhs.rank_in()
150            );
151            assert_eq!(
152                lhs.rank_out(),
153                rhs.rank_in(),
154                "ksk_in output rank: {} != ksk_apply input rank: {}",
155                self.rank_out(),
156                rhs.rank_in()
157            );
158            assert_eq!(
159                self.rank_out(),
160                rhs.rank_out(),
161                "ksk_out output rank: {} != ksk_apply output rank: {}",
162                self.rank_out(),
163                rhs.rank_out()
164            );
165            assert!(
166                self.rows() <= lhs.rows(),
167                "self.rows()={} > lhs.rows()={}",
168                self.rows(),
169                lhs.rows()
170            );
171        }
172
173        (0..self.rank_in()).for_each(|col_i| {
174            (0..self.rows()).for_each(|row_j| {
175                self.at_mut(row_j, col_i)
176                    .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
177            });
178        });
179
180        (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
181            (0..self.rank_in()).for_each(|col_j| {
182                self.at_mut(row_i, col_j).data.zero();
183            });
184        });
185    }
186
187    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
188        &mut self,
189        module: &Module<B>,
190        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
191        scratch: &mut Scratch<B>,
192    ) where
193        Module<B>: VecZnxDftAllocBytes
194            + VmpApplyDftToDftTmpBytes
195            + VecZnxBigNormalizeTmpBytes
196            + VmpApplyDftToDft<B>
197            + VmpApplyDftToDftAdd<B>
198            + VecZnxDftApply<B>
199            + VecZnxIdftApplyConsume<B>
200            + VecZnxBigAddSmallInplace<B>
201            + VecZnxBigNormalize<B>,
202        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
203    {
204        #[cfg(debug_assertions)]
205        {
206            assert_eq!(
207                self.rank_out(),
208                rhs.rank_out(),
209                "ksk_out output rank: {} != ksk_apply output rank: {}",
210                self.rank_out(),
211                rhs.rank_out()
212            );
213        }
214
215        (0..self.rank_in()).for_each(|col_i| {
216            (0..self.rows()).for_each(|row_j| {
217                self.at_mut(row_j, col_i)
218                    .keyswitch_inplace(module, rhs, scratch)
219            });
220        });
221    }
222}