1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
4 VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform,
5 VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, ZnxZero,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
8 source::Source,
9};
10
11use crate::{
12 TakeGLWEPt,
13 layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGLWECiphertext<Vec<u8>> {
17 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
18 where
19 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
20 {
21 GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k)
22 + (GLWEPlaintext::byte_of(n, basek, k) | module.vec_znx_normalize_tmp_bytes(n))
23 }
24
25 pub fn encrypt_pk_scratch_space<B: Backend>(_module: &Module<B>, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize {
26 unimplemented!()
27 }
28}
29
30impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
31 #[allow(clippy::too_many_arguments)]
32 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
33 &mut self,
34 module: &Module<B>,
35 pt: &ScalarZnx<DataPt>,
36 sk: &GLWESecretPrepared<DataSk, B>,
37 source_xa: &mut Source,
38 source_xe: &mut Source,
39 sigma: f64,
40 scratch: &mut Scratch<B>,
41 ) where
42 Module<B>: VecZnxAddScalarInplace
43 + VecZnxDftAllocBytes
44 + VecZnxBigNormalize<B>
45 + VecZnxDftFromVecZnx<B>
46 + SvpApplyInplace<B>
47 + VecZnxDftToVecZnxBigConsume<B>
48 + VecZnxNormalizeTmpBytes
49 + VecZnxFillUniform
50 + VecZnxSubABInplace
51 + VecZnxAddInplace
52 + VecZnxNormalizeInplace<B>
53 + VecZnxAddNormal
54 + VecZnxNormalize<B>
55 + VecZnxSub,
56 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
57 {
58 #[cfg(debug_assertions)]
59 {
60 use poulpy_hal::api::ZnxInfos;
61
62 assert_eq!(
63 self.rank_in(),
64 pt.cols(),
65 "self.rank_in(): {} != pt.cols(): {}",
66 self.rank_in(),
67 pt.cols()
68 );
69 assert_eq!(
70 self.rank_out(),
71 sk.rank(),
72 "self.rank_out(): {} != sk.rank(): {}",
73 self.rank_out(),
74 sk.rank()
75 );
76 assert_eq!(self.n(), sk.n());
77 assert_eq!(pt.n(), sk.n());
78 assert!(
79 scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()),
80 "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
81 scratch.available(),
82 self.rank(),
83 self.size(),
84 GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k())
85 );
86 assert!(
87 self.rows() * self.digits() * self.basek() <= self.k(),
88 "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}",
89 self.rows(),
90 self.digits(),
91 self.basek(),
92 self.rows() * self.digits() * self.basek(),
93 self.k()
94 );
95 }
96
97 let rows: usize = self.rows();
98 let digits: usize = self.digits();
99 let basek: usize = self.basek();
100 let k: usize = self.k();
101 let rank_in: usize = self.rank_in();
102
103 let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k);
104 (0..rank_in).for_each(|col_i| {
116 (0..rows).for_each(|row_i| {
117 tmp_pt.data.zero(); module.vec_znx_add_scalar_inplace(
120 &mut tmp_pt.data,
121 0,
122 (digits - 1) + row_i * digits,
123 pt,
124 col_i,
125 );
126 module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1);
127
128 self.at_mut(row_i, col_i)
130 .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scrach_1);
131 });
132 });
133 }
134}