poulpy_core/encryption/
gglwe.rs

1use poulpy_hal::{
2    api::{ModuleN, ScratchAvailable, VecZnxAddScalarInplace, VecZnxDftBytesOf, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
3    layouts::{Backend, DataMut, Module, ScalarZnx, ScalarZnxToRef, Scratch, ZnxInfos, ZnxZero},
4    source::Source,
5};
6
7use crate::{
8    GLWEEncryptSk, ScratchTakeCore,
9    layouts::{
10        GGLWE, GGLWEInfos, GGLWEToMut, GLWEPlaintext, LWEInfos,
11        prepared::{GLWESecretPrepared, GLWESecretPreparedToRef},
12    },
13};
14
15impl GGLWE<Vec<u8>> {
16    pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
17    where
18        A: GGLWEInfos,
19        M: GGLWEEncryptSk<BE>,
20    {
21        module.gglwe_encrypt_sk_tmp_bytes(infos)
22    }
23
24    pub fn encrypt_pk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
25    where
26        A: GGLWEInfos,
27        M: GGLWEEncryptSk<BE>,
28    {
29        module.gglwe_encrypt_sk_tmp_bytes(infos)
30    }
31}
32
33impl<D: DataMut> GGLWE<D> {
34    #[allow(clippy::too_many_arguments)]
35    pub fn encrypt_sk<P, S, M, BE: Backend>(
36        &mut self,
37        module: &M,
38        pt: &P,
39        sk: &S,
40        source_xa: &mut Source,
41        source_xe: &mut Source,
42        scratch: &mut Scratch<BE>,
43    ) where
44        P: ScalarZnxToRef,
45        S: GLWESecretPreparedToRef<BE>,
46        M: GGLWEEncryptSk<BE>,
47        Scratch<BE>: ScratchTakeCore<BE>,
48    {
49        module.gglwe_encrypt_sk(self, pt, sk, source_xa, source_xe, scratch);
50    }
51}
52
53pub trait GGLWEEncryptSk<BE: Backend> {
54    fn gglwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
55    where
56        A: GGLWEInfos;
57
58    fn gglwe_encrypt_sk<R, P, S>(
59        &self,
60        res: &mut R,
61        pt: &P,
62        sk: &S,
63        source_xa: &mut Source,
64        source_xe: &mut Source,
65        scratch: &mut Scratch<BE>,
66    ) where
67        R: GGLWEToMut,
68        P: ScalarZnxToRef,
69        S: GLWESecretPreparedToRef<BE>;
70}
71
72impl<BE: Backend> GGLWEEncryptSk<BE> for Module<BE>
73where
74    Self: ModuleN
75        + GLWEEncryptSk<BE>
76        + VecZnxNormalizeTmpBytes
77        + VecZnxDftBytesOf
78        + VecZnxAddScalarInplace
79        + VecZnxNormalizeInplace<BE>,
80    Scratch<BE>: ScratchTakeCore<BE>,
81{
82    fn gglwe_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
83    where
84        A: GGLWEInfos,
85    {
86        self.glwe_encrypt_sk_tmp_bytes(infos) + GLWEPlaintext::bytes_of_from_infos(infos).max(self.vec_znx_normalize_tmp_bytes())
87    }
88
89    fn gglwe_encrypt_sk<R, P, S>(
90        &self,
91        res: &mut R,
92        pt: &P,
93        sk: &S,
94        source_xa: &mut Source,
95        source_xe: &mut Source,
96        scratch: &mut Scratch<BE>,
97    ) where
98        R: GGLWEToMut,
99        P: ScalarZnxToRef,
100        S: GLWESecretPreparedToRef<BE>,
101    {
102        let res: &mut GGLWE<&mut [u8]> = &mut res.to_mut();
103        let pt: &ScalarZnx<&[u8]> = &pt.to_ref();
104        let sk: &GLWESecretPrepared<&[u8], BE> = &sk.to_ref();
105
106        assert_eq!(
107            res.rank_in(),
108            pt.cols() as u32,
109            "res.rank_in(): {} != pt.cols(): {}",
110            res.rank_in(),
111            pt.cols()
112        );
113        assert_eq!(
114            res.rank_out(),
115            sk.rank(),
116            "res.rank_out(): {} != sk.rank(): {}",
117            res.rank_out(),
118            sk.rank()
119        );
120        assert_eq!(res.n(), sk.n());
121        assert_eq!(pt.n() as u32, sk.n());
122        assert!(
123            scratch.available() >= self.gglwe_encrypt_sk_tmp_bytes(res),
124            "scratch.available: {} < GGLWE::encrypt_sk_tmp_bytes(self, res.rank()={}, res.size()={}): {}",
125            scratch.available(),
126            res.rank_out(),
127            res.size(),
128            self.gglwe_encrypt_sk_tmp_bytes(res)
129        );
130        assert!(
131            res.dnum().0 * res.dsize().0 * res.base2k().0 <= res.k().0,
132            "res.dnum() : {} * res.dsize() : {} * res.base2k() : {} = {} >= res.k() = {}",
133            res.dnum(),
134            res.dsize(),
135            res.base2k(),
136            res.dnum().0 * res.dsize().0 * res.base2k().0,
137            res.k()
138        );
139
140        let dnum: usize = res.dnum().into();
141        let dsize: usize = res.dsize().into();
142        let base2k: usize = res.base2k().into();
143        let rank_in: usize = res.rank_in().into();
144
145        let (mut tmp_pt, scrach_1) = scratch.take_glwe_plaintext(res);
146        // For each input column (i.e. rank) produces a GGLWE of rank_out+1 columns
147        //
148        // Example for ksk rank 2 to rank 3:
149        //
150        // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
151        // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
152        //
153        // Example ksk rank 2 to rank 1
154        //
155        // (-(a*s) + s0, a)
156        // (-(b*s) + s1, b)
157        for col_i in 0..rank_in {
158            for row_i in 0..dnum {
159                // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
160                tmp_pt.data.zero(); // zeroes for next iteration
161                self.vec_znx_add_scalar_inplace(&mut tmp_pt.data, 0, (dsize - 1) + row_i * dsize, pt, col_i);
162                self.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
163                self.glwe_encrypt_sk(
164                    &mut res.at_mut(row_i, col_i),
165                    &tmp_pt,
166                    sk,
167                    source_xa,
168                    source_xe,
169                    scrach_1,
170                );
171            }
172        }
173    }
174}