1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4 VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5 VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
8 source::Source,
9};
10
11use crate::{
12 TakeGLWEPt,
13 layouts::{GGLWECiphertext, GGLWELayoutInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared},
14};
15
16impl GGLWECiphertext<Vec<u8>> {
17 pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
18 where
19 A: GGLWELayoutInfos,
20 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
21 {
22 GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout())
23 + (GLWEPlaintext::alloc_bytes(&infos.glwe_layout()) | module.vec_znx_normalize_tmp_bytes())
24 }
25
26 pub fn encrypt_pk_scratch_space<B: Backend, A>(_module: &Module<B>, _infos: &A) -> usize
27 where
28 A: GGLWELayoutInfos,
29 {
30 unimplemented!()
31 }
32}
33
34impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
35 #[allow(clippy::too_many_arguments)]
36 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
37 &mut self,
38 module: &Module<B>,
39 pt: &ScalarZnx<DataPt>,
40 sk: &GLWESecretPrepared<DataSk, B>,
41 source_xa: &mut Source,
42 source_xe: &mut Source,
43 scratch: &mut Scratch<B>,
44 ) where
45 Module<B>: VecZnxAddScalarInplace
46 + VecZnxDftAllocBytes
47 + VecZnxBigNormalize<B>
48 + VecZnxDftApply<B>
49 + SvpApplyDftToDftInplace<B>
50 + VecZnxIdftApplyConsume<B>
51 + VecZnxNormalizeTmpBytes
52 + VecZnxFillUniform
53 + VecZnxSubInplace
54 + VecZnxAddInplace
55 + VecZnxNormalizeInplace<B>
56 + VecZnxAddNormal
57 + VecZnxNormalize<B>
58 + VecZnxSub,
59 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
60 {
61 #[cfg(debug_assertions)]
62 {
63 use poulpy_hal::layouts::ZnxInfos;
64
65 assert_eq!(
66 self.rank_in(),
67 pt.cols() as u32,
68 "self.rank_in(): {} != pt.cols(): {}",
69 self.rank_in(),
70 pt.cols()
71 );
72 assert_eq!(
73 self.rank_out(),
74 sk.rank(),
75 "self.rank_out(): {} != sk.rank(): {}",
76 self.rank_out(),
77 sk.rank()
78 );
79 assert_eq!(self.n(), sk.n());
80 assert_eq!(pt.n() as u32, sk.n());
81 assert!(
82 scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self),
83 "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
84 scratch.available(),
85 self.rank_out(),
86 self.size(),
87 GGLWECiphertext::encrypt_sk_scratch_space(module, self)
88 );
89 assert!(
90 self.rows().0 * self.digits().0 * self.base2k().0 <= self.k().0,
91 "self.rows() : {} * self.digits() : {} * self.base2k() : {} = {} >= self.k() = {}",
92 self.rows(),
93 self.digits(),
94 self.base2k(),
95 self.rows().0 * self.digits().0 * self.base2k().0,
96 self.k()
97 );
98 }
99
100 let rows: usize = self.rows().into();
101 let digits: usize = self.digits().into();
102 let base2k: usize = self.base2k().into();
103 let rank_in: usize = self.rank_in().into();
104
105 let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self);
106 (0..rank_in).for_each(|col_i| {
118 (0..rows).for_each(|row_i| {
119 tmp_pt.data.zero(); module.vec_znx_add_scalar_inplace(
122 &mut tmp_pt.data,
123 0,
124 (digits - 1) + row_i * digits,
125 pt,
126 col_i,
127 );
128 module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
129
130 self.at_mut(row_i, col_i)
132 .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, scrach_1);
133 });
134 });
135 }
136}