Skip to main content

poulpy_core/default/noise/
ggsw.rs

1use poulpy_hal::{
2    api::{
3        ScratchArenaTakeBasic, SvpApplyDftToDftAssign, VecZnxAddScalarAssignBackend, VecZnxBigAddAssign, VecZnxBigBytesOf,
4        VecZnxBigFromSmallBackend, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftApply, VecZnxDftBytesOf,
5        VecZnxIdftApplyTmpA, VecZnxSubAssignBackend,
6    },
7    layouts::{
8        Backend, HostBackend, HostDataMut, HostDataRef, Module, ScalarZnx, ScalarZnxToBackendRef, ScratchArena, Stats,
9        VecZnxBigToBackendMut, VecZnxBigToBackendRef, VecZnxDftToBackendMut, ZnxZero,
10    },
11};
12
13use crate::layouts::{GGSW, GGSWInfos, GGSWToBackendRef, GLWEToBackendMut, GLWEToBackendRef, GLWEViewRef, LWEInfos};
14use crate::noise::glwe::glwe_noise_backend_inner;
15use crate::{
16    GLWENormalize,
17    api::{GGSWNoise, GLWENoise},
18    decryption::GLWEDecrypt,
19    layouts::prepared::GLWESecretPreparedToBackendRef,
20};
21use crate::{ScratchArenaTakeCore, layouts::GLWEPlaintext};
22
23impl<D: HostDataRef> GGSW<D> {
24    pub fn noise<M, BE, S>(
25        &self,
26        module: &M,
27        row: usize,
28        col: usize,
29        pt_want: &ScalarZnx<&[u8]>,
30        sk_prepared: &S,
31        scratch: &mut ScratchArena<'_, BE>,
32    ) -> Stats
33    where
34        GGSW<D>: GGSWToBackendRef<BE>,
35        S: GLWESecretPreparedToBackendRef<BE>,
36        M: GGSWNoise<BE>,
37        BE: HostBackend,
38        for<'a> BE::BufRef<'a>: HostDataRef,
39        for<'a> BE::BufMut<'a>: HostDataMut,
40    {
41        module.ggsw_noise(self, row, col, pt_want, sk_prepared, scratch)
42    }
43}
44
45impl<BE: Backend + HostBackend> GGSWNoise<BE> for Module<BE>
46where
47    Module<BE>: VecZnxAddScalarAssignBackend<BE>
48        + VecZnxDftApply<BE>
49        + SvpApplyDftToDftAssign<BE>
50        + VecZnxIdftApplyTmpA<BE>
51        + VecZnxBigBytesOf
52        + VecZnxDftBytesOf
53        + VecZnxBigFromSmallBackend<BE>
54        + VecZnxBigAddAssign<BE>
55        + VecZnxBigNormalize<BE>
56        + VecZnxBigNormalizeTmpBytes
57        + VecZnxSubAssignBackend<BE>
58        + GLWENoise<BE>
59        + GLWEDecrypt<BE>
60        + GLWENormalize<BE>,
61    for<'a> BE::BufRef<'a>: HostDataRef,
62    for<'a> BE::BufMut<'a>: HostDataMut,
63{
64    fn ggsw_noise_tmp_bytes<A>(&self, infos: &A) -> usize
65    where
66        A: GGSWInfos,
67    {
68        assert_eq!(self.n() as u32, infos.n());
69
70        let lvl_0: usize = GLWEPlaintext::<Vec<u8>>::bytes_of_from_infos(infos);
71        let lvl_1_glwe_noise: usize = self.glwe_noise_tmp_bytes(infos);
72        let lvl_1_mul: usize = self.bytes_of_vec_znx_dft(1, infos.size())
73            + self.bytes_of_vec_znx_big(1, infos.size())
74            + self.vec_znx_big_normalize_tmp_bytes();
75        let lvl_1: usize = lvl_1_glwe_noise.max(lvl_1_mul);
76
77        lvl_0 + lvl_1
78    }
79
80    fn ggsw_noise<R, S>(
81        &self,
82        res: &R,
83        res_row: usize,
84        res_col: usize,
85        pt_want: &ScalarZnx<&[u8]>,
86        sk_prepared: &S,
87        scratch: &mut ScratchArena<'_, BE>,
88    ) -> Stats
89    where
90        R: GGSWToBackendRef<BE> + GGSWInfos,
91        S: GLWESecretPreparedToBackendRef<BE>,
92        BE: HostBackend,
93        for<'a> BE::BufRef<'a>: HostDataRef,
94        for<'a> BE::BufMut<'a>: HostDataMut,
95    {
96        let res_backend = res.to_backend_ref();
97        let sk_backend = sk_prepared.to_backend_ref();
98
99        let base2k: usize = res_backend.base2k().into();
100        let dsize: usize = res_backend.dsize().into();
101        assert!(
102            scratch.available() >= self.ggsw_noise_tmp_bytes(res),
103            "scratch.available(): {} < GGSWNoise::ggsw_noise_tmp_bytes: {}",
104            scratch.available(),
105            self.ggsw_noise_tmp_bytes(res)
106        );
107
108        let (mut pt, mut scratch_1) = scratch.borrow().take_glwe_plaintext_scratch(&res_backend);
109        pt.data_mut().zero();
110        let pt_want_backend: ScalarZnx<BE::OwnedBuf> =
111            ScalarZnx::from_data(BE::from_host_bytes(pt_want.data), pt_want.n(), pt_want.cols());
112        {
113            let mut pt_backend = pt.to_backend_mut();
114            self.vec_znx_add_scalar_assign_backend(
115                &mut pt_backend.data,
116                0,
117                (dsize - 1) + res_row * dsize,
118                &<ScalarZnx<BE::OwnedBuf> as ScalarZnxToBackendRef<BE>>::to_backend_ref(&pt_want_backend),
119                0,
120            );
121        }
122
123        // mul with sk[col_j-1]
124        if res_col > 0 {
125            let scratch_mul = scratch_1.borrow();
126            let (mut pt_dft, scratch_2) = scratch_mul.take_vec_znx_dft_scratch(self, 1, res_backend.size());
127            self.vec_znx_dft_apply(1, 0, &mut pt_dft, 0, &pt.to_backend_ref().data, 0);
128            {
129                let mut pt_dft_backend = pt_dft.to_backend_mut();
130                self.svp_apply_dft_to_dft_assign(&mut pt_dft_backend, 0, &sk_backend.data, res_col - 1);
131            }
132            let (mut pt_big, mut scratch_3) = scratch_2.take_vec_znx_big_scratch(self, 1, res_backend.size());
133            {
134                let mut pt_big_backend = pt_big.to_backend_mut();
135                let mut pt_dft_backend = pt_dft.to_backend_mut();
136                self.vec_znx_idft_apply_tmpa(&mut pt_big_backend, 0, &mut pt_dft_backend, 0);
137            }
138            {
139                let mut pt_backend = pt.to_backend_mut();
140                self.vec_znx_big_normalize(
141                    &mut pt_backend.data,
142                    base2k,
143                    0,
144                    0,
145                    &pt_big.to_backend_ref(),
146                    base2k,
147                    0,
148                    &mut scratch_3,
149                );
150            }
151        }
152
153        let res_at_backend: GLWEViewRef<'_, BE> = res_backend.at_view(res_row, res_col);
154        let pt_backend = pt.to_backend_ref();
155        glwe_noise_backend_inner(self, &res_at_backend, &pt_backend, &sk_backend, &mut scratch_1)
156    }
157}