1use poulpy_hal::{
2 api::{
3 ModuleN, ScratchTakeBasic, SvpApplyDftToDftInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigBytesOf,
4 VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf, VecZnxIdftApplyConsume, VecZnxNormalizeTmpBytes,
5 },
6 layouts::{Backend, DataRef, DataViewMut, Module, Scratch},
7};
8
9use crate::layouts::{
10 GLWE, GLWEInfos, GLWEPlaintext, GLWEPlaintextToMut, GLWEToRef, LWEInfos,
11 prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
12};
13
14impl GLWE<Vec<u8>> {
15 pub fn decrypt_tmp_bytes<A, M, BE: Backend>(module: &M, a_infos: &A) -> usize
16 where
17 A: GLWEInfos,
18 M: GLWEDecrypt<BE>,
19 {
20 module.glwe_decrypt_tmp_bytes(a_infos)
21 }
22}
23
24impl<DataSelf: DataRef> GLWE<DataSelf> {
25 pub fn decrypt<P, S, M, BE: Backend>(&self, module: &M, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
26 where
27 P: GLWEPlaintextToMut,
28 S: GLWESecretPreparedToRef<BE>,
29 M: GLWEDecrypt<BE>,
30 Scratch<BE>: ScratchTakeBasic,
31 {
32 module.glwe_decrypt(self, pt, sk, scratch);
33 }
34}
35
36pub trait GLWEDecrypt<BE: Backend>
37where
38 Self: Sized
39 + ModuleN
40 + VecZnxDftBytesOf
41 + VecZnxNormalizeTmpBytes
42 + VecZnxBigBytesOf
43 + VecZnxDftApply<BE>
44 + SvpApplyDftToDftInplace<BE>
45 + VecZnxIdftApplyConsume<BE>
46 + VecZnxBigAddInplace<BE>
47 + VecZnxBigAddSmallInplace<BE>
48 + VecZnxBigNormalize<BE>,
49{
50 fn glwe_decrypt_tmp_bytes<A>(&self, infos: &A) -> usize
51 where
52 A: GLWEInfos,
53 {
54 let size: usize = infos.size();
55 (self.vec_znx_normalize_tmp_bytes() | self.bytes_of_vec_znx_dft(1, size)) + self.bytes_of_vec_znx_dft(1, size)
56 }
57
58 fn glwe_decrypt<R, P, S>(&self, res: &R, pt: &mut P, sk: &S, scratch: &mut Scratch<BE>)
59 where
60 R: GLWEToRef,
61 P: GLWEPlaintextToMut,
62 S: GLWESecretPreparedToRef<BE>,
63 Scratch<BE>: ScratchTakeBasic,
64 {
65 let res: &GLWE<&[u8]> = &res.to_ref();
66 let pt: &mut GLWEPlaintext<&mut [u8]> = &mut pt.to_ref();
67 let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
68
69 #[cfg(debug_assertions)]
70 {
71 assert_eq!(res.rank(), sk.rank());
72 assert_eq!(res.n(), sk.n());
73 assert_eq!(pt.n(), sk.n());
74 }
75
76 let cols: usize = (res.rank() + 1).into();
77
78 let (mut c0_big, scratch_1) = scratch.take_vec_znx_big(self, 1, res.size()); c0_big.data_mut().fill(0);
80
81 {
82 (1..cols).for_each(|i| {
83 let (mut ci_dft, _) = scratch_1.take_vec_znx_dft(self, 1, res.size()); self.vec_znx_dft_apply(1, 0, &mut ci_dft, 0, &res.data, i);
86 self.svp_apply_dft_to_dft_inplace(&mut ci_dft, 0, &sk.data, i - 1);
87 let ci_big = self.vec_znx_idft_apply_consume(ci_dft);
88
89 self.vec_znx_big_add_inplace(&mut c0_big, 0, &ci_big, 0);
91 });
92 }
93
94 self.vec_znx_big_add_small_inplace(&mut c0_big, 0, &res.data, 0);
96
97 self.vec_znx_big_normalize(
99 res.base2k().into(),
100 &mut pt.data,
101 0,
102 res.base2k().into(),
103 &c0_big,
104 0,
105 scratch_1,
106 );
107
108 pt.base2k = res.base2k();
109 pt.k = pt.k().min(res.k());
110 }
111}
112
113impl<BE: Backend> GLWEDecrypt<BE> for Module<BE> where
114 Self: ModuleN
115 + VecZnxDftBytesOf
116 + VecZnxNormalizeTmpBytes
117 + VecZnxBigBytesOf
118 + VecZnxDftApply<BE>
119 + SvpApplyDftToDftInplace<BE>
120 + VecZnxIdftApplyConsume<BE>
121 + VecZnxBigAddInplace<BE>
122 + VecZnxBigAddSmallInplace<BE>
123 + VecZnxBigNormalize<BE>
124{
125}