poulpy_core/noise/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAddInplace,
4        VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes,
5        VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA,
6        VecZnxNormalizeTmpBytes, VecZnxSubInplace,
7    },
8    layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero},
9    oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
10};
11
12use crate::layouts::{
13    GGSWCiphertext, GGSWInfos, GLWECiphertext, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared,
14};
15
16impl<D: DataRef> GGSWCiphertext<D> {
17    pub fn assert_noise<B, DataSk, DataScalar, F>(
18        &self,
19        module: &Module<B>,
20        sk_prepared: &GLWESecretPrepared<DataSk, B>,
21        pt_want: &ScalarZnx<DataScalar>,
22        max_noise: F,
23    ) where
24        DataSk: DataRef,
25        DataScalar: DataRef,
26        Module<B>: VecZnxDftAllocBytes
27            + VecZnxBigAllocBytes
28            + VecZnxDftApply<B>
29            + SvpApplyDftToDftInplace<B>
30            + VecZnxIdftApplyConsume<B>
31            + VecZnxBigAddInplace<B>
32            + VecZnxBigAddSmallInplace<B>
33            + VecZnxBigNormalize<B>
34            + VecZnxNormalizeTmpBytes
35            + VecZnxBigAlloc<B>
36            + VecZnxDftAlloc<B>
37            + VecZnxBigNormalizeTmpBytes
38            + VecZnxIdftApplyTmpA<B>
39            + VecZnxAddScalarInplace
40            + VecZnxSubInplace,
41        B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
42        F: Fn(usize) -> f64,
43    {
44        let base2k: usize = self.base2k().into();
45        let digits: usize = self.digits().into();
46
47        let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
48        let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
49        let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
50        let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
51
52        let mut scratch: ScratchOwned<B> =
53            ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes());
54
55        (0..(self.rank() + 1).into()).for_each(|col_j| {
56            (0..self.rows().into()).for_each(|row_i| {
57                module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
58
59                // mul with sk[col_j-1]
60                if col_j > 0 {
61                    module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0);
62                    module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
63                    module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0);
64                    module.vec_znx_big_normalize(
65                        base2k,
66                        &mut pt.data,
67                        0,
68                        base2k,
69                        &pt_big,
70                        0,
71                        scratch.borrow(),
72                    );
73                }
74
75                self.at(row_i, col_j)
76                    .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
77
78                module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
79
80                let std_pt: f64 = pt_have.data.std(base2k, 0).log2();
81                let noise: f64 = max_noise(col_j);
82                assert!(std_pt <= noise, "{std_pt} > {noise}");
83
84                pt.data.zero();
85            });
86        });
87    }
88}
89
90impl<D: DataRef> GGSWCiphertext<D> {
91    pub fn print_noise<B, DataSk, DataScalar>(
92        &self,
93        module: &Module<B>,
94        sk_prepared: &GLWESecretPrepared<DataSk, B>,
95        pt_want: &ScalarZnx<DataScalar>,
96    ) where
97        DataSk: DataRef,
98        DataScalar: DataRef,
99        Module<B>: VecZnxDftAllocBytes
100            + VecZnxBigAllocBytes
101            + VecZnxDftApply<B>
102            + SvpApplyDftToDftInplace<B>
103            + VecZnxIdftApplyConsume<B>
104            + VecZnxBigAddInplace<B>
105            + VecZnxBigAddSmallInplace<B>
106            + VecZnxBigNormalize<B>
107            + VecZnxNormalizeTmpBytes
108            + VecZnxBigAlloc<B>
109            + VecZnxDftAlloc<B>
110            + VecZnxBigNormalizeTmpBytes
111            + VecZnxIdftApplyTmpA<B>
112            + VecZnxAddScalarInplace
113            + VecZnxSubInplace,
114        B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
115    {
116        let base2k: usize = self.base2k().into();
117        let digits: usize = self.digits().into();
118
119        let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
120        let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self);
121        let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
122        let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
123
124        let mut scratch: ScratchOwned<B> =
125            ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, self) | module.vec_znx_normalize_tmp_bytes());
126
127        (0..(self.rank() + 1).into()).for_each(|col_j| {
128            (0..self.rows().into()).for_each(|row_i| {
129                module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
130
131                // mul with sk[col_j-1]
132                if col_j > 0 {
133                    module.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0);
134                    module.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
135                    module.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0);
136                    module.vec_znx_big_normalize(
137                        base2k,
138                        &mut pt.data,
139                        0,
140                        base2k,
141                        &pt_big,
142                        0,
143                        scratch.borrow(),
144                    );
145                }
146
147                self.at(row_i, col_j)
148                    .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
149
150                module.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
151
152                let std_pt: f64 = pt_have.data.std(base2k, 0).log2();
153                println!("col: {col_j} row: {row_i}: {std_pt}");
154                pt.data.zero();
155            });
156        });
157    }
158}