1use poulpy_hal::{
2 api::{
3 ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4 VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5 VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
8 source::Source,
9};
10
11use crate::{
12 TakeGLWEPt,
13 layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGLWECiphertext<Vec<u8>> {
17 pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
18 where
19 Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
20 {
21 GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
22 + (GLWEPlaintext::byte_of(module.n(), basek, k) | module.vec_znx_normalize_tmp_bytes())
23 }
24
25 pub fn encrypt_pk_scratch_space<B: Backend>(_module: &Module<B>, _basek: usize, _k: usize, _rank: usize) -> usize {
26 unimplemented!()
27 }
28}
29
30impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
31 #[allow(clippy::too_many_arguments)]
32 pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
33 &mut self,
34 module: &Module<B>,
35 pt: &ScalarZnx<DataPt>,
36 sk: &GLWESecretPrepared<DataSk, B>,
37 source_xa: &mut Source,
38 source_xe: &mut Source,
39 scratch: &mut Scratch<B>,
40 ) where
41 Module<B>: VecZnxAddScalarInplace
42 + VecZnxDftAllocBytes
43 + VecZnxBigNormalize<B>
44 + VecZnxDftApply<B>
45 + SvpApplyDftToDftInplace<B>
46 + VecZnxIdftApplyConsume<B>
47 + VecZnxNormalizeTmpBytes
48 + VecZnxFillUniform
49 + VecZnxSubABInplace
50 + VecZnxAddInplace
51 + VecZnxNormalizeInplace<B>
52 + VecZnxAddNormal
53 + VecZnxNormalize<B>
54 + VecZnxSub,
55 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
56 {
57 #[cfg(debug_assertions)]
58 {
59 use poulpy_hal::layouts::ZnxInfos;
60
61 assert_eq!(
62 self.rank_in(),
63 pt.cols(),
64 "self.rank_in(): {} != pt.cols(): {}",
65 self.rank_in(),
66 pt.cols()
67 );
68 assert_eq!(
69 self.rank_out(),
70 sk.rank(),
71 "self.rank_out(): {} != sk.rank(): {}",
72 self.rank_out(),
73 sk.rank()
74 );
75 assert_eq!(self.n(), sk.n());
76 assert_eq!(pt.n(), sk.n());
77 assert!(
78 scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
79 "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
80 scratch.available(),
81 self.rank(),
82 self.size(),
83 GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
84 );
85 assert!(
86 self.rows() * self.digits() * self.basek() <= self.k(),
87 "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}",
88 self.rows(),
89 self.digits(),
90 self.basek(),
91 self.rows() * self.digits() * self.basek(),
92 self.k()
93 );
94 }
95
96 let rows: usize = self.rows();
97 let digits: usize = self.digits();
98 let basek: usize = self.basek();
99 let k: usize = self.k();
100 let rank_in: usize = self.rank_in();
101
102 let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k);
103 (0..rank_in).for_each(|col_i| {
115 (0..rows).for_each(|row_i| {
116 tmp_pt.data.zero(); module.vec_znx_add_scalar_inplace(
119 &mut tmp_pt.data,
120 0,
121 (digits - 1) + row_i * digits,
122 pt,
123 col_i,
124 );
125 module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1);
126
127 self.at_mut(row_i, col_i)
129 .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, scrach_1);
130 });
131 });
132 }
133}