1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxAddScalarInplace,
4 VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize,
5 VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, VecZnxSubABInplace,
6 },
7 layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero},
8 oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
9};
10
11use crate::layouts::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared};
12
13impl<D: DataRef> GGSWCiphertext<D> {
14 pub fn assert_noise<B, DataSk, DataScalar, F>(
15 &self,
16 module: &Module<B>,
17 sk_prepared: &GLWESecretPrepared<DataSk, B>,
18 pt_want: &ScalarZnx<DataScalar>,
19 max_noise: F,
20 ) where
21 DataSk: DataRef,
22 DataScalar: DataRef,
23 Module<B>: VecZnxDftAllocBytes
24 + VecZnxBigAllocBytes
25 + DFT<B>
26 + SvpApplyInplace<B>
27 + IDFTConsume<B>
28 + VecZnxBigAddInplace<B>
29 + VecZnxBigAddSmallInplace<B>
30 + VecZnxBigNormalize<B>
31 + VecZnxNormalizeTmpBytes
32 + VecZnxBigAlloc<B>
33 + VecZnxDftAlloc<B>
34 + VecZnxBigNormalizeTmpBytes
35 + IDFTTmpA<B>
36 + VecZnxAddScalarInplace
37 + VecZnxSubABInplace,
38 B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
39 F: Fn(usize) -> f64,
40 {
41 let basek: usize = self.basek();
42 let k: usize = self.k();
43 let digits: usize = self.digits();
44
45 let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
46 let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
47 let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
48 let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
49
50 let mut scratch: ScratchOwned<B> =
51 ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes());
52
53 (0..self.rank() + 1).for_each(|col_j| {
54 (0..self.rows()).for_each(|row_i| {
55 module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
56
57 if col_j > 0 {
59 module.dft(1, 0, &mut pt_dft, 0, &pt.data, 0);
60 module.svp_apply_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
61 module.idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
62 module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow());
63 }
64
65 self.at(row_i, col_j)
66 .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
67
68 module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0);
69
70 let std_pt: f64 = pt_have.data.std(basek, 0).log2();
71 let noise: f64 = max_noise(col_j);
72 println!("{} {}", std_pt, noise);
73 assert!(std_pt <= noise, "{} > {}", std_pt, noise);
74
75 pt.data.zero();
76 });
77 });
78 }
79}
80
81impl<D: DataRef> GGSWCiphertext<D> {
82 pub fn print_noise<B, DataSk, DataScalar>(
83 &self,
84 module: &Module<B>,
85 sk_prepared: &GLWESecretPrepared<DataSk, B>,
86 pt_want: &ScalarZnx<DataScalar>,
87 ) where
88 DataSk: DataRef,
89 DataScalar: DataRef,
90 Module<B>: VecZnxDftAllocBytes
91 + VecZnxBigAllocBytes
92 + DFT<B>
93 + SvpApplyInplace<B>
94 + IDFTConsume<B>
95 + VecZnxBigAddInplace<B>
96 + VecZnxBigAddSmallInplace<B>
97 + VecZnxBigNormalize<B>
98 + VecZnxNormalizeTmpBytes
99 + VecZnxBigAlloc<B>
100 + VecZnxDftAlloc<B>
101 + VecZnxBigNormalizeTmpBytes
102 + IDFTTmpA<B>
103 + VecZnxAddScalarInplace
104 + VecZnxSubABInplace,
105 B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
106 {
107 let basek: usize = self.basek();
108 let k: usize = self.k();
109 let digits: usize = self.digits();
110
111 let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
112 let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
113 let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
114 let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
115
116 let mut scratch: ScratchOwned<B> =
117 ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes());
118
119 (0..self.rank() + 1).for_each(|col_j| {
120 (0..self.rows()).for_each(|row_i| {
121 module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
122
123 if col_j > 0 {
125 module.dft(1, 0, &mut pt_dft, 0, &pt.data, 0);
126 module.svp_apply_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
127 module.idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
128 module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow());
129 }
130
131 self.at(row_i, col_j)
132 .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
133
134 module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0);
135
136 let std_pt: f64 = pt_have.data.std(basek, 0).log2();
137 println!("col: {} row: {}: {}", col_j, row_i, std_pt);
138 pt.data.zero();
139 });
140 });
141 }
142}