poulpy_core/keyswitching/
lwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VmpApplyDftToDft, VmpApplyDftToDftAdd,
5        VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
8};
9
10use crate::{
11    TakeGLWECt,
12    layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWESwitchingKeyPrepared},
13};
14
15impl LWECiphertext<Vec<u8>> {
16    pub fn keyswitch_scratch_space<B: Backend>(
17        module: &Module<B>,
18        basek: usize,
19        k_lwe_out: usize,
20        k_lwe_in: usize,
21        k_ksk: usize,
22    ) -> usize
23    where
24        Module<B>: VecZnxDftAllocBytes
25            + VmpApplyDftToDftTmpBytes
26            + VecZnxBigNormalizeTmpBytes
27            + VmpApplyDftToDftTmpBytes
28            + VmpApplyDftToDft<B>
29            + VmpApplyDftToDftAdd<B>
30            + VecZnxDftApply<B>
31            + VecZnxIdftApplyConsume<B>
32            + VecZnxBigAddSmallInplace<B>
33            + VecZnxBigNormalize<B>,
34    {
35        GLWECiphertext::bytes_of(module.n(), basek, k_lwe_out.max(k_lwe_in), 1)
36            + GLWECiphertext::keyswitch_inplace_scratch_space(module, basek, k_lwe_out, k_ksk, 1, 1)
37    }
38}
39
40impl<DLwe: DataMut> LWECiphertext<DLwe> {
41    pub fn keyswitch<A, DKs, B: Backend>(
42        &mut self,
43        module: &Module<B>,
44        a: &LWECiphertext<A>,
45        ksk: &LWESwitchingKeyPrepared<DKs, B>,
46        scratch: &mut Scratch<B>,
47    ) where
48        A: DataRef,
49        DKs: DataRef,
50        Module<B>: VecZnxDftAllocBytes
51            + VmpApplyDftToDftTmpBytes
52            + VecZnxBigNormalizeTmpBytes
53            + VmpApplyDftToDft<B>
54            + VmpApplyDftToDftAdd<B>
55            + VecZnxDftApply<B>
56            + VecZnxIdftApplyConsume<B>
57            + VecZnxBigAddSmallInplace<B>
58            + VecZnxBigNormalize<B>,
59        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
60    {
61        #[cfg(debug_assertions)]
62        {
63            assert!(self.n() <= module.n());
64            assert!(a.n() <= module.n());
65            assert_eq!(self.basek(), a.basek());
66        }
67
68        let max_k: usize = self.k().max(a.k());
69        let basek: usize = self.basek();
70
71        let (mut glwe, scratch_1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
72        glwe.data.zero();
73
74        let n_lwe: usize = a.n();
75
76        (0..a.size()).for_each(|i| {
77            let data_lwe: &[i64] = a.data.at(0, i);
78            glwe.data.at_mut(0, i)[0] = data_lwe[0];
79            glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
80        });
81
82        glwe.keyswitch_inplace(module, &ksk.0, scratch_1);
83
84        self.sample_extract(&glwe);
85    }
86}