poulpy_core/encryption/
gglwe_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, SvpApplyDftToDftInplace, TakeVecZnx, TakeVecZnxDft, VecZnxAddInplace, VecZnxAddNormal,
4        VecZnxAddScalarInplace, VecZnxBigNormalize, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxFillUniform,
5        VecZnxIdftApplyConsume, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes, VecZnxSub, VecZnxSubInplace,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, ScalarZnx, Scratch, ZnxZero},
8    source::Source,
9};
10
11use crate::{
12    TakeGLWEPt,
13    layouts::{GGLWECiphertext, GGLWELayoutInfos, GLWECiphertext, GLWEPlaintext, LWEInfos, prepared::GLWESecretPrepared},
14};
15
16impl GGLWECiphertext<Vec<u8>> {
17    pub fn encrypt_sk_scratch_space<B: Backend, A>(module: &Module<B>, infos: &A) -> usize
18    where
19        A: GGLWELayoutInfos,
20        Module<B>: VecZnxNormalizeTmpBytes + VecZnxDftAllocBytes + VecZnxNormalizeTmpBytes,
21    {
22        GLWECiphertext::encrypt_sk_scratch_space(module, &infos.glwe_layout())
23            + (GLWEPlaintext::alloc_bytes(&infos.glwe_layout()) | module.vec_znx_normalize_tmp_bytes())
24    }
25
26    pub fn encrypt_pk_scratch_space<B: Backend, A>(_module: &Module<B>, _infos: &A) -> usize
27    where
28        A: GGLWELayoutInfos,
29    {
30        unimplemented!()
31    }
32}
33
34impl<DataSelf: DataMut> GGLWECiphertext<DataSelf> {
35    #[allow(clippy::too_many_arguments)]
36    pub fn encrypt_sk<DataPt: DataRef, DataSk: DataRef, B: Backend>(
37        &mut self,
38        module: &Module<B>,
39        pt: &ScalarZnx<DataPt>,
40        sk: &GLWESecretPrepared<DataSk, B>,
41        source_xa: &mut Source,
42        source_xe: &mut Source,
43        scratch: &mut Scratch<B>,
44    ) where
45        Module<B>: VecZnxAddScalarInplace
46            + VecZnxDftAllocBytes
47            + VecZnxBigNormalize<B>
48            + VecZnxDftApply<B>
49            + SvpApplyDftToDftInplace<B>
50            + VecZnxIdftApplyConsume<B>
51            + VecZnxNormalizeTmpBytes
52            + VecZnxFillUniform
53            + VecZnxSubInplace
54            + VecZnxAddInplace
55            + VecZnxNormalizeInplace<B>
56            + VecZnxAddNormal
57            + VecZnxNormalize<B>
58            + VecZnxSub,
59        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
60    {
61        #[cfg(debug_assertions)]
62        {
63            use poulpy_hal::layouts::ZnxInfos;
64
65            assert_eq!(
66                self.rank_in(),
67                pt.cols() as u32,
68                "self.rank_in(): {} != pt.cols(): {}",
69                self.rank_in(),
70                pt.cols()
71            );
72            assert_eq!(
73                self.rank_out(),
74                sk.rank(),
75                "self.rank_out(): {} != sk.rank(): {}",
76                self.rank_out(),
77                sk.rank()
78            );
79            assert_eq!(self.n(), sk.n());
80            assert_eq!(pt.n() as u32, sk.n());
81            assert!(
82                scratch.available() >= GGLWECiphertext::encrypt_sk_scratch_space(module, self),
83                "scratch.available: {} < GGLWECiphertext::encrypt_sk_scratch_space(module, self.rank()={}, self.size()={}): {}",
84                scratch.available(),
85                self.rank_out(),
86                self.size(),
87                GGLWECiphertext::encrypt_sk_scratch_space(module, self)
88            );
89            assert!(
90                self.rows().0 * self.digits().0 * self.base2k().0 <= self.k().0,
91                "self.rows() : {} * self.digits() : {} * self.base2k() : {} = {} >= self.k() = {}",
92                self.rows(),
93                self.digits(),
94                self.base2k(),
95                self.rows().0 * self.digits().0 * self.base2k().0,
96                self.k()
97            );
98        }
99
100        let rows: usize = self.rows().into();
101        let digits: usize = self.digits().into();
102        let base2k: usize = self.base2k().into();
103        let rank_in: usize = self.rank_in().into();
104
105        let (mut tmp_pt, scrach_1) = scratch.take_glwe_pt(self);
106        // For each input column (i.e. rank) produces a GGLWE ciphertext of rank_out+1 columns
107        //
108        // Example for ksk rank 2 to rank 3:
109        //
110        // (-(a0*s0 + a1*s1 + a2*s2) + s0', a0, a1, a2)
111        // (-(b0*s0 + b1*s1 + b2*s2) + s0', b0, b1, b2)
112        //
113        // Example ksk rank 2 to rank 1
114        //
115        // (-(a*s) + s0, a)
116        // (-(b*s) + s1, b)
117        (0..rank_in).for_each(|col_i| {
118            (0..rows).for_each(|row_i| {
119                // Adds the scalar_znx_pt to the i-th limb of the vec_znx_pt
120                tmp_pt.data.zero(); // zeroes for next iteration
121                module.vec_znx_add_scalar_inplace(
122                    &mut tmp_pt.data,
123                    0,
124                    (digits - 1) + row_i * digits,
125                    pt,
126                    col_i,
127                );
128                module.vec_znx_normalize_inplace(base2k, &mut tmp_pt.data, 0, scrach_1);
129
130                // rlwe encrypt of vec_znx_pt into vec_znx_ct
131                self.at_mut(row_i, col_i)
132                    .encrypt_sk(module, &tmp_pt, sk, source_xa, source_xe, scrach_1);
133            });
134        });
135    }
136}