poulpy_core/encryption/
lwe.rs

1use poulpy_hal::{
2    api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace},
3    layouts::{Backend, DataMut, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut},
4    source::Source,
5};
6
7use crate::{
8    encryption::{SIGMA, SIGMA_BOUND},
9    layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut},
10};
11
12impl<DataSelf: DataMut> LWE<DataSelf> {
13    pub fn encrypt_sk<P, S, M, BE: Backend>(&mut self, module: &M, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source)
14    where
15        P: LWEPlaintextToRef,
16        S: LWESecretToRef,
17        M: LWEEncryptSk<BE>,
18    {
19        module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe);
20    }
21}
22
23pub trait LWEEncryptSk<BE: Backend> {
24    fn lwe_encrypt_sk<R, P, S>(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source)
25    where
26        R: LWEToMut,
27        P: LWEPlaintextToRef,
28        S: LWESecretToRef;
29}
30
31impl<BE: Backend> LWEEncryptSk<BE> for Module<BE>
32where
33    Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace<BE>,
34    ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
35{
36    fn lwe_encrypt_sk<R, P, S>(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source)
37    where
38        R: LWEToMut,
39        P: LWEPlaintextToRef,
40        S: LWESecretToRef,
41    {
42        let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
43        let pt: &LWEPlaintext<&[u8]> = &pt.to_ref();
44        let sk: &LWESecret<&[u8]> = &sk.to_ref();
45
46        #[cfg(debug_assertions)]
47        {
48            assert_eq!(res.n(), sk.n())
49        }
50
51        let base2k: usize = res.base2k().into();
52        let k: usize = res.k().into();
53
54        self.zn_fill_uniform((res.n() + 1).into(), base2k, &mut res.data, 0, source_xa);
55
56        let mut tmp_znx: Zn<Vec<u8>> = Zn::alloc(1, 1, res.size());
57
58        let min_size = res.size().min(pt.size());
59
60        (0..min_size).for_each(|i| {
61            tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0]
62                - res.data.at(0, i)[1..]
63                    .iter()
64                    .zip(sk.data.at(0, 0))
65                    .map(|(x, y)| x * y)
66                    .sum::<i64>();
67        });
68
69        (min_size..res.size()).for_each(|i| {
70            tmp_znx.at_mut(0, i)[0] -= res.data.at(0, i)[1..]
71                .iter()
72                .zip(sk.data.at(0, 0))
73                .map(|(x, y)| x * y)
74                .sum::<i64>();
75        });
76
77        self.zn_add_normal(
78            1,
79            base2k,
80            &mut res.data,
81            0,
82            k,
83            source_xe,
84            SIGMA,
85            SIGMA_BOUND,
86        );
87
88        self.zn_normalize_inplace(
89            1,
90            base2k,
91            &mut tmp_znx,
92            0,
93            ScratchOwned::alloc(size_of::<i64>()).borrow(),
94        );
95
96        (0..res.size()).for_each(|i| {
97            res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0];
98        });
99    }
100}