poulpy_core/noise/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, ScratchOwnedAlloc, ScratchOwnedBorrow, SvpApplyInplace, VecZnxBigAddInplace, VecZnxBigAddSmallInplace,
4        VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxNormalizeTmpBytes, VecZnxSubScalarInplace,
5    },
6    layouts::{Backend, DataRef, Module, ScalarZnx, ScratchOwned, ZnxZero},
7    oep::{ScratchOwnedAllocImpl, ScratchOwnedBorrowImpl, TakeVecZnxBigImpl, TakeVecZnxDftImpl},
8};
9
10use crate::layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared};
11
12impl<D: DataRef> GGLWECiphertext<D> {
13    pub fn assert_noise<B, DataSk, DataWant>(
14        self,
15        module: &Module<B>,
16        sk: &GLWESecretPrepared<DataSk, B>,
17        pt_want: &ScalarZnx<DataWant>,
18        max_noise: f64,
19    ) where
20        DataSk: DataRef,
21        DataWant: DataRef,
22        Module<B>: VecZnxDftAllocBytes
23            + VecZnxBigAllocBytes
24            + DFT<B>
25            + SvpApplyInplace<B>
26            + IDFTConsume<B>
27            + VecZnxBigAddInplace<B>
28            + VecZnxBigAddSmallInplace<B>
29            + VecZnxBigNormalize<B>
30            + VecZnxNormalizeTmpBytes
31            + VecZnxSubScalarInplace,
32        B: Backend + TakeVecZnxDftImpl<B> + TakeVecZnxBigImpl<B> + ScratchOwnedAllocImpl<B> + ScratchOwnedBorrowImpl<B>,
33    {
34        let digits: usize = self.digits();
35        let basek: usize = self.basek();
36        let k: usize = self.k();
37
38        let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(GLWECiphertext::decrypt_scratch_space(module, basek, k));
39        let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc(self.n(), basek, k);
40
41        (0..self.rank_in()).for_each(|col_i| {
42            (0..self.rows()).for_each(|row_i| {
43                self.at(row_i, col_i)
44                    .decrypt(module, &mut pt, sk, scratch.borrow());
45
46                module.vec_znx_sub_scalar_inplace(
47                    &mut pt.data,
48                    0,
49                    (digits - 1) + row_i * digits,
50                    pt_want,
51                    col_i,
52                );
53
54                let noise_have: f64 = pt.data.std(basek, 0).log2();
55
56                assert!(
57                    noise_have <= max_noise,
58                    "noise_have: {} > max_noise: {}",
59                    noise_have,
60                    max_noise
61                );
62
63                pt.data.zero();
64            });
65        });
66    }
67}