poulpy_core/keyswitching/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
4        VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5    },
6    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
7};
8
9use crate::layouts::{
10    GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, Infos,
11    prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
12};
13
14impl GGLWEAutomorphismKey<Vec<u8>> {
15    #[allow(clippy::too_many_arguments)]
16    pub fn keyswitch_scratch_space<B: Backend>(
17        module: &Module<B>,
18        basek: usize,
19        k_out: usize,
20        k_in: usize,
21        k_ksk: usize,
22        digits: usize,
23        rank: usize,
24    ) -> usize
25    where
26        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
27    {
28        GGLWESwitchingKey::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
29    }
30
31    pub fn keyswitch_inplace_scratch_space<B: Backend>(
32        module: &Module<B>,
33        basek: usize,
34        k_out: usize,
35        k_ksk: usize,
36        digits: usize,
37        rank: usize,
38    ) -> usize
39    where
40        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
41    {
42        GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
43    }
44}
45
46impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
47    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
48        &mut self,
49        module: &Module<B>,
50        lhs: &GGLWEAutomorphismKey<DataLhs>,
51        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
52        scratch: &mut Scratch<B>,
53    ) where
54        Module<B>: VecZnxDftAllocBytes
55            + VmpApplyDftToDftTmpBytes
56            + VecZnxBigNormalizeTmpBytes
57            + VmpApplyDftToDft<B>
58            + VmpApplyDftToDftAdd<B>
59            + DFT<B>
60            + IDFTConsume<B>
61            + VecZnxBigAddSmallInplace<B>
62            + VecZnxBigNormalize<B>,
63        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
64    {
65        self.key.keyswitch(module, &lhs.key, rhs, scratch);
66    }
67
68    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
69        &mut self,
70        module: &Module<B>,
71        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
72        scratch: &mut Scratch<B>,
73    ) where
74        Module<B>: VecZnxDftAllocBytes
75            + VmpApplyDftToDftTmpBytes
76            + VecZnxBigNormalizeTmpBytes
77            + VmpApplyDftToDft<B>
78            + VmpApplyDftToDftAdd<B>
79            + DFT<B>
80            + IDFTConsume<B>
81            + VecZnxBigAddSmallInplace<B>
82            + VecZnxBigNormalize<B>,
83        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
84    {
85        self.key.keyswitch_inplace(module, &rhs.key, scratch);
86    }
87}
88
89impl GGLWESwitchingKey<Vec<u8>> {
90    #[allow(clippy::too_many_arguments)]
91    pub fn keyswitch_scratch_space<B: Backend>(
92        module: &Module<B>,
93        basek: usize,
94        k_out: usize,
95        k_in: usize,
96        k_ksk: usize,
97        digits: usize,
98        rank_in: usize,
99        rank_out: usize,
100    ) -> usize
101    where
102        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
103    {
104        GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out)
105    }
106
107    pub fn keyswitch_inplace_scratch_space<B: Backend>(
108        module: &Module<B>,
109        basek: usize,
110        k_out: usize,
111        k_ksk: usize,
112        digits: usize,
113        rank: usize,
114    ) -> usize
115    where
116        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
117    {
118        GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_out, k_ksk, digits, rank)
119    }
120}
121
122impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
123    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
124        &mut self,
125        module: &Module<B>,
126        lhs: &GGLWESwitchingKey<DataLhs>,
127        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
128        scratch: &mut Scratch<B>,
129    ) where
130        Module<B>: VecZnxDftAllocBytes
131            + VmpApplyDftToDftTmpBytes
132            + VecZnxBigNormalizeTmpBytes
133            + VmpApplyDftToDft<B>
134            + VmpApplyDftToDftAdd<B>
135            + DFT<B>
136            + IDFTConsume<B>
137            + VecZnxBigAddSmallInplace<B>
138            + VecZnxBigNormalize<B>,
139        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
140    {
141        #[cfg(debug_assertions)]
142        {
143            assert_eq!(
144                self.rank_in(),
145                lhs.rank_in(),
146                "ksk_out input rank: {} != ksk_in input rank: {}",
147                self.rank_in(),
148                lhs.rank_in()
149            );
150            assert_eq!(
151                lhs.rank_out(),
152                rhs.rank_in(),
153                "ksk_in output rank: {} != ksk_apply input rank: {}",
154                self.rank_out(),
155                rhs.rank_in()
156            );
157            assert_eq!(
158                self.rank_out(),
159                rhs.rank_out(),
160                "ksk_out output rank: {} != ksk_apply output rank: {}",
161                self.rank_out(),
162                rhs.rank_out()
163            );
164        }
165
166        (0..self.rank_in()).for_each(|col_i| {
167            (0..self.rows()).for_each(|row_j| {
168                self.at_mut(row_j, col_i)
169                    .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
170            });
171        });
172
173        (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
174            (0..self.rank_in()).for_each(|col_j| {
175                self.at_mut(row_i, col_j).data.zero();
176            });
177        });
178    }
179
180    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
181        &mut self,
182        module: &Module<B>,
183        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
184        scratch: &mut Scratch<B>,
185    ) where
186        Module<B>: VecZnxDftAllocBytes
187            + VmpApplyDftToDftTmpBytes
188            + VecZnxBigNormalizeTmpBytes
189            + VmpApplyDftToDft<B>
190            + VmpApplyDftToDftAdd<B>
191            + DFT<B>
192            + IDFTConsume<B>
193            + VecZnxBigAddSmallInplace<B>
194            + VecZnxBigNormalize<B>,
195        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
196    {
197        #[cfg(debug_assertions)]
198        {
199            assert_eq!(
200                self.rank_out(),
201                rhs.rank_out(),
202                "ksk_out output rank: {} != ksk_apply output rank: {}",
203                self.rank_out(),
204                rhs.rank_out()
205            );
206        }
207
208        (0..self.rank_in()).for_each(|col_i| {
209            (0..self.rows()).for_each(|row_j| {
210                self.at_mut(row_j, col_i)
211                    .keyswitch_inplace(module, rhs, scratch)
212            });
213        });
214    }
215}