poulpy_core/noise/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, IDFTTmpA, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxAddScalarInplace,
4        VecZnxBigAddInplace, VecZnxBigAddSmallInplace, VecZnxBigAlloc, VecZnxBigAllocBytes, VecZnxBigNormalize,
5        VecZnxBigNormalizeTmpBytes, VecZnxDftAlloc, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, VecZnxSubABInplace,
6    },
7    layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero},
8    oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
9};
10
11use crate::layouts::{GGSWCiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared};
12
13impl<D: DataRef> GGSWCiphertext<D> {
14    pub fn assert_noise<B, DataSk, DataScalar, F>(
15        &self,
16        module: &Module<B>,
17        sk_prepared: &GLWESecretPrepared<DataSk, B>,
18        pt_want: &ScalarZnx<DataScalar>,
19        max_noise: F,
20    ) where
21        DataSk: DataRef,
22        DataScalar: DataRef,
23        Module<B>: VecZnxDftAllocBytes
24            + VecZnxBigAllocBytes
25            + DFT<B>
26            + SvpApplyInplace<B>
27            + IDFTConsume<B>
28            + VecZnxBigAddInplace<B>
29            + VecZnxBigAddSmallInplace<B>
30            + VecZnxBigNormalize<B>
31            + VecZnxNormalizeTmpBytes
32            + VecZnxBigAlloc<B>
33            + VecZnxDftAlloc<B>
34            + VecZnxBigNormalizeTmpBytes
35            + IDFTTmpA<B>
36            + VecZnxAddScalarInplace
37            + VecZnxSubABInplace,
38        B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
39        F: Fn(usize) -> f64,
40    {
41        let basek: usize = self.basek();
42        let k: usize = self.k();
43        let digits: usize = self.digits();
44
45        let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
46        let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
47        let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
48        let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
49
50        let mut scratch: ScratchOwned<B> =
51            ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes());
52
53        (0..self.rank() + 1).for_each(|col_j| {
54            (0..self.rows()).for_each(|row_i| {
55                module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
56
57                // mul with sk[col_j-1]
58                if col_j > 0 {
59                    module.dft(1, 0, &mut pt_dft, 0, &pt.data, 0);
60                    module.svp_apply_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
61                    module.idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
62                    module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow());
63                }
64
65                self.at(row_i, col_j)
66                    .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
67
68                module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0);
69
70                let std_pt: f64 = pt_have.data.std(basek, 0).log2();
71                let noise: f64 = max_noise(col_j);
72                println!("{} {}", std_pt, noise);
73                assert!(std_pt <= noise, "{} > {}", std_pt, noise);
74
75                pt.data.zero();
76            });
77        });
78    }
79}
80
81impl<D: DataRef> GGSWCiphertext<D> {
82    pub fn print_noise<B, DataSk, DataScalar>(
83        &self,
84        module: &Module<B>,
85        sk_prepared: &GLWESecretPrepared<DataSk, B>,
86        pt_want: &ScalarZnx<DataScalar>,
87    ) where
88        DataSk: DataRef,
89        DataScalar: DataRef,
90        Module<B>: VecZnxDftAllocBytes
91            + VecZnxBigAllocBytes
92            + DFT<B>
93            + SvpApplyInplace<B>
94            + IDFTConsume<B>
95            + VecZnxBigAddInplace<B>
96            + VecZnxBigAddSmallInplace<B>
97            + VecZnxBigNormalize<B>
98            + VecZnxNormalizeTmpBytes
99            + VecZnxBigAlloc<B>
100            + VecZnxDftAlloc<B>
101            + VecZnxBigNormalizeTmpBytes
102            + IDFTTmpA<B>
103            + VecZnxAddScalarInplace
104            + VecZnxSubABInplace,
105        B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
106    {
107        let basek: usize = self.basek();
108        let k: usize = self.k();
109        let digits: usize = self.digits();
110
111        let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
112        let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
113        let mut pt_dft: VecZnxDft<Vec<u8>, B> = module.vec_znx_dft_alloc(1, self.size());
114        let mut pt_big: VecZnxBig<Vec<u8>, B> = module.vec_znx_big_alloc(1, self.size());
115
116        let mut scratch: ScratchOwned<B> =
117            ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k) | module.vec_znx_normalize_tmp_bytes());
118
119        (0..self.rank() + 1).for_each(|col_j| {
120            (0..self.rows()).for_each(|row_i| {
121                module.vec_znx_add_scalar_inplace(&mut pt.data, 0, (digits - 1) + row_i * digits, pt_want, 0);
122
123                // mul with sk[col_j-1]
124                if col_j > 0 {
125                    module.dft(1, 0, &mut pt_dft, 0, &pt.data, 0);
126                    module.svp_apply_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
127                    module.idft_tmp_a(&mut pt_big, 0, &mut pt_dft, 0);
128                    module.vec_znx_big_normalize(basek, &mut pt.data, 0, &pt_big, 0, scratch.borrow());
129                }
130
131                self.at(row_i, col_j)
132                    .decrypt(module, &mut pt_have, sk_prepared, scratch.borrow());
133
134                module.vec_znx_sub_ab_inplace(&mut pt_have.data, 0, &pt.data, 0);
135
136                let std_pt: f64 = pt_have.data.std(basek, 0).log2();
137                println!("col: {} row: {}: {}", col_j, row_i, std_pt);
138                pt.data.zero();
139            });
140        });
141    }
142}