poulpy_core/encryption/compressed/
gglwe.rs

1use poulpy_hal::{
2    api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
3    layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
4    source::Source,
5};
6
7use crate::{
8    ScratchTakeCore,
9    encryption::{GLWEEncryptSk, GLWEEncryptSkInternal, SIGMA},
10    layouts::{
11        GGLWECompressedSeedMut, GGLWEInfos, GLWEPlaintext, GLWESecretPrepared, LWEInfos,
12        compressed::{GGLWECompressed, GGLWECompressedToMut},
13        prepared::GLWESecretPreparedToRef,
14    },
15};
16
17impl<D: DataMut> GGLWECompressed<D> {
18    #[allow(clippy::too_many_arguments)]
19    pub fn encrypt_sk<M, P, S, BE: Backend>(
20        &mut self,
21        module: &M,
22        pt: &P,
23        sk: &S,
24        seed: [u8; 32],
25        source_xe: &mut Source,
26        scratch: &mut Scratch<BE>,
27    ) where
28        P: ScalarZnxToRef,
29        S: GLWESecretPreparedToRef<BE>,
30        M: GGLWECompressedEncryptSk<BE>,
31    {
32        module.gglwe_compressed_encrypt_sk(self, pt, sk, seed, source_xe, scratch);
33    }
34}
35
36impl GGLWECompressed<Vec<u8>> {
37    pub fn encrypt_sk_tmp_bytes<M, BE: Backend, A>(module: &M, infos: &A) -> usize
38    where
39        A: GGLWEInfos,
40        M: GGLWECompressedEncryptSk<BE>,
41    {
42        module.gglwe_compressed_encrypt_sk_tmp_bytes(infos)
43    }
44}
45
46pub trait GGLWECompressedEncryptSk<BE: Backend> {
47    fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
48    where
49        A: GGLWEInfos;
50
51    fn gglwe_compressed_encrypt_sk<R, P, S>(
52        &self,
53        res: &mut R,
54        pt: &P,
55        sk: &S,
56        seed: [u8; 32],
57        source_xe: &mut Source,
58        scratch: &mut Scratch<BE>,
59    ) where
60        R: GGLWECompressedToMut + GGLWECompressedSeedMut,
61        P: ScalarZnxToRef,
62        S: GLWESecretPreparedToRef<BE>;
63}
64
65impl<BE: Backend> GGLWECompressedEncryptSk<BE> for Module<BE>
66where
67    Self: ModuleN
68        + GLWEEncryptSkInternal<BE>
69        + GLWEEncryptSk<BE>
70        + VecZnxDftBytesOf
71        + VecZnxNormalizeInplace<BE>
72        + VecZnxAddScalarInplace
73        + VecZnxNormalizeTmpBytes,
74    Scratch<BE>: ScratchTakeCore<BE>,
75{
76    fn gglwe_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
77    where
78        A: GGLWEInfos,
79    {
80        self.glwe_encrypt_sk_tmp_bytes(infos)
81            .max(self.vec_znx_normalize_tmp_bytes())
82            + GLWEPlaintext::bytes_of_from_infos(infos)
83    }
84
85    fn gglwe_compressed_encrypt_sk<R, P, S>(
86        &self,
87        res: &mut R,
88        pt: &P,
89        sk: &S,
90        seed: [u8; 32],
91        source_xe: &mut Source,
92        scratch: &mut Scratch<BE>,
93    ) where
94        R: GGLWECompressedToMut + GGLWECompressedSeedMut,
95        P: ScalarZnxToRef,
96        S: GLWESecretPreparedToRef<BE>,
97    {
98        let mut seeds: Vec<[u8; 32]> = vec![[0u8; 32]; res.seed_mut().len()];
99
100        {
101            let res: &mut GGLWECompressed<&mut [u8]> = &mut res.to_mut();
102            let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
103            let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
104
105            assert_eq!(
106                res.rank_in(),
107                pt.cols() as u32,
108                "res.rank_in(): {} != pt.cols(): {}",
109                res.rank_in(),
110                pt.cols()
111            );
112            assert_eq!(
113                res.rank_out(),
114                sk.rank(),
115                "res.rank_out(): {} != sk.rank(): {}",
116                res.rank_out(),
117                sk.rank()
118            );
119            assert_eq!(res.n(), sk.n());
120            assert_eq!(pt.n() as u32, sk.n());
121            assert!(
122                scratch.available() >= GGLWECompressed::encrypt_sk_tmp_bytes(self, res),
123                "scratch.available: {} < GGLWECiphertext::encrypt_sk_tmp_bytes: {}",
124                scratch.available(),
125                GGLWECompressed::encrypt_sk_tmp_bytes(self, res)
126            );
127            assert!(
128                res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0,
129                "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}",
130                res.dnum(),
131                res.dsize(),
132                res.base2k(),
133                res.dnum().0 * res.dsize().0 * res.base2k().0,
134                res.k()
135            );
136
137            let dnum: usize = res.dnum().into();
138            let dsize: usize = res.dsize().into();
139            let base2k: usize = res.base2k().into();
140            let rank_in: usize = res.rank_in().into();
141            let cols: usize = (res.rank_out() + 1).into();
142
143            let mut source_xa = Source::new(seed);
144
145            let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res);
146            for col_i in 0..rank_in {
147                for d_i in 0..dnum {
148                    // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
149                    tmp_pt.data.zero(); // zeroes for next iteration
150                    self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + d_i * dsize, pt, col_i);
151                    self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
152
153                    let (seed, mut source_xa_tmp) = source_xa.branch();
154                    seeds[col_i * dnum + d_i] = seed;
155
156                    self.glwe_encrypt_sk_internal(
157                        res.base2k().into(),
158                        res.k().into(),
159                        &mut res.at_mut(d_i, col_i).data,
160                        cols,
161                        true,
162                        Some((&tmp_pt, 0)),
163                        sk,
164                        &mut source_xa_tmp,
165                        source_xe,
166                        SIGMA,
167                        scrach_1,
168                    );
169                }
170            }
171        }
172
173        res.seed_mut().copy_from_slice(&seeds);
174    }
175}