poulpy_core/keyswitching/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
4        VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5    },
6    layouts::{Backend, DataMut, DataRef, DataViewMut, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat, ZnxInfos},
7};
8
9use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWESwitchingKeyPrepared};
10
11impl GLWECiphertext<Vec<u8>> {
12    #[allow(clippy::too_many_arguments)]
13    pub fn keyswitch_scratch_space<B: Backend>(
14        module: &Module<B>,
15        basek: usize,
16        k_out: usize,
17        k_in: usize,
18        k_ksk: usize,
19        digits: usize,
20        rank_in: usize,
21        rank_out: usize,
22    ) -> usize
23    where
24        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
25    {
26        let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
27        let out_size: usize = k_out.div_ceil(basek);
28        let ksk_size: usize = k_ksk.div_ceil(basek);
29        let res_dft: usize = module.vec_znx_dft_alloc_bytes(rank_out + 1, ksk_size); // TODO OPTIMIZE
30        let ai_dft: usize = module.vec_znx_dft_alloc_bytes(rank_in, in_size);
31        let vmp: usize = module.vmp_apply_dft_to_dft_tmp_bytes(out_size, in_size, in_size, rank_in, rank_out + 1, ksk_size)
32            + module.vec_znx_dft_alloc_bytes(rank_in, in_size);
33        let normalize: usize = module.vec_znx_big_normalize_tmp_bytes();
34        res_dft + ((ai_dft + vmp) | normalize)
35    }
36
37    pub fn keyswitch_inplace_scratch_space<B: Backend>(
38        module: &Module<B>,
39        basek: usize,
40        k_out: usize,
41        k_ksk: usize,
42        digits: usize,
43        rank: usize,
44    ) -> usize
45    where
46        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
47    {
48        Self::keyswitch_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank, rank)
49    }
50}
51
52impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
53    #[allow(dead_code)]
54    pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
55        &self,
56        module: &Module<B>,
57        lhs: &GLWECiphertext<DataLhs>,
58        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
59        scratch: &Scratch<B>,
60    ) where
61        DataLhs: DataRef,
62        DataRhs: DataRef,
63        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
64        Scratch<B>: ScratchAvailable,
65    {
66        let basek: usize = self.basek();
67        assert_eq!(
68            lhs.rank(),
69            rhs.rank_in(),
70            "lhs.rank(): {} != rhs.rank_in(): {}",
71            lhs.rank(),
72            rhs.rank_in()
73        );
74        assert_eq!(
75            self.rank(),
76            rhs.rank_out(),
77            "self.rank(): {} != rhs.rank_out(): {}",
78            self.rank(),
79            rhs.rank_out()
80        );
81        assert_eq!(self.basek(), basek);
82        assert_eq!(lhs.basek(), basek);
83        assert_eq!(rhs.n(), self.n());
84        assert_eq!(lhs.n(), self.n());
85        assert!(
86            scratch.available()
87                >= GLWECiphertext::keyswitch_scratch_space(
88                    module,
89                    self.basek(),
90                    self.k(),
91                    lhs.k(),
92                    rhs.k(),
93                    rhs.digits(),
94                    rhs.rank_in(),
95                    rhs.rank_out(),
96                ),
97            "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
98                    module,
99                    self.basek(),
100                    self.k(),
101                    lhs.k(),
102                    rhs.k(),
103                    rhs.digits(),
104                    rhs.rank_in(),
105                    rhs.rank_out(),
106                )={}",
107            scratch.available(),
108            GLWECiphertext::keyswitch_scratch_space(
109                module,
110                self.basek(),
111                self.k(),
112                lhs.k(),
113                rhs.k(),
114                rhs.digits(),
115                rhs.rank_in(),
116                rhs.rank_out(),
117            )
118        );
119    }
120}
121
122impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
123    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
124        &mut self,
125        module: &Module<B>,
126        lhs: &GLWECiphertext<DataLhs>,
127        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
128        scratch: &mut Scratch<B>,
129    ) where
130        Module<B>: VecZnxDftAllocBytes
131            + VmpApplyDftToDftTmpBytes
132            + VecZnxBigNormalizeTmpBytes
133            + VmpApplyDftToDftTmpBytes
134            + VmpApplyDftToDft<B>
135            + VmpApplyDftToDftAdd<B>
136            + DFT<B>
137            + IDFTConsume<B>
138            + VecZnxBigAddSmallInplace<B>
139            + VecZnxBigNormalize<B>,
140        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
141    {
142        #[cfg(debug_assertions)]
143        {
144            self.assert_keyswitch(module, lhs, rhs, scratch);
145        }
146        let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
147        let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1);
148        (0..self.cols()).for_each(|i| {
149            module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
150        })
151    }
152
153    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
154        &mut self,
155        module: &Module<B>,
156        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
157        scratch: &mut Scratch<B>,
158    ) where
159        Module<B>: VecZnxDftAllocBytes
160            + VmpApplyDftToDftTmpBytes
161            + VecZnxBigNormalizeTmpBytes
162            + VmpApplyDftToDftTmpBytes
163            + VmpApplyDftToDft<B>
164            + VmpApplyDftToDftAdd<B>
165            + DFT<B>
166            + IDFTConsume<B>
167            + VecZnxBigAddSmallInplace<B>
168            + VecZnxBigNormalize<B>,
169        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
170    {
171        unsafe {
172            let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
173            self.keyswitch(module, &*self_ptr, rhs, scratch);
174        }
175    }
176}
177
178impl<D: DataRef> GLWECiphertext<D> {
179    pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
180        &self,
181        module: &Module<B>,
182        res_dft: VecZnxDft<DataRes, B>,
183        rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
184        scratch: &mut Scratch<B>,
185    ) -> VecZnxBig<DataRes, B>
186    where
187        DataRes: DataMut,
188        DataKey: DataRef,
189        Module<B>: VecZnxDftAllocBytes
190            + VmpApplyDftToDftTmpBytes
191            + VecZnxBigNormalizeTmpBytes
192            + VmpApplyDftToDftTmpBytes
193            + VmpApplyDftToDft<B>
194            + VmpApplyDftToDftAdd<B>
195            + DFT<B>
196            + IDFTConsume<B>
197            + VecZnxBigAddSmallInplace<B>
198            + VecZnxBigNormalize<B>,
199        Scratch<B>: TakeVecZnxDft<B>,
200    {
201        if rhs.digits() == 1 {
202            return keyswitch_vmp_one_digit(module, res_dft, &self.data, &rhs.key.data, scratch);
203        }
204
205        keyswitch_vmp_multiple_digits(
206            module,
207            res_dft,
208            &self.data,
209            &rhs.key.data,
210            rhs.digits(),
211            scratch,
212        )
213    }
214}
215
216fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
217    module: &Module<B>,
218    mut res_dft: VecZnxDft<DataRes, B>,
219    a: &VecZnx<DataIn>,
220    mat: &VmpPMat<DataVmp, B>,
221    scratch: &mut Scratch<B>,
222) -> VecZnxBig<DataRes, B>
223where
224    DataRes: DataMut,
225    DataIn: DataRef,
226    DataVmp: DataRef,
227    Module<B>: VecZnxDftAllocBytes + DFT<B> + VmpApplyDftToDft<B> + IDFTConsume<B> + VecZnxBigAddSmallInplace<B>,
228    Scratch<B>: TakeVecZnxDft<B>,
229{
230    let cols: usize = a.cols();
231    let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
232    (0..cols - 1).for_each(|col_i| {
233        module.dft(1, 0, &mut ai_dft, col_i, a, col_i + 1);
234    });
235    module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
236    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
237    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
238    res_big
239}
240
241fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
242    module: &Module<B>,
243    mut res_dft: VecZnxDft<DataRes, B>,
244    a: &VecZnx<DataIn>,
245    mat: &VmpPMat<DataVmp, B>,
246    digits: usize,
247    scratch: &mut Scratch<B>,
248) -> VecZnxBig<DataRes, B>
249where
250    DataRes: DataMut,
251    DataIn: DataRef,
252    DataVmp: DataRef,
253    Module<B>: VecZnxDftAllocBytes
254        + DFT<B>
255        + VmpApplyDftToDft<B>
256        + VmpApplyDftToDftAdd<B>
257        + IDFTConsume<B>
258        + VecZnxBigAddSmallInplace<B>,
259    Scratch<B>: TakeVecZnxDft<B>,
260{
261    let cols: usize = a.cols();
262    let size: usize = a.size();
263    let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
264
265    ai_dft.data_mut().fill(0);
266
267    (0..digits).for_each(|di| {
268        ai_dft.set_size((size + di) / digits);
269
270        // Small optimization for digits > 2
271        // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
272        // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
273        // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
274        // It is possible to further ignore the last digits-1 limbs, but this introduce
275        // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
276        // noise is kept with respect to the ideal functionality.
277        res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
278
279        (0..cols - 1).for_each(|col_i| {
280            module.dft(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
281        });
282
283        if di == 0 {
284            module.vmp_apply_dft_to_dft(&mut res_dft, &ai_dft, mat, scratch1);
285        } else {
286            module.vmp_apply_dft_to_dft_add(&mut res_dft, &ai_dft, mat, di, scratch1);
287        }
288    });
289
290    res_dft.set_size(res_dft.max_size());
291    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_idft_consume(res_dft);
292    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
293    res_big
294}