poulpy_core/encryption/compressed/
gglwe.rs1use poulpy_hal::{
2 api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
3 layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
4 source::Source,
5};
6
7use crate::{
8 ScratchTakeCore,
9 encryption::{GLWEEncryptSk, GLWEEncryptSkInternal, SIGMA},
10 layouts::{
11 GGLWECompressedSeedMut, GGLWEInfos, GLWEPlaintext, GLWESecretPrepared, LWEInfos,
12 compressed::{GGLWECompressed, GGLWECompressedToMut},
13 prepared::GLWESecretPreparedToRef,
14 },
15};
16
17impl<D: DataMut> GGLWECompressed<D> {
18 #[allow(clippy::too_many_arguments)]
19 pub fn encrypt_sk<M, P, S, BE: Backend>(
20 &mut self,
21 module: &M,
22 pt: &P,
23 sk: &S,
24 seed: [u8; 32],
25 source_xe: &mut Source,
26 scratch: &mut Scratch<BE>,
27 ) where
28 P: ScalarZnxToRef,
29 S: GLWESecretPreparedToRef<BE>,
30 M: GGLWECompressedEncryptSk<BE>,
31 {
32 module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch);
33 }
34}
35
36impl GGLWECompressed<Vec<u8>> {
37 pub fn encrypt_sk_tmp_bytes<M, BE: Backend, A>(module: &M, infos: &A) -> usize
38 where
39 A: GGLWEInfos,
40 M: GGLWECompressedEncryptSk<BE>,
41 {
42 module.gglwe_compressed_encrypt_sk_tmp_bytes(infos)
43 }
44}
45
46pub trait GGLWECompressedEncryptSk<BE: Backend> {
47 fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
48 where
49 A: GGLWEInfos;
50
51 fn gglwe_compressed_encrypt_sk<R, P, S>(
52 &self,
53 res: &mut R,
54 pt: &P,
55 sk: &S,
56 seed: [u8; 32],
57 source_xe: &mut Source,
58 scratch: &mut Scratch<BE>,
59 ) where
60 R: GGLWECompressedToMut + GGLWECompressedSeedMut,
61 P: ScalarZnxToRef,
62 S: GLWESecretPreparedToRef<BE>;
63}
64
65impl<BE: Backend> GGLWECompressedEncryptSk<BE> for Module<BE>
66where
67 Self: ModuleN
68 + GLWEEncryptSkInternal<BE>
69 + GLWEEncryptSk<BE>
70 + VecZnxDftBytesOf
71 + VecZnxNormalizeInplace<BE>
72 + VecZnxAddScalarInplace
73 + VecZnxNormalizeTmpBytes,
74 Scratch<BE>: ScratchTakeCore<BE>,
75{
76 fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
77 where
78 A: GGLWEInfos,
79 {
80 self.glwe_encrypt_sk_tmp_bytes(infos)
81 .max(self.vec_znx_normalize_tmp_bytes())
82 + GLWEPlaintext::bytes_of_from_infos(infos)
83 }
84
85 fn gglwe_compressed_encrypt_sk<R, P, S>(
86 &self,
87 res: &mut R,
88 pt: &P,
89 sk: &S,
90 seed: [u8; 32],
91 source_xe: &mut Source,
92 scratch: &mut Scratch<BE>,
93 ) where
94 R: GGLWECompressedToMut + GGLWECompressedSeedMut,
95 P: ScalarZnxToRef,
96 S: GLWESecretPreparedToRef<BE>,
97 {
98 let mut seeds: Vec<[u8; 32]> = vec![[0u8; 32]; res.seed_mut().len()];
99
100 {
101 let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut();
102 let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
103 let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
104
105 assert_eq!(
106 res.rank_in(),
107 pt.cols() as u32,
108 "res.rank_in(): {} != pt.cols(): {}",
109 res.rank_in(),
110 pt.cols()
111 );
112 assert_eq!(
113 res.rank_out(),
114 sk.rank(),
115 "res.rank_out(): {} != sk.rank(): {}",
116 res.rank_out(),
117 sk.rank()
118 );
119 assert_eq!(res.n(), sk.n());
120 assert_eq!(pt.n() as u32, sk.n());
121 assert!(
122 scratch.available() >= GGLWECompressed::encrypt_sk_tmp_bytes(self, res),
123 "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes: {}",
124 scratch.available(),
125 GGLWECompressed::encrypt_sk_tmp_bytes(self, res)
126 );
127 assert!(
128 res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0,
129 "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}",
130 res.dnum(),
131 res.dsize(),
132 res.base2k(),
133 res.dnum().0 * res.dsize().0 * res.base2k().0,
134 res.k()
135 );
136
137 let dnum: usize = res.dnum().into();
138 let dsize: usize = res.dsize().into();
139 let base2k: usize = res.base2k().into();
140 let rank_in: usize = res.rank_in().into();
141 let cols: usize = (res.rank_out() + 1).into();
142
143 let mut source_xa = Source::new(seed);
144
145 let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res);
146 for col_i in 0..rank_in {
147 for d_i in 0..dnum {
148 tmp_pt.data.zero(); self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i);
151 self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
152
153 let (seed, mut source_xa_tmp) = source_xa.branch();
154 seeds[col_i * dnum + d_i] = seed;
155
156 self.glwe_encrypt_sk_internal(
157 res.base2k().into(),
158 res.k().into(),
159 &mut res.at_mut(d_i, col_i).data,
160 cols,
161 true,
162 Some((&tmp_pt, 0)),
163 sk,
164 &mut source_xa_tmp,
165 source_xe,
166 SIGMA,
167 scrach_1,
168 );
169 }
170 }
171 }
172
173 res.seed_mut().copy_from_slice(&seeds);
174 }
175}