1use poulpy_hal::{
2 api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
3 layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
4 source::Source,
5};
6
7use crate::{
8 GLWEEncryptSk, ScratchTakeCore,
9 layouts::{
10 GGLWE, GGLWEInfos, GGLWEToMut, GLWEPlaintext, LWEInfos,
11 prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
12 },
13};
14
15impl GGLWE<Vec<u8>> {
16 pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
17 where
18 A: GGLWEInfos,
19 M: GGLWEEncryptSk<BE>,
20 {
21 module.gglwe_encrypt_sk_tmp_bytes(infos)
22 }
23
24 pub fn encrypt_pk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
25 where
26 A: GGLWEInfos,
27 M: GGLWEEncryptSk<BE>,
28 {
29 module.gglwe_encrypt_sk_tmp_bytes(infos)
30 }
31}
32
33impl<D: DataMut> GGLWE<D> {
34 #[allow(clippy::too_many_arguments)]
35 pub fn encrypt_sk<P, S, M, BE: Backend>(
36 &mut self,
37 module: &M,
38 pt: &P,
39 sk: &S,
40 source_xa: &mut Source,
41 source_xe: &mut Source,
42 scratch: &mut Scratch<BE>,
43 ) where
44 P: ScalarZnxToRef,
45 S: GLWESecretPreparedToRef<BE>,
46 M: GGLWEEncryptSk<BE>,
47 Scratch<BE>: ScratchTakeCore<BE>,
48 {
49 module.gglwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch);
50 }
51}
52
53pub trait GGLWEEncryptSk<BE: Backend> {
54 fn gglwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
55 where
56 A: GGLWEInfos;
57
58 fn gglwe_encrypt_sk<R, P, S>(
59 &self,
60 res: &mut R,
61 pt: &P,
62 sk: &S,
63 source_xa: &mut Source,
64 source_xe: &mut Source,
65 scratch: &mut Scratch<BE>,
66 ) where
67 R: GGLWEToMut,
68 P: ScalarZnxToRef,
69 S: GLWESecretPreparedToRef<BE>;
70}
71
72impl<BE: Backend> GGLWEEncryptSk<BE> for Module<BE>
73where
74 Self: ModuleN
75 + GLWEEncryptSk<BE>
76 + VecZnxNormalizeTmpBytes
77 + VecZnxDftBytesOf
78 + VecZnxAddScalarInplace
79 + VecZnxNormalizeInplace<BE>,
80 Scratch<BE>: ScratchTakeCore<BE>,
81{
82 fn gglwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
83 where
84 A: GGLWEInfos,
85 {
86 self.glwe_encrypt_sk_tmp_bytes(infos) + GLWEPlaintext::bytes_of_from_infos(infos).max(self.vec_znx_normalize_tmp_bytes())
87 }
88
89 fn gglwe_encrypt_sk<R, P, S>(
90 &self,
91 res: &mut R,
92 pt: &P,
93 sk: &S,
94 source_xa: &mut Source,
95 source_xe: &mut Source,
96 scratch: &mut Scratch<BE>,
97 ) where
98 R: GGLWEToMut,
99 P: ScalarZnxToRef,
100 S: GLWESecretPreparedToRef<BE>,
101 {
102 let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
103 let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
104 let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
105
106 assert_eq!(
107 res.rank_in(),
108 pt.cols() as u32,
109 "res.rank_in(): {} != pt.cols(): {}",
110 res.rank_in(),
111 pt.cols()
112 );
113 assert_eq!(
114 res.rank_out(),
115 sk.rank(),
116 "res.rank_out(): {} != sk.rank(): {}",
117 res.rank_out(),
118 sk.rank()
119 );
120 assert_eq!(res.n(), sk.n());
121 assert_eq!(pt.n() as u32, sk.n());
122 assert!(
123 scratch.available() >= self.gglwe_encrypt_sk_tmp_bytes(res),
124 "scratch.available: {} < GGLWE::encrypt_sk_tmp_bytes(self, res.rank()={}, res.size()={}): {}",
125 scratch.available(),
126 res.rank_out(),
127 res.size(),
128 self.gglwe_encrypt_sk_tmp_bytes(res)
129 );
130 assert!(
131 res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0,
132 "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}",
133 res.dnum(),
134 res.dsize(),
135 res.base2k(),
136 res.dnum().0 * res.dsize().0 * res.base2k().0,
137 res.k()
138 );
139
140 let dnum: usize = res.dnum().into();
141 let dsize: usize = res.dsize().into();
142 let base2k: usize = res.base2k().into();
143 let rank_in: usize = res.rank_in().into();
144
145 let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res);
146 for col_i in 0..rank_in {
158 for row_i in 0..dnum {
159 tmp_pt.data.zero(); self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i);
162 self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
163 self.glwe_encrypt_sk(
164 &mut res.at_mut(row_i, col_i),
165 &tmp_pt,
166 sk,
167 source_xa,
168 source_xe,
169 scrach_1,
170 );
171 }
172 }
173 }
174}