poulpy_core/encryption/
lwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxAddNormal, VecZnxFillUniform, VecZnxNormalizeInplace, ZnxView, ZnxViewMut,
4    },
5    layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, VecZnx},
6    oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
7    source::Source,
8};
9
10use crate::{
11    SIX_SIGMA,
12    layouts::{Infos, LWECiphertext, LWEPlaintext, LWESecret},
13};
14
15impl<DataSelf: DataMut> LWECiphertext<DataSelf> {
16    pub fn encrypt_sk<DataPt, DataSk, B>(
17        &mut self,
18        module: &Module<B>,
19        pt: &LWEPlaintext<DataPt>,
20        sk: &LWESecret<DataSk>,
21        source_xa: &mut Source,
22        source_xe: &mut Source,
23        sigma: f64,
24    ) where
25        DataPt: DataRef,
26        DataSk: DataRef,
27        Module<B>: VecZnxFillUniform + VecZnxAddNormal + VecZnxNormalizeInplace<B>,
28        B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
29    {
30        #[cfg(debug_assertions)]
31        {
32            assert_eq!(self.n(), sk.n())
33        }
34
35        let basek: usize = self.basek();
36        let k: usize = self.k();
37
38        module.vec_znx_fill_uniform(basek, &mut self.data, 0, k, source_xa);
39
40        let mut tmp_znx: VecZnx<Vec<u8>> = VecZnx::alloc(1, 1, self.size());
41
42        let min_size = self.size().min(pt.size());
43
44        (0..min_size).for_each(|i| {
45            tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0]
46                - self.data.at(0, i)[1..]
47                    .iter()
48                    .zip(sk.data.at(0, 0))
49                    .map(|(x, y)| x * y)
50                    .sum::<i64>();
51        });
52
53        (min_size..self.size()).for_each(|i| {
54            tmp_znx.at_mut(0, i)[0] -= self.data.at(0, i)[1..]
55                .iter()
56                .zip(sk.data.at(0, 0))
57                .map(|(x, y)| x * y)
58                .sum::<i64>();
59        });
60
61        module.vec_znx_add_normal(
62            basek,
63            &mut self.data,
64            0,
65            k,
66            source_xe,
67            sigma,
68            sigma * SIX_SIGMA,
69        );
70
71        module.vec_znx_normalize_inplace(
72            basek,
73            &mut tmp_znx,
74            0,
75            ScratchOwned::alloc(size_of::<i64>()).borrow(),
76        );
77
78        (0..self.size()).for_each(|i| {
79            self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0];
80        });
81    }
82}