poulpy_core/keyswitching/
glwe_ct.rs

1use poulpy_hal::{
2    api::{
3        DataViewMut, ScratchAvailable, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
4        VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxInfos,
5    },
6    layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx, VecZnxBig, VecZnxDft, VmpPMat},
7};
8
9use crate::layouts::{GLWECiphertext, Infos, prepared::GGLWESwitchingKeyPrepared};
10
11impl GLWECiphertext<Vec<u8>> {
12    #[allow(clippy::too_many_arguments)]
13    pub fn keyswitch_scratch_space<B: Backend>(
14        module: &Module<B>,
15        n: usize,
16        basek: usize,
17        k_out: usize,
18        k_in: usize,
19        k_ksk: usize,
20        digits: usize,
21        rank_in: usize,
22        rank_out: usize,
23    ) -> usize
24    where
25        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
26    {
27        let in_size: usize = k_in.div_ceil(basek).div_ceil(digits);
28        let out_size: usize = k_out.div_ceil(basek);
29        let ksk_size: usize = k_ksk.div_ceil(basek);
30        let res_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_out + 1, ksk_size); // TODO OPTIMIZE
31        let ai_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank_in, in_size);
32        let vmp: usize = module.vmp_apply_tmp_bytes(
33            n,
34            out_size,
35            in_size,
36            in_size,
37            rank_in,
38            rank_out + 1,
39            ksk_size,
40        ) + module.vec_znx_dft_alloc_bytes(n, rank_in, in_size);
41        let normalize: usize = module.vec_znx_big_normalize_tmp_bytes(n);
42        res_dft + ((ai_dft + vmp) | normalize)
43    }
44
45    pub fn keyswitch_inplace_scratch_space<B: Backend>(
46        module: &Module<B>,
47        n: usize,
48        basek: usize,
49        k_out: usize,
50        k_ksk: usize,
51        digits: usize,
52        rank: usize,
53    ) -> usize
54    where
55        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
56    {
57        Self::keyswitch_scratch_space(module, n, basek, k_out, k_out, k_ksk, digits, rank, rank)
58    }
59}
60
61impl<DataSelf: DataRef> GLWECiphertext<DataSelf> {
62    #[allow(dead_code)]
63    pub(crate) fn assert_keyswitch<B: Backend, DataLhs, DataRhs>(
64        &self,
65        module: &Module<B>,
66        lhs: &GLWECiphertext<DataLhs>,
67        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
68        scratch: &Scratch<B>,
69    ) where
70        DataLhs: DataRef,
71        DataRhs: DataRef,
72        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
73        Scratch<B>: ScratchAvailable,
74    {
75        let basek: usize = self.basek();
76        assert_eq!(
77            lhs.rank(),
78            rhs.rank_in(),
79            "lhs.rank(): {} != rhs.rank_in(): {}",
80            lhs.rank(),
81            rhs.rank_in()
82        );
83        assert_eq!(
84            self.rank(),
85            rhs.rank_out(),
86            "self.rank(): {} != rhs.rank_out(): {}",
87            self.rank(),
88            rhs.rank_out()
89        );
90        assert_eq!(self.basek(), basek);
91        assert_eq!(lhs.basek(), basek);
92        assert_eq!(rhs.n(), self.n());
93        assert_eq!(lhs.n(), self.n());
94        assert!(
95            scratch.available()
96                >= GLWECiphertext::keyswitch_scratch_space(
97                    module,
98                    self.n(),
99                    self.basek(),
100                    self.k(),
101                    lhs.k(),
102                    rhs.k(),
103                    rhs.digits(),
104                    rhs.rank_in(),
105                    rhs.rank_out(),
106                ),
107            "scratch.available()={} < GLWECiphertext::keyswitch_scratch_space(
108                    module,
109                    self.basek(),
110                    self.k(),
111                    lhs.k(),
112                    rhs.k(),
113                    rhs.digits(),
114                    rhs.rank_in(),
115                    rhs.rank_out(),
116                )={}",
117            scratch.available(),
118            GLWECiphertext::keyswitch_scratch_space(
119                module,
120                self.n(),
121                self.basek(),
122                self.k(),
123                lhs.k(),
124                rhs.k(),
125                rhs.digits(),
126                rhs.rank_in(),
127                rhs.rank_out(),
128            )
129        );
130    }
131}
132
133impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
134    pub fn keyswitch<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
135        &mut self,
136        module: &Module<B>,
137        lhs: &GLWECiphertext<DataLhs>,
138        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
139        scratch: &mut Scratch<B>,
140    ) where
141        Module<B>: VecZnxDftAllocBytes
142            + VmpApplyTmpBytes
143            + VecZnxBigNormalizeTmpBytes
144            + VmpApplyTmpBytes
145            + VmpApply<B>
146            + VmpApplyAdd<B>
147            + VecZnxDftFromVecZnx<B>
148            + VecZnxDftToVecZnxBigConsume<B>
149            + VecZnxBigAddSmallInplace<B>
150            + VecZnxBigNormalize<B>,
151        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
152    {
153        #[cfg(debug_assertions)]
154        {
155            self.assert_keyswitch(module, lhs, rhs, scratch);
156        }
157        let (res_dft, scratch1) = scratch.take_vec_znx_dft(self.n(), self.cols(), rhs.size()); // Todo optimise
158        let res_big: VecZnxBig<_, B> = lhs.keyswitch_internal(module, res_dft, rhs, scratch1);
159        (0..self.cols()).for_each(|i| {
160            module.vec_znx_big_normalize(self.basek(), &mut self.data, i, &res_big, i, scratch1);
161        })
162    }
163
164    pub fn keyswitch_inplace<DataRhs: DataRef, B: Backend>(
165        &mut self,
166        module: &Module<B>,
167        rhs: &GGLWESwitchingKeyPrepared<DataRhs, B>,
168        scratch: &mut Scratch<B>,
169    ) where
170        Module<B>: VecZnxDftAllocBytes
171            + VmpApplyTmpBytes
172            + VecZnxBigNormalizeTmpBytes
173            + VmpApplyTmpBytes
174            + VmpApply<B>
175            + VmpApplyAdd<B>
176            + VecZnxDftFromVecZnx<B>
177            + VecZnxDftToVecZnxBigConsume<B>
178            + VecZnxBigAddSmallInplace<B>
179            + VecZnxBigNormalize<B>,
180        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
181    {
182        unsafe {
183            let self_ptr: *mut GLWECiphertext<DataSelf> = self as *mut GLWECiphertext<DataSelf>;
184            self.keyswitch(module, &*self_ptr, rhs, scratch);
185        }
186    }
187}
188
189impl<D: DataRef> GLWECiphertext<D> {
190    pub(crate) fn keyswitch_internal<B: Backend, DataRes, DataKey>(
191        &self,
192        module: &Module<B>,
193        res_dft: VecZnxDft<DataRes, B>,
194        rhs: &GGLWESwitchingKeyPrepared<DataKey, B>,
195        scratch: &mut Scratch<B>,
196    ) -> VecZnxBig<DataRes, B>
197    where
198        DataRes: DataMut,
199        DataKey: DataRef,
200        Module<B>: VecZnxDftAllocBytes
201            + VmpApplyTmpBytes
202            + VecZnxBigNormalizeTmpBytes
203            + VmpApplyTmpBytes
204            + VmpApply<B>
205            + VmpApplyAdd<B>
206            + VecZnxDftFromVecZnx<B>
207            + VecZnxDftToVecZnxBigConsume<B>
208            + VecZnxBigAddSmallInplace<B>
209            + VecZnxBigNormalize<B>,
210        Scratch<B>: TakeVecZnxDft<B>,
211    {
212        if rhs.digits() == 1 {
213            return keyswitch_vmp_one_digit(module, res_dft, &self.data, &rhs.key.data, scratch);
214        }
215
216        keyswitch_vmp_multiple_digits(
217            module,
218            res_dft,
219            &self.data,
220            &rhs.key.data,
221            rhs.digits(),
222            scratch,
223        )
224    }
225}
226
227fn keyswitch_vmp_one_digit<B: Backend, DataRes, DataIn, DataVmp>(
228    module: &Module<B>,
229    mut res_dft: VecZnxDft<DataRes, B>,
230    a: &VecZnx<DataIn>,
231    mat: &VmpPMat<DataVmp, B>,
232    scratch: &mut Scratch<B>,
233) -> VecZnxBig<DataRes, B>
234where
235    DataRes: DataMut,
236    DataIn: DataRef,
237    DataVmp: DataRef,
238    Module<B>:
239        VecZnxDftAllocBytes + VecZnxDftFromVecZnx<B> + VmpApply<B> + VecZnxDftToVecZnxBigConsume<B> + VecZnxBigAddSmallInplace<B>,
240    Scratch<B>: TakeVecZnxDft<B>,
241{
242    let cols: usize = a.cols();
243    let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, a.size());
244    (0..cols - 1).for_each(|col_i| {
245        module.vec_znx_dft_from_vec_znx(1, 0, &mut ai_dft, col_i, a, col_i + 1);
246    });
247    module.vmp_apply(&mut res_dft, &ai_dft, mat, scratch1);
248    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_dft_to_vec_znx_big_consume(res_dft);
249    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
250    res_big
251}
252
253fn keyswitch_vmp_multiple_digits<B: Backend, DataRes, DataIn, DataVmp>(
254    module: &Module<B>,
255    mut res_dft: VecZnxDft<DataRes, B>,
256    a: &VecZnx<DataIn>,
257    mat: &VmpPMat<DataVmp, B>,
258    digits: usize,
259    scratch: &mut Scratch<B>,
260) -> VecZnxBig<DataRes, B>
261where
262    DataRes: DataMut,
263    DataIn: DataRef,
264    DataVmp: DataRef,
265    Module<B>: VecZnxDftAllocBytes
266        + VecZnxDftFromVecZnx<B>
267        + VmpApply<B>
268        + VmpApplyAdd<B>
269        + VecZnxDftToVecZnxBigConsume<B>
270        + VecZnxBigAddSmallInplace<B>,
271    Scratch<B>: TakeVecZnxDft<B>,
272{
273    let cols: usize = a.cols();
274    let size: usize = a.size();
275    let (mut ai_dft, scratch1) = scratch.take_vec_znx_dft(a.n(), cols - 1, size.div_ceil(digits));
276
277    ai_dft.data_mut().fill(0);
278
279    (0..digits).for_each(|di| {
280        ai_dft.set_size((size + di) / digits);
281
282        // Small optimization for digits > 2
283        // VMP produce some error e, and since we aggregate vmp * 2^{di * B}, then
284        // we also aggregate ei * 2^{di * B}, with the largest error being ei * 2^{(digits-1) * B}.
285        // As such we can ignore the last digits-2 limbs safely of the sum of vmp products.
286        // It is possible to further ignore the last digits-1 limbs, but this introduce
287        // ~0.5 to 1 bit of additional noise, and thus not chosen here to ensure that the same
288        // noise is kept with respect to the ideal functionality.
289        res_dft.set_size(mat.size() - ((digits - di) as isize - 2).max(0) as usize);
290
291        (0..cols - 1).for_each(|col_i| {
292            module.vec_znx_dft_from_vec_znx(digits, digits - di - 1, &mut ai_dft, col_i, a, col_i + 1);
293        });
294
295        if di == 0 {
296            module.vmp_apply(&mut res_dft, &ai_dft, mat, scratch1);
297        } else {
298            module.vmp_apply_add(&mut res_dft, &ai_dft, mat, di, scratch1);
299        }
300    });
301
302    res_dft.set_size(res_dft.max_size());
303    let mut res_big: VecZnxBig<DataRes, B> = module.vec_znx_dft_to_vec_znx_big_consume(res_dft);
304    module.vec_znx_big_add_small_inplace(&mut res_big, 0, a, 0);
305    res_big
306}