1use poulpy_hal::{
2 api::{ModuleN, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
3 layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
4 source::Source,
5};
6
7use crate::{
8 GLWEEncryptSk, GLWEEncryptSkInternal, SIGMA, ScratchTakeCore,
9 layouts::{
10 GGSW, GGSWInfos, GGSWToMut, GLWEInfos, GLWEPlaintext, LWEInfos,
11 prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
12 },
13};
14
15impl GGSW<Vec<u8>> {
16 pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
17 where
18 A: GGSWInfos,
19 M: GGSWEncryptSk<BE>,
20 {
21 module.ggsw_encrypt_sk_tmp_bytes(infos)
22 }
23}
24
25impl<D: DataMut> GGSW<D> {
26 #[allow(clippy::too_many_arguments)]
27 pub fn encrypt_sk<P, S, M, BE: Backend>(
28 &mut self,
29 module: &M,
30 pt: &P,
31 sk: &S,
32 source_xa: &mut Source,
33 source_xe: &mut Source,
34 scratch: &mut Scratch<BE>,
35 ) where
36 P: ScalarZnxToRef,
37 S: GLWESecretPreparedToRef<BE>,
38 M: GGSWEncryptSk<BE>,
39 Scratch<BE>: ScratchTakeCore<BE>,
40 {
41 module.ggsw_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch);
42 }
43}
44
45pub trait GGSWEncryptSk<BE: Backend> {
46 fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
47 where
48 A: GGSWInfos;
49
50 fn ggsw_encrypt_sk<R, P, S>(
51 &self,
52 res: &mut R,
53 pt: &P,
54 sk: &S,
55 source_xa: &mut Source,
56 source_xe: &mut Source,
57 scratch: &mut Scratch<BE>,
58 ) where
59 R: GGSWToMut,
60 P: ScalarZnxToRef,
61 S: GLWESecretPreparedToRef<BE>;
62}
63
64impl<BE: Backend> GGSWEncryptSk<BE> for Module<BE>
65where
66 Self: ModuleN
67 + GLWEEncryptSkInternal<BE>
68 + GLWEEncryptSk<BE>
69 + VecZnxDftBytesOf
70 + VecZnxNormalizeInplace<BE>
71 + VecZnxAddScalarInplace
72 + VecZnxNormalizeTmpBytes,
73 Scratch<BE>: ScratchTakeCore<BE>,
74{
75 fn ggsw_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
76 where
77 A: GGSWInfos,
78 {
79 self.glwe_encrypt_sk_tmp_bytes(infos)
80 .max(self.vec_znx_normalize_tmp_bytes())
81 + GLWEPlaintext::bytes_of_from_infos(infos)
82 }
83
84 fn ggsw_encrypt_sk<R, P, S>(
85 &self,
86 res: &mut R,
87 pt: &P,
88 sk: &S,
89 source_xa: &mut Source,
90 source_xe: &mut Source,
91 scratch: &mut Scratch<BE>,
92 ) where
93 R: GGSWToMut,
94 P: ScalarZnxToRef,
95 S: GLWESecretPreparedToRef<BE>,
96 {
97 let res: &mut GGSW<&mut [u8]> = &mut res.to_mut();
98 let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
99 let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
100
101 assert_eq!(res.rank(), sk.rank());
102 assert_eq!(res.n(), self.n() as u32);
103 assert_eq!(pt.n(), self.n());
104 assert_eq!(sk.n(), self.n() as u32);
105
106 let k: usize = res.k().into();
107 let base2k: usize = res.base2k().into();
108 let rank: usize = res.rank().into();
109 let dsize: usize = res.dsize().into();
110 let cols: usize = rank + 1;
111
112 let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res);
113
114 for row_i in 0..res.dnum().into() {
115 tmp_pt.data.zero();
116 self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0);
118 self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1);
119 for col_j in 0..rank + 1 {
120 self.glwe_encrypt_sk_internal(
121 base2k,
122 k,
123 res.at_mut(row_i, col_j).data_mut(),
124 cols,
125 false,
126 Some((&tmp_pt, col_j)),
127 sk,
128 source_xa,
129 source_xe,
130 SIGMA,
131 scratch_1,
132 );
133 }
134 }
135 }
136}