poulpy_core/encryption/compressed/
ggsw.rs

1use poulpy_hal::{
2    api::{ModuleN, VecZnxAddScalarInplace, VecZnxNormalizeInplace},
3    layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
4    source::Source,
5};
6
7use crate::{
8    ScratchTakeCore,
9    encryption::{GGSWEncryptSk, GLWEEncryptSkInternal, SIGMA},
10    layouts::{
11        GGSWCompressedSeedMut, GGSWInfos, LWEInfos,
12        compressed::{GGSWCompressed, GGSWCompressedToMut},
13        prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
14    },
15};
16
17impl GGSWCompressed<Vec<u8>> {
18    pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
19    where
20        A: GGSWInfos,
21        M: GGSWCompressedEncryptSk<BE>,
22    {
23        module.ggsw_compressed_encrypt_sk_tmp_bytes(infos)
24    }
25}
26
27impl<DataSelf: DataMut> GGSWCompressed<DataSelf> {
28    #[allow(clippy::too_many_arguments)]
29    pub fn encrypt_sk<P, S, M, BE: Backend>(
30        &mut self,
31        module: &M,
32        pt: &P,
33        sk: &S,
34        seed_xa: [u8; 32],
35        source_xe: &mut Source,
36        scratch: &mut Scratch<BE>,
37    ) where
38        P: ScalarZnxToRef,
39        S: GLWESecretPreparedToRef<BE>,
40        M: GGSWCompressedEncryptSk<BE>,
41    {
42        module.ggsw_compressed_encrypt_sk(self, pt, sk, seed_xa, source_xe, scratch);
43    }
44}
45
46pub trait GGSWCompressedEncryptSk<BE: Backend> {
47    fn ggsw_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
48    where
49        A: GGSWInfos;
50
51    fn ggsw_compressed_encrypt_sk<R, P, S>(
52        &self,
53        res: &mut R,
54        pt: &P,
55        sk: &S,
56        seed_xa: [u8; 32],
57        source_xe: &mut Source,
58        scratch: &mut Scratch<BE>,
59    ) where
60        R: GGSWCompressedToMut + GGSWCompressedSeedMut + GGSWInfos,
61        P: ScalarZnxToRef,
62        S: GLWESecretPreparedToRef<BE>;
63}
64
65impl<BE: Backend> GGSWCompressedEncryptSk<BE> for Module<BE>
66where
67    Self: ModuleN + GLWEEncryptSkInternal<BE> + GGSWEncryptSk<BE> + VecZnxAddScalarInplace + VecZnxNormalizeInplace<BE>,
68    Scratch<BE>: ScratchTakeCore<BE>,
69{
70    fn ggsw_compressed_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
71    where
72        A: GGSWInfos,
73    {
74        self.ggsw_encrypt_sk_tmp_bytes(infos)
75    }
76
77    fn ggsw_compressed_encrypt_sk<R, P, S>(
78        &self,
79        res: &mut R,
80        pt: &P,
81        sk: &S,
82        seed_xa: [u8; 32],
83        source_xe: &mut Source,
84        scratch: &mut Scratch<BE>,
85    ) where
86        R: GGSWCompressedToMut + GGSWCompressedSeedMut + GGSWInfos,
87        P: ScalarZnxToRef,
88        S: GLWESecretPreparedToRef<BE>,
89    {
90        let base2k: usize = res.base2k().into();
91        let rank: usize = res.rank().into();
92        let cols: usize = rank + 1;
93        let dsize: usize = res.dsize().into();
94
95        let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
96        let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
97
98        assert_eq!(res.rank(), sk.rank());
99        assert_eq!(pt.n(), self.n());
100        assert_eq!(res.n(), self.n() as u32);
101        assert_eq!(sk.n(), self.n() as u32);
102
103        let mut seeds: Vec<[u8; 32]> = vec![[0u8; 32]; res.dnum().as_usize() * (res.rank().as_usize() + 1)];
104
105        {
106            let res: &mut GGSWCompressed<&mut [u8]> = &mut res.to_mut();
107
108            println!("res.seed: {:?}", res.seed);
109
110            let (mut tmp_pt, scratch_1) = scratch.take_glwe_plaintext(res);
111
112            let mut source = Source::new(seed_xa);
113
114            for row_i in 0..res.dnum().into() {
115                tmp_pt.data.zero();
116
117                // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
118                self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, 0);
119                self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scratch_1);
120
121                for col_j in 0..rank + 1 {
122                    // rlwe encrypt of vec_znx_pt into vec_znx_ct
123
124                    let (seed, mut source_xa_tmp) = source.branch();
125
126                    seeds[row_i * cols + col_j] = seed;
127
128                    self.glwe_encrypt_sk_internal(
129                        res.base2k().into(),
130                        res.k().into(),
131                        &mut res.at_mut(row_i, col_j).data,
132                        cols,
133                        true,
134                        Some((&tmp_pt, col_j)),
135                        sk,
136                        &mut source_xa_tmp,
137                        source_xe,
138                        SIGMA,
139                        scratch_1,
140                    );
141                }
142            }
143        }
144
145        res.seed_mut().copy_from_slice(&seeds);
146    }
147}