poulpy_core/automorphism/
gglwe_atk.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4 VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume,
5 VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{GGLWEAutomorphismKey, GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared};
11
12impl GGLWEAutomorphismKey<Vec<u8>> {
13 #[allow(clippy::too_many_arguments)]
14 pub fn automorphism_scratch_space<B: Backend>(
15 module: &Module<B>,
16 basek: usize,
17 k_out: usize,
18 k_in: usize,
19 k_ksk: usize,
20 digits: usize,
21 rank: usize,
22 ) -> usize
23 where
24 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
25 {
26 GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
27 }
28
29 pub fn automorphism_inplace_scratch_space<B: Backend>(
30 module: &Module<B>,
31 basek: usize,
32 k_out: usize,
33 k_ksk: usize,
34 digits: usize,
35 rank: usize,
36 ) -> usize
37 where
38 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
39 {
40 GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank)
41 }
42}
43
44impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
45 pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
46 &mut self,
47 module: &Module<B>,
48 lhs: &GGLWEAutomorphismKey<DataLhs>,
49 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
50 scratch: &mut Scratch<B>,
51 ) where
52 Module<B>: VecZnxDftAllocBytes
53 + VmpApplyDftToDftTmpBytes
54 + VecZnxBigNormalizeTmpBytes
55 + VmpApplyDftToDft<B>
56 + VmpApplyDftToDftAdd<B>
57 + VecZnxDftApply<B>
58 + VecZnxIdftApplyConsume<B>
59 + VecZnxBigAddSmallInplace<B>
60 + VecZnxBigNormalize<B>
61 + VecZnxAutomorphism
62 + VecZnxAutomorphismInplace<B>,
63 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
64 {
65 #[cfg(debug_assertions)]
66 {
67 assert_eq!(
68 self.rank_in(),
69 lhs.rank_in(),
70 "ksk_out input rank: {} != ksk_in input rank: {}",
71 self.rank_in(),
72 lhs.rank_in()
73 );
74 assert_eq!(
75 self.rank_out(),
76 rhs.rank_in(),
77 "ksk_in output rank: {} != ksk_apply input rank: {}",
78 self.rank_out(),
79 rhs.rank_in()
80 );
81 assert_eq!(
82 self.rank_out(),
83 rhs.rank_out(),
84 "ksk_out output rank: {} != ksk_apply output rank: {}",
85 self.rank_out(),
86 rhs.rank_out()
87 );
88 assert!(
89 self.k() <= lhs.k(),
90 "output k={} cannot be greater than input k={}",
91 self.k(),
92 lhs.k()
93 )
94 }
95
96 let cols_out: usize = rhs.rank_out() + 1;
97
98 let p: i64 = lhs.p();
99 let p_inv = module.galois_element_inv(p);
100
101 (0..self.rank_in()).for_each(|col_i| {
102 (0..self.rows()).for_each(|row_j| {
103 let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
104 let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i);
105
106 (0..cols_out).for_each(|i| {
108 module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i);
109 });
110
111 res_ct.keyswitch_inplace(module, &rhs.key, scratch);
113
114 (0..cols_out).for_each(|i| {
116 module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
117 });
118 });
119 });
120
121 (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
122 (0..self.rank_in()).for_each(|col_j| {
123 self.at_mut(row_i, col_j).data.zero();
124 });
125 });
126
127 self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
128 }
129
130 pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
131 &mut self,
132 module: &Module<B>,
133 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
134 scratch: &mut Scratch<B>,
135 ) where
136 Module<B>: VecZnxDftAllocBytes
137 + VmpApplyDftToDftTmpBytes
138 + VecZnxBigNormalizeTmpBytes
139 + VmpApplyDftToDft<B>
140 + VmpApplyDftToDftAdd<B>
141 + VecZnxDftApply<B>
142 + VecZnxIdftApplyConsume<B>
143 + VecZnxBigAddSmallInplace<B>
144 + VecZnxBigNormalize<B>
145 + VecZnxAutomorphism
146 + VecZnxAutomorphismInplace<B>,
147 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
148 {
149 #[cfg(debug_assertions)]
150 {
151 assert_eq!(
152 self.rank_out(),
153 rhs.rank_in(),
154 "ksk_in output rank: {} != ksk_apply input rank: {}",
155 self.rank_out(),
156 rhs.rank_in()
157 );
158 assert_eq!(
159 self.rank_out(),
160 rhs.rank_out(),
161 "ksk_out output rank: {} != ksk_apply output rank: {}",
162 self.rank_out(),
163 rhs.rank_out()
164 );
165 }
166
167 let cols_out: usize = rhs.rank_out() + 1;
168
169 let p: i64 = self.p();
170 let p_inv = module.galois_element_inv(p);
171
172 (0..self.rank_in()).for_each(|col_i| {
173 (0..self.rows()).for_each(|row_j| {
174 let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
175
176 (0..cols_out).for_each(|i| {
178 module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
179 });
180
181 res_ct.keyswitch_inplace(module, &rhs.key, scratch);
183
184 (0..cols_out).for_each(|i| {
186 module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i, scratch);
187 });
188 });
189 });
190
191 self.p = (self.p * rhs.p) % (module.cyclotomic_order() as i64);
192 }
193}