1use poulpy_hal::{
2 api::{ScratchOwnedAlloc, ScratchOwnedBorrow, ZnAddNormal, ZnFillUniform, ZnNormalizeInplace},
3 layouts::{Backend, DataMut, Module, ScratchOwned, Zn, ZnxView, ZnxViewMut},
4 source::Source,
5};
6
7use crate::{
8 encryption::{SIGMA, SIGMA_BOUND},
9 layouts::{LWE, LWEInfos, LWEPlaintext, LWEPlaintextToRef, LWESecret, LWESecretToRef, LWEToMut},
10};
11
12impl<DataSelf: DataMut> LWE<DataSelf> {
13 pub fn encrypt_sk<P, S, M, BE: Backend>(&mut self, module: &M, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source)
14 where
15 P: LWEPlaintextToRef,
16 S: LWESecretToRef,
17 M: LWEEncryptSk<BE>,
18 {
19 module.lwe_encrypt_sk(self, pt, sk, source_xa, source_xe);
20 }
21}
22
23pub trait LWEEncryptSk<BE: Backend> {
24 fn lwe_encrypt_sk<R, P, S>(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source)
25 where
26 R: LWEToMut,
27 P: LWEPlaintextToRef,
28 S: LWESecretToRef;
29}
30
31impl<BE: Backend> LWEEncryptSk<BE> for Module<BE>
32where
33 Self: Sized + ZnFillUniform + ZnAddNormal + ZnNormalizeInplace<BE>,
34 ScratchOwned<BE>: ScratchOwnedAlloc<BE> + ScratchOwnedBorrow<BE>,
35{
36 fn lwe_encrypt_sk<R, P, S>(&self, res: &mut R, pt: &P, sk: &S, source_xa: &mut Source, source_xe: &mut Source)
37 where
38 R: LWEToMut,
39 P: LWEPlaintextToRef,
40 S: LWESecretToRef,
41 {
42 let res: &mut LWE<&mut [u8]> = &mut res.to_mut();
43 let pt: &LWEPlaintext<&[u8]> = &pt.to_ref();
44 let sk: &LWESecret<&[u8]> = &sk.to_ref();
45
46 #[cfg(debug_assertions)]
47 {
48 assert_eq!(res.n(), sk.n())
49 }
50
51 let base2k: usize = res.base2k().into();
52 let k: usize = res.k().into();
53
54 self.zn_fill_uniform((res.n() + 1).into(), base2k, &mut res.data, 0, source_xa);
55
56 let mut tmp_znx: Zn<Vec<u8>> = Zn::alloc(1, 1, res.size());
57
58 let min_size = res.size().min(pt.size());
59
60 (0..min_size).for_each(|i| {
61 tmp_znx.at_mut(0, i)[0] = pt.data.at(0, i)[0]
62 - res.data.at(0, i)[1..]
63 .iter()
64 .zip(sk.data.at(0, 0))
65 .map(|(x, y)| x * y)
66 .sum::<i64>();
67 });
68
69 (min_size..res.size()).for_each(|i| {
70 tmp_znx.at_mut(0, i)[0] -= res.data.at(0, i)[1..]
71 .iter()
72 .zip(sk.data.at(0, 0))
73 .map(|(x, y)| x * y)
74 .sum::<i64>();
75 });
76
77 self.zn_add_normal(
78 1,
79 base2k,
80 &mut res.data,
81 0,
82 k,
83 source_xe,
84 SIGMA,
85 SIGMA_BOUND,
86 );
87
88 self.zn_normalize_inplace(
89 1,
90 base2k,
91 &mut tmp_znx,
92 0,
93 ScratchOwned::alloc(size_of::<i64>()).borrow(),
94 );
95
96 (0..res.size()).for_each(|i| {
97 res.data.at_mut(0, i)[0] = tmp_znx.at(0, i)[0];
98 });
99 }
100}