poulpy_core/noise/
ggsw.rs

1use poulpy_hal::{
2    api::{
3        ScratchOwnedAlloc, ScratchOwnedBorrow, ScratchTakeBasic, SvpApplyDftToDftInplace, VecZnxAddScalarInplace, VecZnxBigAlloc,
4        VecZnxBigNormalize, VecZnxDftAlloc, VecZnxDftApply, VecZnxIdftApplyTmpA, VecZnxNormalizeTmpBytes, VecZnxSubInplace,
5    },
6    layouts::{Backend, DataRef, Module, ScalarZnxToRef, Scratch, ScratchOwned, VecZnxBig, VecZnxDft, ZnxZero},
7};
8
9use crate::decryption::GLWEDecrypt;
10use crate::layouts::prepared::GLWESecretPreparedToRef;
11use crate::layouts::{GGSW, GGSWInfos, GGSWToRef, GLWEInfos, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared};
12
13impl<D: DataRef> GGSW<D> {
14    pub fn assert_noise<M, BE: Backend, P, S, F>(&self, module: &M, sk_prepared: &S, pt_want: &P, max_noise: &F)
15    where
16        S: GLWESecretPreparedToRef<BE>,
17        P: ScalarZnxToRef,
18        M: GGSWNoise<BE>,
19        F: Fn(usize) -> f64,
20    {
21        module.ggsw_assert_noise(self, sk_prepared, pt_want, max_noise);
22    }
23
24    pub fn print_noise<M, BE: Backend, P, S>(&self, module: &M, sk_prepared: &S, pt_want: &P)
25    where
26        S: GLWESecretPreparedToRef<BE>,
27        P: ScalarZnxToRef,
28        M: GGSWNoise<BE>,
29    {
30        module.ggsw_print_noise(self, sk_prepared, pt_want);
31    }
32}
33
34pub trait GGSWNoise<BE: Backend> {
35    fn ggsw_assert_noise<R, S, P, F>(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: &F)
36    where
37        R: GGSWToRef,
38        S: GLWESecretPreparedToRef<BE>,
39        P: ScalarZnxToRef,
40        F: Fn(usize) -> f64;
41
42    fn ggsw_print_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P)
43    where
44        R: GGSWToRef,
45        S: GLWESecretPreparedToRef<BE>,
46        P: ScalarZnxToRef;
47}
48
49impl<BE: Backend> GGSWNoise<BE> for Module<BE>
50where
51    Module<BE>: GLWEDecrypt<BE>
52        + VecZnxDftAlloc<BE>
53        + VecZnxBigAlloc<BE>
54        + VecZnxAddScalarInplace
55        + VecZnxIdftApplyTmpA<BE>
56        + VecZnxSubInplace,
57    Scratch<BE>: ScratchTakeBasic,
58    ScratchOwned<BE>: ScratchOwnedBorrow<BE> + ScratchOwnedAlloc<BE>,
59{
60    fn ggsw_assert_noise<R, S, P, F>(&self, res: &R, sk_prepared: &S, pt_want: &P, max_noise: &F)
61    where
62        R: GGSWToRef,
63        S: GLWESecretPreparedToRef<BE>,
64        P: ScalarZnxToRef,
65        F: Fn(usize) -> f64,
66    {
67        let res: &GGSW<&[u8]> = &res.to_ref();
68        let sk_prepared: &GLWESecretPrepared<&[u8], BE> = &sk_prepared.to_ref();
69
70        let base2k: usize = res.base2k().into();
71        let dsize: usize = res.dsize().into();
72
73        let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(res);
74        let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(res);
75        let mut pt_dft: VecZnxDft<Vec<u8>, BE> = self.vec_znx_dft_alloc(1, res.size());
76        let mut pt_big: VecZnxBig<Vec<u8>, BE> = self.vec_znx_big_alloc(1, res.size());
77
78        let mut scratch: ScratchOwned<BE> =
79            ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res) | self.vec_znx_normalize_tmp_bytes());
80
81        (0..(res.rank() + 1).into()).for_each(|col_j| {
82            (0..res.dnum().into()).for_each(|row_i| {
83                self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0);
84
85                // mul with sk[col_j-1]
86                if col_j > 0 {
87                    self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0);
88                    self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
89                    self.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0);
90                    self.vec_znx_big_normalize(
91                        base2k,
92                        &mut pt.data,
93                        0,
94                        base2k,
95                        &pt_big,
96                        0,
97                        scratch.borrow(),
98                    );
99                }
100
101                self.glwe_decrypt(
102                    &res.at(row_i, col_j),
103                    &mut pt_have,
104                    sk_prepared,
105                    scratch.borrow(),
106                );
107
108                self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
109
110                let std_pt: f64 = pt_have.data.std(base2k, 0).log2();
111                let noise: f64 = max_noise(col_j);
112                assert!(std_pt <= noise, "{std_pt} > {noise}");
113
114                pt.data.zero();
115            });
116        });
117    }
118
119    fn ggsw_print_noise<R, S, P>(&self, res: &R, sk_prepared: &S, pt_want: &P)
120    where
121        R: GGSWToRef,
122        S: GLWESecretPreparedToRef<BE>,
123        P: ScalarZnxToRef,
124    {
125        let res: &GGSW<&[u8]> = &res.to_ref();
126        let sk_prepared: &GLWESecretPrepared<&[u8], BE> = &sk_prepared.to_ref();
127
128        let base2k: usize = res.base2k().into();
129        let dsize: usize = res.dsize().into();
130
131        let mut pt: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(res);
132        let mut pt_have: GLWEPlaintext<Vec<u8>> = GLWEPlaintext::alloc_from_infos(res);
133        let mut pt_dft: VecZnxDft<Vec<u8>, BE> = self.vec_znx_dft_alloc(1, res.size());
134        let mut pt_big: VecZnxBig<Vec<u8>, BE> = self.vec_znx_big_alloc(1, res.size());
135
136        let mut scratch: ScratchOwned<BE> =
137            ScratchOwned::alloc(self.glwe_decrypt_tmp_bytes(res) | self.vec_znx_normalize_tmp_bytes());
138
139        for col_j in 0..(res.rank() + 1).into() {
140            for row_i in 0..res.dnum().into() {
141                self.vec_znx_add_scalar_inplace(&mut pt.data, 0, (dsize - 1) + row_i * dsize, pt_want, 0);
142
143                // mul with sk[col_j-1]
144                if col_j > 0 {
145                    self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.data, 0);
146                    self.svp_apply_dft_to_dft_inplace(&mut pt_dft, 0, &sk_prepared.data, col_j - 1);
147                    self.vec_znx_idft_apply_tmpa(&mut pt_big, 0, &mut pt_dft, 0);
148                    self.vec_znx_big_normalize(
149                        base2k,
150                        &mut pt.data,
151                        0,
152                        base2k,
153                        &pt_big,
154                        0,
155                        scratch.borrow(),
156                    );
157                }
158
159                self.glwe_decrypt(
160                    &res.at(row_i, col_j),
161                    &mut pt_have,
162                    sk_prepared,
163                    scratch.borrow(),
164                );
165                self.vec_znx_sub_inplace(&mut pt_have.data, 0, &pt.data, 0);
166
167                let std_pt: f64 = pt_have.data.std(base2k, 0).log2();
168                println!("col: {col_j} row: {row_i}: {std_pt}");
169                pt.data.zero();
170            }
171        }
172    }
173}