poulpy_core/encryption/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal, VecZnxAddScalarInplace,
4        VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxFillUniform,
5        VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace, ZnxZero,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch},
8    source::Source,
9};
10
11use crate::{
12    TakeGLWEPt,
13    layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGLWECiphertext<Vec<u8>> {
17    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, n: usize, basek: usize, k: usize) -> usize
18    where
19        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
20    {
21        GLWECiphertext::encrypt_sk_scratch_space(module, n, basek, k)
22            + (GLWEPlaintext::byte_of(n, basek, k) | module.vec_znx_normalize_tmp_bytes(n))
23    }
24
25    pub fn encrypt_pk_scratch_space<B: Backend>(_module: &Module<B>, _n: usize, _basek: usize, _k: usize, _rank: usize) -> usize {
26        unimplemented!()
27    }
28}
29
30impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
31    #[allow(clippy::too_many_arguments)]
32    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
33        &mut self,
34        module: &Module<B>,
35        pt: &ScalarZnx<DataPt>,
36        sk: &GLWESecretPrepared<DataSk, B>,
37        source_xa: &mut Source,
38        source_xe: &mut Source,
39        sigma: f64,
40        scratch: &mut Scratch<B>,
41    ) where
42        Module<B>: VecZnxAddScalarInplace
43            + VecZnxDftAllocBytes
44            + VecZnxBigNormalize<B>
45            + VecZnxDftFromVecZnx<B>
46            + SvpApplyInplace<B>
47            + VecZnxDftToVecZnxBigConsume<B>
48            + VecZnxNormalizeTmpBytes
49            + VecZnxFillUniform
50            + VecZnxSubABInplace
51            + VecZnxAddInplace
52            + VecZnxNormalizeInplace<B>
53            + VecZnxAddNormal
54            + VecZnxNormalize<B>
55            + VecZnxSub,
56        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
57    {
58        #[cfg(debug_assertions)]
59        {
60            use poulpy_hal::api::ZnxInfos;
61
62            assert_eq!(
63                self.rank_in(),
64                pt.cols(),
65                "self.rank_in(): {} != pt.cols(): {}",
66                self.rank_in(),
67                pt.cols()
68            );
69            assert_eq!(
70                self.rank_out(),
71                sk.rank(),
72                "self.rank_out(): {} != sk.rank(): {}",
73                self.rank_out(),
74                sk.rank()
75            );
76            assert_eq!(self.n(), sk.n());
77            assert_eq!(pt.n(), sk.n());
78            assert!(
79                scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k()),
80                "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
81                scratch.available(),
82                self.rank(),
83                self.size(),
84                GGLWECiphertext::encrypt_sk_scratch_space(module, sk.n(), self.basek(), self.k())
85            );
86            assert!(
87                self.rows() * self.digits() * self.basek() <= self.k(),
88                "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}",
89                self.rows(),
90                self.digits(),
91                self.basek(),
92                self.rows() * self.digits() * self.basek(),
93                self.k()
94            );
95        }
96
97        let rows: usize = self.rows();
98        let digits: usize = self.digits();
99        let basek: usize = self.basek();
100        let k: usize = self.k();
101        let rank_in: usize = self.rank_in();
102
103        let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k);
104        // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
105        //
106        // Example for ksk rank 2 to rank 3:
107        //
108        // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
109        // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
110        //
111        // Example ksk rank 2 to rank 1
112        //
113        // (-(a*s) + s0, a)
114        // (-(b*s) + s1, b)
115        (0..rank_in).for_each(|col_i| {
116            (0..rows).for_each(|row_i| {
117                // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
118                tmp_pt.data.zero(); // zeroes for next iteration
119                module.vec_znx_add_scalar_inplace(
120                    &mut tmp_pt.data,
121                    0,
122                    (digits - 1) + row_i * digits,
123                    pt,
124                    col_i,
125                );
126                module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1);
127
128                // rlwe encrypt of vec_znx_pt into vec_znx_ct
129                self.at_mut(row_i, col_i)
130                    .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, sigma, scrach_1);
131            });
132        });
133    }
134}