poulpy_core/keyswitching/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxZero,
5    },
6    layouts::{Backend, DataMut, DataRef, Module, Scratch},
7};
8
9use crate::layouts::{
10    GGLWEAutomorphismKey, GGLWESwitchingKey, GLWECiphertext, Infos,
11    prepared::{GGLWEAutomorphismKeyPrepared, GGLWESwitchingKeyPrepared},
12};
13
14impl GGLWEAutomorphismKey<Vec<u8>> {
15    #[allow(clippy::too_many_arguments)]
16    pub fn keyswitch_scratch_space<B: Backend>(
17        module: &Module<B>,
18        n: usize,
19        basek: usize,
20        k_out: usize,
21        k_in: usize,
22        k_ksk: usize,
23        digits: usize,
24        rank: usize,
25    ) -> usize
26    where
27        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
28    {
29        GGLWESwitchingKey::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank)
30    }
31
32    pub fn keyswitch_inplace_scratch_space<B: Backend>(
33        module: &Module<B>,
34        n: usize,
35        basek: usize,
36        k_out: usize,
37        k_ksk: usize,
38        digits: usize,
39        rank: usize,
40    ) -> usize
41    where
42        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
43    {
44        GGLWESwitchingKey::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank)
45    }
46}
47
48impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
49    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
50        &mut self,
51        module: &Module<B>,
52        lhs: &GGLWEAutomorphismKey<DataLhs>,
53        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
54        scratch: &mut Scratch<B>,
55    ) where
56        Module<B>: VecZnxDftAllocBytes
57            + VmpApplyTmpBytes
58            + VecZnxBigNormalizeTmpBytes
59            + VmpApply<B>
60            + VmpApplyAdd<B>
61            + VecZnxDftFromVecZnx<B>
62            + VecZnxDftToVecZnxBigConsume<B>
63            + VecZnxBigAddSmallInplace<B>
64            + VecZnxBigNormalize<B>,
65        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
66    {
67        self.key.keyswitch(module, &lhs.key, rhs, scratch);
68    }
69
70    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
71        &mut self,
72        module: &Module<B>,
73        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
74        scratch: &mut Scratch<B>,
75    ) where
76        Module<B>: VecZnxDftAllocBytes
77            + VmpApplyTmpBytes
78            + VecZnxBigNormalizeTmpBytes
79            + VmpApply<B>
80            + VmpApplyAdd<B>
81            + VecZnxDftFromVecZnx<B>
82            + VecZnxDftToVecZnxBigConsume<B>
83            + VecZnxBigAddSmallInplace<B>
84            + VecZnxBigNormalize<B>,
85        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
86    {
87        self.key.keyswitch_inplace(module, &rhs.key, scratch);
88    }
89}
90
91impl GGLWESwitchingKey<Vec<u8>> {
92    #[allow(clippy::too_many_arguments)]
93    pub fn keyswitch_scratch_space<B: Backend>(
94        module: &Module<B>,
95        n: usize,
96        basek: usize,
97        k_out: usize,
98        k_in: usize,
99        k_ksk: usize,
100        digits: usize,
101        rank_in: usize,
102        rank_out: usize,
103    ) -> usize
104    where
105        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
106    {
107        GLWECiphertext::keyswitch_scratch_space(
108            module, n, basek, k_out, k_in, k_ksk, digits, rank_in, rank_out,
109        )
110    }
111
112    pub fn keyswitch_inplace_scratch_space<B: Backend>(
113        module: &Module<B>,
114        n: usize,
115        basek: usize,
116        k_out: usize,
117        k_ksk: usize,
118        digits: usize,
119        rank: usize,
120    ) -> usize
121    where
122        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
123    {
124        GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_out, k_ksk, digits, rank)
125    }
126}
127
128impl<DataSelf: DataMut> GGLWESwitchingKey<DataSelf> {
129    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
130        &mut self,
131        module: &Module<B>,
132        lhs: &GGLWESwitchingKey<DataLhs>,
133        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
134        scratch: &mut Scratch<B>,
135    ) where
136        Module<B>: VecZnxDftAllocBytes
137            + VmpApplyTmpBytes
138            + VecZnxBigNormalizeTmpBytes
139            + VmpApply<B>
140            + VmpApplyAdd<B>
141            + VecZnxDftFromVecZnx<B>
142            + VecZnxDftToVecZnxBigConsume<B>
143            + VecZnxBigAddSmallInplace<B>
144            + VecZnxBigNormalize<B>,
145        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
146    {
147        #[cfg(debug_assertions)]
148        {
149            assert_eq!(
150                self.rank_in(),
151                lhs.rank_in(),
152                "ksk_out input rank: {} != ksk_in input rank: {}",
153                self.rank_in(),
154                lhs.rank_in()
155            );
156            assert_eq!(
157                lhs.rank_out(),
158                rhs.rank_in(),
159                "ksk_in output rank: {} != ksk_apply input rank: {}",
160                self.rank_out(),
161                rhs.rank_in()
162            );
163            assert_eq!(
164                self.rank_out(),
165                rhs.rank_out(),
166                "ksk_out output rank: {} != ksk_apply output rank: {}",
167                self.rank_out(),
168                rhs.rank_out()
169            );
170        }
171
172        (0..self.rank_in()).for_each(|col_i| {
173            (0..self.rows()).for_each(|row_j| {
174                self.at_mut(row_j, col_i)
175                    .keyswitch(module, &lhs.at(row_j, col_i), rhs, scratch);
176            });
177        });
178
179        (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
180            (0..self.rank_in()).for_each(|col_j| {
181                self.at_mut(row_i, col_j).data.zero();
182            });
183        });
184    }
185
186    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
187        &mut self,
188        module: &Module<B>,
189        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
190        scratch: &mut Scratch<B>,
191    ) where
192        Module<B>: VecZnxDftAllocBytes
193            + VmpApplyTmpBytes
194            + VecZnxBigNormalizeTmpBytes
195            + VmpApply<B>
196            + VmpApplyAdd<B>
197            + VecZnxDftFromVecZnx<B>
198            + VecZnxDftToVecZnxBigConsume<B>
199            + VecZnxBigAddSmallInplace<B>
200            + VecZnxBigNormalize<B>,
201        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
202    {
203        #[cfg(debug_assertions)]
204        {
205            assert_eq!(
206                self.rank_out(),
207                rhs.rank_out(),
208                "ksk_out output rank: {} != ksk_apply output rank: {}",
209                self.rank_out(),
210                rhs.rank_out()
211            );
212        }
213
214        (0..self.rank_in()).for_each(|col_i| {
215            (0..self.rows()).for_each(|row_j| {
216                self.at_mut(row_j, col_i)
217                    .keyswitch_inplace(module, rhs, scratch)
218            });
219        });
220    }
221}