poulpy_core/encryption/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4        VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5        VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubABInplace,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
8    source::Source,
9};
10
11use crate::{
12    TakeGLWEPt,
13    layouts::{GGLWECiphertext, GLWECiphertext, GLWEPlaintext, Infos, prepared::GLWESecretPrepared},
14};
15
16impl GGLWECiphertext<Vec<u8>> {
17    pub fn encrypt_sk_scratch_space<B: Backend>(module: &Module<B>, basek: usize, k: usize) -> usize
18    where
19        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
20    {
21        GLWECiphertext::encrypt_sk_scratch_space(module, basek, k)
22            + (GLWEPlaintext::byte_of(module.n(), basek, k) | module.vec_znx_normalize_tmp_bytes())
23    }
24
25    pub fn encrypt_pk_scratch_space<B: Backend>(_module: &Module<B>, _basek: usize, _k: usize, _rank: usize) -> usize {
26        unimplemented!()
27    }
28}
29
30impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
31    #[allow(clippy::too_many_arguments)]
32    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
33        &mut self,
34        module: &Module<B>,
35        pt: &ScalarZnx<DataPt>,
36        sk: &GLWESecretPrepared<DataSk, B>,
37        source_xa: &mut Source,
38        source_xe: &mut Source,
39        scratch: &mut Scratch<B>,
40    ) where
41        Module<B>: VecZnxAddScalarInplace
42            + VecZnxDftAllocBytes
43            + VecZnxBigNormalize<B>
44            + VecZnxDftApply<B>
45            + SvpApplyDftToDftInplace<B>
46            + VecZnxIdftApplyConsume<B>
47            + VecZnxNormalizeTmpBytes
48            + VecZnxFillUniform
49            + VecZnxSubABInplace
50            + VecZnxAddInplace
51            + VecZnxNormalizeInplace<B>
52            + VecZnxAddNormal
53            + VecZnxNormalize<B>
54            + VecZnxSub,
55        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
56    {
57        #[cfg(debug_assertions)]
58        {
59            use poulpy_hal::layouts::ZnxInfos;
60
61            assert_eq!(
62                self.rank_in(),
63                pt.cols(),
64                "self.rank_in(): {} != pt.cols(): {}",
65                self.rank_in(),
66                pt.cols()
67            );
68            assert_eq!(
69                self.rank_out(),
70                sk.rank(),
71                "self.rank_out(): {} != sk.rank(): {}",
72                self.rank_out(),
73                sk.rank()
74            );
75            assert_eq!(self.n(), sk.n());
76            assert_eq!(pt.n(), sk.n());
77            assert!(
78                scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k()),
79                "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
80                scratch.available(),
81                self.rank(),
82                self.size(),
83                GGLWECiphertext::encrypt_sk_scratch_space(module, self.basek(), self.k())
84            );
85            assert!(
86                self.rows() * self.digits() * self.basek() <= self.k(),
87                "self.rows() : {} * self.digits() : {} * self.basek() : {} = {} >= self.k() = {}",
88                self.rows(),
89                self.digits(),
90                self.basek(),
91                self.rows() * self.digits() * self.basek(),
92                self.k()
93            );
94        }
95
96        let rows: usize = self.rows();
97        let digits: usize = self.digits();
98        let basek: usize = self.basek();
99        let k: usize = self.k();
100        let rank_in: usize = self.rank_in();
101
102        let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(sk.n(), basek, k);
103        // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
104        //
105        // Example for ksk rank 2 to rank 3:
106        //
107        // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
108        // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
109        //
110        // Example ksk rank 2 to rank 1
111        //
112        // (-(a*s) + s0, a)
113        // (-(b*s) + s1, b)
114        (0..rank_in).for_each(|col_i| {
115            (0..rows).for_each(|row_i| {
116                // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
117                tmp_pt.data.zero(); // zeroes for next iteration
118                module.vec_znx_add_scalar_inplace(
119                    &mut tmp_pt.data,
120                    0,
121                    (digits - 1) + row_i * digits,
122                    pt,
123                    col_i,
124                );
125                module.vec_znx_normalize_inplace(basek, &mut tmp_pt.data, 0, scrach_1);
126
127                // rlwe encrypt of vec_znx_pt into vec_znx_ct
128                self.at_mut(row_i, col_i)
129                    .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, scrach_1);
130            });
131        });
132    }
133}