poulpy_core/encryption/compressed/
ggsw.rs1use poulpy_hal::{
2 api::{ModuleN, VecZnxAddScalarInplace, VecZnxNormalizeInplace},
3 layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
4 source::Source,
5};
6
7use crate::{
8 ScratchTakeCore,
9 encryption::{GGSWEncryptSk, GLWEEncryptSkInternal, SIGMA},
10 layouts::{
11 GGSWCompressedSeedMut, GGSWInfos, LWEInfos,
12 compressed::{GGSWCompressed, GGSWCompressedToMut},
13 prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
14 },
15};
16
17impl GGSWCompressed<Vec<u8>> {
18 pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
19 where
20 A: GGSWInfos,
21 M: GGSWCompressedEncryptSk<BE>,
22 {
23 module.ggsw_compressed_encrypt_sk_tmp_bytes(infos)
24 }
25}
26
27impl<DataSelf: DataMut> GGSWCompressed<DataSelf> {
28 #[allow(clippy::too_many_arguments)]
29 pub fn encrypt_sk<P, S, M, BE: Backend>(
30 &mut self,
31 module: &M,
32 pt: &P,
33 sk: &S,
34 seed_xa: [u8; 32],
35 source_xe: &mut Source,
36 scratch: &mut Scratch<BE>,
37 ) where
38 P: ScalarZnxToRef,
39 S: GLWESecretPreparedToRef<BE>,
40 M: GGSWCompressedEncryptSk<BE>,
41 {
42 module.ggsw_compressed_encrypt_sk(self, pt, sk, seed_xa, source_xe, scratch);
43 }
44}
45
46pub trait GGSWCompressedEncryptSk<BE: Backend> {
47 fn ggsw_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
48 where
49 A: GGSWInfos;
50
51 fn ggsw_compressed_encrypt_sk<R, P, S>(
52 &self,
53 res: &mut R,
54 pt: &P,
55 sk: &S,
56 seed_xa: [u8; 32],
57 source_xe: &mut Source,
58 scratch: &mut Scratch<BE>,
59 ) where
60 R: GGSWCompressedToMut + GGSWCompressedSeedMut + GGSWInfos,
61 P: ScalarZnxToRef,
62 S: GLWESecretPreparedToRef<BE>;
63}
64
65impl<BE: Backend> GGSWCompressedEncryptSk<BE> for Module<BE>
66where
67 Self: ModuleN + GLWEEncryptSkInternal<BE> + GGSWEncryptSk<BE> + VecZnxAddScalarInplace + VecZnxNormalizeInplace<BE>,
68 Scratch<BE>: ScratchTakeCore<BE>,
69{
70 fn ggsw_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
71 where
72 A: GGSWInfos,
73 {
74 self.ggsw_encrypt_sk_tmp_bytes(infos)
75 }
76
77 fn ggsw_compressed_encrypt_sk<R, P, S>(
78 &self,
79 res: &mut R,
80 pt: &P,
81 sk: &S,
82 seed_xa: [u8; 32],
83 source_xe: &mut Source,
84 scratch: &mut Scratch<BE>,
85 ) where
86 R: GGSWCompressedToMut + GGSWCompressedSeedMut + GGSWInfos,
87 P: ScalarZnxToRef,
88 S: GLWESecretPreparedToRef<BE>,
89 {
90 let base2k: usize = res.base2k().into();
91 let rank: usize = res.rank().into();
92 let cols: usize = rank + 1;
93 let dsize: usize = res.dsize().into();
94
95 let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
96 let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
97
98 assert_eq!(res.rank(), sk.rank());
99 assert_eq!(pt.n(), self.n());
100 assert_eq!(res.n(), self.n() as u32);
101 assert_eq!(sk.n(), self.n() as u32);
102
103 let mut seeds: Vec<[u8; 32]> = vec![[0u8; 32]; res.dnum().as_usize() * (res.rank().as_usize() + 1)];
104
105 {
106 let res: &mut GGSWCompressed<&mut [u8]> = &mut res.to_mut();
107
108 println!("res.seed: {:?}", res.seed);
109
110 let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res);
111
112 let mut source = Source::new(seed_xa);
113
114 for row_i in 0..res.dnum().into() {
115 tmp_pt.data.zero();
116
117 self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0);
119 self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1);
120
121 for col_j in 0..rank + 1 {
122 let (seed, mut source_xa_tmp) = source.branch();
125
126 seeds[row_i * cols + col_j] = seed;
127
128 self.glwe_encrypt_sk_internal(
129 res.base2k().into(),
130 res.k().into(),
131 &mut res.at_mut(row_i, col_j).data,
132 cols,
133 true,
134 Some((&tmp_pt, col_j)),
135 sk,
136 &mut source_xa_tmp,
137 source_xe,
138 SIGMA,
139 scratch_1,
140 );
141 }
142 }
143 }
144
145 res.seed_mut().copy_from_slice(&seeds);
146 }
147}