poulpy_core/keyswitching/
lwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes,
5        VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxView, ZnxViewMut, ZnxZero},
8};
9
10use crate::{
11    TakeGLWECt,
12    layouts::{
13        GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, LWECiphertext, LWEInfos, Rank, TorusPrecision,
14        prepared::LWESwitchingKeyPrepared,
15    },
16};
17
18impl LWECiphertext<Vec<u8>> {
19    pub fn keyswitch_scratch_space<B: Backend, OUT, IN, KEY>(
20        module: &Module<B>,
21        out_infos: &OUT,
22        in_infos: &IN,
23        key_infos: &KEY,
24    ) -> usize
25    where
26        OUT: LWEInfos,
27        IN: LWEInfos,
28        KEY: GGLWELayoutInfos,
29        Module<B>: VecZnxDftAllocBytes
30            + VmpApplyDftToDftTmpBytes
31            + VecZnxBigNormalizeTmpBytes
32            + VmpApplyDftToDftTmpBytes
33            + VmpApplyDftToDft<B>
34            + VmpApplyDftToDftAdd<B>
35            + VecZnxDftApply<B>
36            + VecZnxIdftApplyConsume<B>
37            + VecZnxBigAddSmallInplace<B>
38            + VecZnxBigNormalize<B>
39            + VecZnxNormalizeTmpBytes,
40    {
41        let max_k: TorusPrecision = in_infos.k().max(out_infos.k());
42
43        let glwe_in_infos: GLWECiphertextLayout = GLWECiphertextLayout {
44            n: module.n().into(),
45            base2k: in_infos.base2k(),
46            k: max_k,
47            rank: Rank(1),
48        };
49
50        let glwe_out_infos: GLWECiphertextLayout = GLWECiphertextLayout {
51            n: module.n().into(),
52            base2k: out_infos.base2k(),
53            k: max_k,
54            rank: Rank(1),
55        };
56
57        let glwe_in: usize = GLWECiphertext::alloc_bytes(&glwe_in_infos);
58        let glwe_out: usize = GLWECiphertext::alloc_bytes(&glwe_out_infos);
59        let ks: usize = GLWECiphertext::keyswitch_scratch_space(module, &glwe_out_infos, &glwe_in_infos, key_infos);
60
61        glwe_in + glwe_out + ks
62    }
63}
64
65impl<DLwe: DataMut> LWECiphertext<DLwe> {
66    pub fn keyswitch<A, DKs, B: Backend>(
67        &mut self,
68        module: &Module<B>,
69        a: &LWECiphertext<A>,
70        ksk: &LWESwitchingKeyPrepared<DKs, B>,
71        scratch: &mut Scratch<B>,
72    ) where
73        A: DataRef,
74        DKs: DataRef,
75        Module<B>: VecZnxDftAllocBytes
76            + VmpApplyDftToDftTmpBytes
77            + VecZnxBigNormalizeTmpBytes
78            + VmpApplyDftToDft<B>
79            + VmpApplyDftToDftAdd<B>
80            + VecZnxDftApply<B>
81            + VecZnxIdftApplyConsume<B>
82            + VecZnxBigAddSmallInplace<B>
83            + VecZnxBigNormalize<B>
84            + VecZnxNormalize<B>
85            + VecZnxNormalizeTmpBytes
86            + VecZnxCopy,
87        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
88    {
89        #[cfg(debug_assertions)]
90        {
91            assert!(self.n() <= module.n() as u32);
92            assert!(a.n() <= module.n() as u32);
93            assert!(scratch.available() >= LWECiphertext::keyswitch_scratch_space(module, self, a, ksk));
94        }
95
96        let max_k: TorusPrecision = self.k().max(a.k());
97
98        let a_size: usize = a.k().div_ceil(ksk.base2k()) as usize;
99
100        let (mut glwe_in, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
101            n: ksk.n(),
102            base2k: a.base2k(),
103            k: max_k,
104            rank: Rank(1),
105        });
106        glwe_in.data.zero();
107
108        let (mut glwe_out, scratch_1) = scratch_1.take_glwe_ct(&GLWECiphertextLayout {
109            n: ksk.n(),
110            base2k: self.base2k(),
111            k: max_k,
112            rank: Rank(1),
113        });
114
115        let n_lwe: usize = a.n().into();
116
117        for i in 0..a_size {
118            let data_lwe: &[i64] = a.data.at(0, i);
119            glwe_in.data.at_mut(0, i)[0] = data_lwe[0];
120            glwe_in.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
121        }
122
123        glwe_out.keyswitch(module, &glwe_in, &ksk.0, scratch_1);
124        self.sample_extract(&glwe_out);
125    }
126}