poulpy_core/keyswitching/
lwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxView,
5        ZnxViewMut, ZnxZero,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch},
8};
9
10use crate::{
11    TakeGLWECt,
12    layouts::{GLWECiphertext, Infos, LWECiphertext, prepared::LWESwitchingKeyPrepared},
13};
14
15impl LWECiphertext<Vec<u8>> {
16    pub fn keyswitch_scratch_space<B: Backend>(
17        module: &Module<B>,
18        n: usize,
19        basek: usize,
20        k_lwe_out: usize,
21        k_lwe_in: usize,
22        k_ksk: usize,
23    ) -> usize
24    where
25        Module<B>: VecZnxDftAllocBytes
26            + VmpApplyTmpBytes
27            + VecZnxBigNormalizeTmpBytes
28            + VmpApplyTmpBytes
29            + VmpApply<B>
30            + VmpApplyAdd<B>
31            + VecZnxDftFromVecZnx<B>
32            + VecZnxDftToVecZnxBigConsume<B>
33            + VecZnxBigAddSmallInplace<B>
34            + VecZnxBigNormalize<B>,
35    {
36        GLWECiphertext::bytes_of(n, basek, k_lwe_out.max(k_lwe_in), 1)
37            + GLWECiphertext::keyswitch_inplace_scratch_space(module, n, basek, k_lwe_out, k_ksk, 1, 1)
38    }
39}
40
41impl<DLwe: DataMut> LWECiphertext<DLwe> {
42    pub fn keyswitch<A, DKs, B: Backend>(
43        &mut self,
44        module: &Module<B>,
45        a: &LWECiphertext<A>,
46        ksk: &LWESwitchingKeyPrepared<DKs, B>,
47        scratch: &mut Scratch<B>,
48    ) where
49        A: DataRef,
50        DKs: DataRef,
51        Module<B>: VecZnxDftAllocBytes
52            + VmpApplyTmpBytes
53            + VecZnxBigNormalizeTmpBytes
54            + VmpApply<B>
55            + VmpApplyAdd<B>
56            + VecZnxDftFromVecZnx<B>
57            + VecZnxDftToVecZnxBigConsume<B>
58            + VecZnxBigAddSmallInplace<B>
59            + VecZnxBigNormalize<B>,
60        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
61    {
62        #[cfg(debug_assertions)]
63        {
64            assert!(self.n() <= module.n());
65            assert!(a.n() <= module.n());
66            assert_eq!(self.basek(), a.basek());
67        }
68
69        let max_k: usize = self.k().max(a.k());
70        let basek: usize = self.basek();
71
72        let (mut glwe, scratch1) = scratch.take_glwe_ct(ksk.n(), basek, max_k, 1);
73        glwe.data.zero();
74
75        let n_lwe: usize = a.n();
76
77        (0..a.size()).for_each(|i| {
78            let data_lwe: &[i64] = a.data.at(0, i);
79            glwe.data.at_mut(0, i)[0] = data_lwe[0];
80            glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
81        });
82
83        glwe.keyswitch_inplace(module, &ksk.0, scratch1);
84
85        self.sample_extract(&glwe);
86    }
87}