poulpy_core/encryption/
lwe_ct.rs

1use poulpy_hal::{
2    api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace},
3    layouts::{Backend, DataMut, DataRef, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut},
4    oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl},
5    source::Source,
6};
7
8use crate::{
9    encryption::{SIGMA, SIGMA_BOUND},
10    layouts::{LWECiphertext, LWEInfos, LWEPlaintext, LWESecret},
11};
12
13impl<DataSelf: DataMut> LWECiphertext<DataSelf> {
14    pub fn encrypt_sk<DataPt, DataSk, B>(
15        &mut self,
16        module: &Module<B>,
17        pt: &LWEPlaintext<DataPt>,
18        sk: &LWESecret<DataSk>,
19        source_xa: &mut Source,
20        source_xe: &mut Source,
21    ) where
22        DataPt: DataRef,
23        DataSk: DataRef,
24        Module<B>: ZnFillUniform + ZnAddNormal + ZnNormalizeInplace<B>,
25        B: Backend + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
26    {
27        #[cfg(debug_assertions)]
28        {
29            assert_eq!(self.n(), sk.n())
30        }
31
32        let base2k: usize = self.base2k().into();
33        let k: usize = self.k().into();
34
35        module.zn_fill_uniform((self.n() + 1).into(), base2k, &mut self.data, 0, source_xa);
36
37        let mut tmp_znx: Zn<Vec<u8>> = Zn::alloc(1, 1, self.size());
38
39        let min_size = self.size().min(pt.size());
40
41        (0..min_size).for_each(|i| {
42            tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0]
43                - self.data.at(0, i)[1..]
44                    .iter()
45                    .zip(sk.data.at(0, 0))
46                    .map(|(x, y)| x * y)
47                    .sum::<i64>();
48        });
49
50        (min_size..self.size()).for_each(|i| {
51            tmp_znx.at_mut(0, i)[0] -= self.data.at(0, i)[1..]
52                .iter()
53                .zip(sk.data.at(0, 0))
54                .map(|(x, y)| x * y)
55                .sum::<i64>();
56        });
57
58        module.zn_add_normal(
59            1,
60            base2k,
61            &mut self.data,
62            0,
63            k,
64            source_xe,
65            SIGMA,
66            SIGMA_BOUND,
67        );
68
69        module.zn_normalize_inplace(
70            1,
71            base2k,
72            &mut tmp_znx,
73            0,
74            ScratchOwned::alloc(size_of::<i64>()).borrow(),
75        );
76
77        (0..self.size()).for_each(|i| {
78            self.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0];
79        });
80    }
81}