poulpy_core/automorphism/
gglwe_atk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4        VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
5        VecZnxNormalize, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{GGLWEAutomorphismKey, GGLWELayoutInfos, GLWECiphertext, prepared::GGLWEAutomorphismKeyPrepared};
11
12impl GGLWEAutomorphismKey<Vec<u8>> {
13    pub fn automorphism_scratch_space<B: Backend, OUT, IN, KEY>(
14        module: &Module<B>,
15        out_infos: &OUT,
16        in_infos: &IN,
17        key_infos: &KEY,
18    ) -> usize
19    where
20        OUT: GGLWELayoutInfos,
21        IN: GGLWELayoutInfos,
22        KEY: GGLWELayoutInfos,
23        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
24    {
25        GLWECiphertext::keyswitch_scratch_space(
26            module,
27            &out_infos.glwe_layout(),
28            &in_infos.glwe_layout(),
29            key_infos,
30        )
31    }
32
33    pub fn automorphism_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
34    where
35        OUT: GGLWELayoutInfos,
36        KEY: GGLWELayoutInfos,
37        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
38    {
39        GGLWEAutomorphismKey::automorphism_scratch_space(module, out_infos, out_infos, key_infos)
40    }
41}
42
43impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
44    pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
45        &mut self,
46        module: &Module<B>,
47        lhs: &GGLWEAutomorphismKey<DataLhs>,
48        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
49        scratch: &mut Scratch<B>,
50    ) where
51        Module<B>: VecZnxDftAllocBytes
52            + VmpApplyDftToDftTmpBytes
53            + VecZnxBigNormalizeTmpBytes
54            + VmpApplyDftToDft<B>
55            + VmpApplyDftToDftAdd<B>
56            + VecZnxDftApply<B>
57            + VecZnxIdftApplyConsume<B>
58            + VecZnxBigAddSmallInplace<B>
59            + VecZnxBigNormalize<B>
60            + VecZnxAutomorphism
61            + VecZnxAutomorphismInplace<B>
62            + VecZnxNormalize<B>
63            + VecZnxNormalizeTmpBytes,
64        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
65    {
66        #[cfg(debug_assertions)]
67        {
68            use crate::layouts::LWEInfos;
69
70            assert_eq!(
71                self.rank_in(),
72                lhs.rank_in(),
73                "ksk_out input rank: {} != ksk_in input rank: {}",
74                self.rank_in(),
75                lhs.rank_in()
76            );
77            assert_eq!(
78                self.rank_out(),
79                rhs.rank_in(),
80                "ksk_in output rank: {} != ksk_apply input rank: {}",
81                self.rank_out(),
82                rhs.rank_in()
83            );
84            assert_eq!(
85                self.rank_out(),
86                rhs.rank_out(),
87                "ksk_out output rank: {} != ksk_apply output rank: {}",
88                self.rank_out(),
89                rhs.rank_out()
90            );
91            assert!(
92                self.k() <= lhs.k(),
93                "output k={} cannot be greater than input k={}",
94                self.k(),
95                lhs.k()
96            )
97        }
98
99        let cols_out: usize = (rhs.rank_out() + 1).into();
100
101        let p: i64 = lhs.p();
102        let p_inv: i64 = module.galois_element_inv(p);
103
104        (0..self.rank_in().into()).for_each(|col_i| {
105            (0..self.rows().into()).for_each(|row_j| {
106                let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
107                let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i);
108
109                // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
110                (0..cols_out).for_each(|i| {
111                    module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i);
112                });
113
114                // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
115                res_ct.keyswitch_inplace(module, &rhs.key, scratch);
116
117                // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
118                (0..cols_out).for_each(|i| {
119                    module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
120                });
121            });
122        });
123
124        (self.rows().min(lhs.rows()).into()..self.rows().into()).for_each(|row_i| {
125            (0..self.rank_in().into()).for_each(|col_j| {
126                self.at_mut(row_i, col_j).data.zero();
127            });
128        });
129
130        self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
131    }
132
133    pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
134        &mut self,
135        module: &Module<B>,
136        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
137        scratch: &mut Scratch<B>,
138    ) where
139        Module<B>: VecZnxDftAllocBytes
140            + VmpApplyDftToDftTmpBytes
141            + VecZnxBigNormalizeTmpBytes
142            + VmpApplyDftToDft<B>
143            + VmpApplyDftToDftAdd<B>
144            + VecZnxDftApply<B>
145            + VecZnxIdftApplyConsume<B>
146            + VecZnxBigAddSmallInplace<B>
147            + VecZnxBigNormalize<B>
148            + VecZnxAutomorphism
149            + VecZnxAutomorphismInplace<B>
150            + VecZnxNormalize<B>
151            + VecZnxNormalizeTmpBytes,
152        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B> + TakeVecZnx,
153    {
154        #[cfg(debug_assertions)]
155        {
156            assert_eq!(
157                self.rank_out(),
158                rhs.rank_in(),
159                "ksk_in output rank: {} != ksk_apply input rank: {}",
160                self.rank_out(),
161                rhs.rank_in()
162            );
163            assert_eq!(
164                self.rank_out(),
165                rhs.rank_out(),
166                "ksk_out output rank: {} != ksk_apply output rank: {}",
167                self.rank_out(),
168                rhs.rank_out()
169            );
170        }
171
172        let cols_out: usize = (rhs.rank_out() + 1).into();
173
174        let p: i64 = self.p();
175        let p_inv = module.galois_element_inv(p);
176
177        (0..self.rank_in().into()).for_each(|col_i| {
178            (0..self.rows().into()).for_each(|row_j| {
179                let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
180
181                // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
182                (0..cols_out).for_each(|i| {
183                    module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
184                });
185
186                // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
187                res_ct.keyswitch_inplace(module, &rhs.key, scratch);
188
189                // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
190                (0..cols_out).for_each(|i| {
191                    module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
192                });
193            });
194        });
195
196        self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64);
197    }
198}