poulpy_core/keyswitching/
lwe_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize,
4        VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
5    },
6    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
7};
8
9use crate::{
10    TakeGLWECt,
11    layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWESwitchingKeyPrepared},
12};
13
14impl LWECiphertext<Vec<u8>> {
15    pub fn keyswitch_scratch_space<B: Backend>(
16        module: &Module<B>,
17        basek: usize,
18        k_lwe_out: usize,
19        k_lwe_in: usize,
20        k_ksk: usize,
21    ) -> usize
22    where
23        Module<B>: VecZnxDftAllocBytes
24            + VmpApplyDftToDftTmpBytes
25            + VecZnxBigNormalizeTmpBytes
26            + VmpApplyDftToDftTmpBytes
27            + VmpApplyDftToDft<B>
28            + VmpApplyDftToDftAdd<B>
29            + DFT<B>
30            + IDFTConsume<B>
31            + VecZnxBigAddSmallInplace<B>
32            + VecZnxBigNormalize<B>,
33    {
34        GLWECiphertext::bytes_of(module.n(), basek, k_lwe_out.max(k_lwe_in), 1)
35            + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1)
36    }
37}
38
39impl<DLwe: DataMut> LWECiphertext<DLwe> {
40    pub fn keyswitch<A, DKs, B: Backend>(
41        &mut self,
42        module: &Module<B>,
43        a: &LWECiphertext<A>,
44        ksk: &LWESwitchingKeyPrepared<DKs, B>,
45        scratch: &mut Scratch<B>,
46    ) where
47        A: DataRef,
48        DKs: DataRef,
49        Module<B>: VecZnxDftAllocBytes
50            + VmpApplyDftToDftTmpBytes
51            + VecZnxBigNormalizeTmpBytes
52            + VmpApplyDftToDft<B>
53            + VmpApplyDftToDftAdd<B>
54            + DFT<B>
55            + IDFTConsume<B>
56            + VecZnxBigAddSmallInplace<B>
57            + VecZnxBigNormalize<B>,
58        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
59    {
60        #[cfg(debug_assertions)]
61        {
62            assert!(self.n() <= module.n());
63            assert!(a.n() <= module.n());
64            assert_eq!(self.basek(), a.basek());
65        }
66
67        let max_k: usize = self.k().max(a.k());
68        let basek: usize = self.basek();
69
70        let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
71        glwe.data.zero();
72
73        let n_lwe: usize = a.n();
74
75        (0..a.size()).for_each(|i| {
76            let data_lwe: &[i64] = a.data.at(0, i);
77            glwe.data.at_mut(0, i)[0] = data_lwe[0];
78            glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
79        });
80
81        glwe.keyswitch_inplace(module, &ksk.0, scratch1);
82
83        self.sample_extract(&glwe);
84    }
85}