poulpy_core/encryption/
glwe_tensor_key.rs

1use poulpy_hal::{
2    api::{
3        ModuleN, ScratchTakeBasic, SvpApplyDftToDft, VecZnxBigBytesOf, VecZnxBigNormalize, VecZnxDftApply, VecZnxDftBytesOf,
4        VecZnxIdftApplyTmpA,
5    },
6    layouts::{Backend, DataMut, Module, Scratch},
7    source::Source,
8};
9
10use crate::{
11    GGLWEEncryptSk, GetDistribution, ScratchTakeCore,
12    layouts::{
13        GGLWE, GGLWEInfos, GLWEInfos, GLWESecret, GLWESecretToRef, GLWETensorKey, GLWETensorKeyToMut, LWEInfos, Rank,
14        prepared::{GLWESecretPrepared, GLWESecretPreparedFactory},
15    },
16};
17
18impl GLWETensorKey<Vec<u8>> {
19    pub fn encrypt_sk_tmp_bytes<M, A, BE: Backend>(module: &M, infos: &A) -> usize
20    where
21        A: GGLWEInfos,
22        M: GLWETensorKeyEncryptSk<BE>,
23    {
24        module.glwe_tensor_key_encrypt_sk_tmp_bytes(infos)
25    }
26}
27
28impl<DataSelf: DataMut> GLWETensorKey<DataSelf> {
29    pub fn encrypt_sk<M, S, BE: Backend>(
30        &mut self,
31        module: &M,
32        sk: &S,
33        source_xa: &mut Source,
34        source_xe: &mut Source,
35        scratch: &mut Scratch<BE>,
36    ) where
37        M: GLWETensorKeyEncryptSk<BE>,
38        S: GLWESecretToRef + GetDistribution + GLWEInfos,
39        Scratch<BE>: ScratchTakeCore<BE>,
40    {
41        module.glwe_tensor_key_encrypt_sk(self, sk, source_xa, source_xe, scratch);
42    }
43}
44
45pub trait GLWETensorKeyEncryptSk<BE: Backend> {
46    fn glwe_tensor_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
47    where
48        A: GGLWEInfos;
49
50    fn glwe_tensor_key_encrypt_sk<R, S>(
51        &self,
52        res: &mut R,
53        sk: &S,
54        source_xa: &mut Source,
55        source_xe: &mut Source,
56        scratch: &mut Scratch<BE>,
57    ) where
58        R: GLWETensorKeyToMut,
59        S: GLWESecretToRef + GetDistribution + GLWEInfos;
60}
61
62impl<BE: Backend> GLWETensorKeyEncryptSk<BE> for Module<BE>
63where
64    Self: ModuleN
65        + GGLWEEncryptSk<BE>
66        + VecZnxDftBytesOf
67        + VecZnxBigBytesOf
68        + GLWESecretPreparedFactory<BE>
69        + VecZnxDftApply<BE>
70        + SvpApplyDftToDft<BE>
71        + VecZnxIdftApplyTmpA<BE>
72        + VecZnxBigNormalize<BE>,
73    Scratch<BE>: ScratchTakeCore<BE>,
74{
75    fn glwe_tensor_key_encrypt_sk_tmp_bytes<A>(&self, infos: &A) -> usize
76    where
77        A: GGLWEInfos,
78    {
79        GLWESecretPrepared::bytes_of(self, infos.rank_out())
80            + self.bytes_of_vec_znx_dft(infos.rank_out().into(), 1)
81            + self.bytes_of_vec_znx_big(1, 1)
82            + self.bytes_of_vec_znx_dft(1, 1)
83            + GLWESecret::bytes_of(self.n().into(), Rank(1))
84            + GGLWE::encrypt_sk_tmp_bytes(self, infos)
85    }
86
87    fn glwe_tensor_key_encrypt_sk<R, S>(
88        &self,
89        res: &mut R,
90        sk: &S,
91        source_xa: &mut Source,
92        source_xe: &mut Source,
93        scratch: &mut Scratch<BE>,
94    ) where
95        R: GLWETensorKeyToMut,
96        S: GLWESecretToRef + GetDistribution + GLWEInfos,
97    {
98        let res: &mut GLWETensorKey<&mut [u8]> = &mut res.to_mut();
99
100        // let n: RingDegree = sk.n();
101        let rank: Rank = res.rank_out();
102
103        let (mut sk_prepared, scratch_1) = scratch.take_glwe_secret_prepared(self, sk.rank());
104        sk_prepared.prepare(self, sk);
105
106        let sk: &GLWESecret<&[u8]> = &sk.to_ref();
107
108        assert_eq!(res.rank_out(), sk.rank());
109        assert_eq!(res.n(), sk.n());
110
111        let (mut sk_dft, scratch_2) = scratch_1.take_vec_znx_dft(self, rank.into(), 1);
112
113        (0..rank.into()).for_each(|i| {
114            self.vec_znx_dft_apply(1, 0, &mut sk_dft, i, &sk.data.as_vec_znx(), i);
115        });
116
117        let (mut sk_ij_big, scratch_3) = scratch_2.take_vec_znx_big(self, 1, 1);
118        let (mut sk_ij, scratch_4) = scratch_3.take_glwe_secret(self.n().into(), Rank(1));
119        let (mut sk_ij_dft, scratch_5) = scratch_4.take_vec_znx_dft(self, 1, 1);
120
121        (0..rank.into()).for_each(|i| {
122            (i..rank.into()).for_each(|j| {
123                self.svp_apply_dft_to_dft(&mut sk_ij_dft, 0, &sk_prepared.data, j, &sk_dft, i);
124
125                self.vec_znx_idft_apply_tmpa(&mut sk_ij_big, 0, &mut sk_ij_dft, 0);
126                self.vec_znx_big_normalize(
127                    res.base2k().into(),
128                    &mut sk_ij.data.as_vec_znx_mut(),
129                    0,
130                    res.base2k().into(),
131                    &sk_ij_big,
132                    0,
133                    scratch_5,
134                );
135
136                res.at_mut(i, j).encrypt_sk(
137                    self,
138                    &sk_ij.data,
139                    &sk_prepared,
140                    source_xa,
141                    source_xe,
142                    scratch_5,
143                );
144            });
145        })
146    }
147}