poulpy_core/conversion/
lwe_to_glwe.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft,
5        VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, ZnxView, ZnxViewMut, ZnxZero},
8};
9
10use crate::{
11    TakeGLWECt,
12    layouts::{
13        GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWECiphertext, LWEInfos,
14        prepared::LWEToGLWESwitchingKeyPrepared,
15    },
16};
17
18impl GLWECiphertext<Vec<u8>> {
19    pub fn from_lwe_scratch_space<B: Backend, OUT, IN, KEY>(
20        module: &Module<B>,
21        glwe_infos: &OUT,
22        lwe_infos: &IN,
23        key_infos: &KEY,
24    ) -> usize
25    where
26        OUT: GLWEInfos,
27        IN: LWEInfos,
28        KEY: GGLWELayoutInfos,
29        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
30    {
31        let ct: usize = GLWECiphertext::alloc_bytes_with(
32            module.n().into(),
33            key_infos.base2k(),
34            lwe_infos.k().max(glwe_infos.k()),
35            1u32.into(),
36        );
37        let ks: usize = GLWECiphertext::keyswitch_inplace_scratch_space(module, glwe_infos, key_infos);
38        if lwe_infos.base2k() == key_infos.base2k() {
39            ct + ks
40        } else {
41            let a_conv = VecZnx::alloc_bytes(module.n(), 1, lwe_infos.size()) + module.vec_znx_normalize_tmp_bytes();
42            ct + a_conv + ks
43        }
44    }
45}
46
47impl<D: DataMut> GLWECiphertext<D> {
48    pub fn from_lwe<DLwe, DKsk, B: Backend>(
49        &mut self,
50        module: &Module<B>,
51        lwe: &LWECiphertext<DLwe>,
52        ksk: &LWEToGLWESwitchingKeyPrepared<DKsk, B>,
53        scratch: &mut Scratch<B>,
54    ) where
55        DLwe: DataRef,
56        DKsk: DataRef,
57        Module<B>: VecZnxDftAllocBytes
58            + VmpApplyDftToDftTmpBytes
59            + VecZnxBigNormalizeTmpBytes
60            + VmpApplyDftToDft<B>
61            + VmpApplyDftToDftAdd<B>
62            + VecZnxDftApply<B>
63            + VecZnxIdftApplyConsume<B>
64            + VecZnxBigAddSmallInplace<B>
65            + VecZnxBigNormalize<B>
66            + VecZnxNormalize<B>
67            + VecZnxNormalizeTmpBytes,
68        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeGLWECt + TakeVecZnx,
69    {
70        #[cfg(debug_assertions)]
71        {
72            assert_eq!(self.n(), module.n() as u32);
73            assert_eq!(ksk.n(), module.n() as u32);
74            assert!(lwe.n() <= module.n() as u32);
75        }
76
77        let (mut glwe, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
78            n: ksk.n(),
79            base2k: ksk.base2k(),
80            k: lwe.k(),
81            rank: 1u32.into(),
82        });
83        glwe.data.zero();
84
85        let n_lwe: usize = lwe.n().into();
86
87        if lwe.base2k() == ksk.base2k() {
88            for i in 0..lwe.size() {
89                let data_lwe: &[i64] = lwe.data.at(0, i);
90                glwe.data.at_mut(0, i)[0] = data_lwe[0];
91                glwe.data.at_mut(1, i)[..n_lwe].copy_from_slice(&data_lwe[1..]);
92            }
93        } else {
94            let (mut a_conv, scratch_2) = scratch_1.take_vec_znx(module.n(), 1, lwe.size());
95            a_conv.zero();
96            for j in 0..lwe.size() {
97                let data_lwe: &[i64] = lwe.data.at(0, j);
98                a_conv.at_mut(0, j)[0] = data_lwe[0]
99            }
100
101            module.vec_znx_normalize(
102                ksk.base2k().into(),
103                &mut glwe.data,
104                0,
105                lwe.base2k().into(),
106                &a_conv,
107                0,
108                scratch_2,
109            );
110
111            a_conv.zero();
112            for j in 0..lwe.size() {
113                let data_lwe: &[i64] = lwe.data.at(0, j);
114                a_conv.at_mut(0, j)[..n_lwe].copy_from_slice(&data_lwe[1..]);
115            }
116
117            module.vec_znx_normalize(
118                ksk.base2k().into(),
119                &mut glwe.data,
120                1,
121                lwe.base2k().into(),
122                &a_conv,
123                0,
124                scratch_2,
125            );
126        }
127
128        self.keyswitch(module, &glwe, &ksk.0, scratch_1);
129    }
130}