poulpy_core/automorphism/
ggsw_ct.rs1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace,
4 VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace,
5 VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
6 VmpApplyDftToDftTmpBytes,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9};
10
11use crate::layouts::{
12 GGSWCiphertext, GLWECiphertext, Infos,
13 prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17 #[allow(clippy::too_many_arguments)]
18 pub fn automorphism_scratch_space<B: Backend>(
19 module: &Module<B>,
20 basek: usize,
21 k_out: usize,
22 k_in: usize,
23 k_ksk: usize,
24 digits_ksk: usize,
25 k_tsk: usize,
26 digits_tsk: usize,
27 rank: usize,
28 ) -> usize
29 where
30 Module<B>: VecZnxDftAllocBytes
31 + VmpApplyDftToDftTmpBytes
32 + VecZnxBigAllocBytes
33 + VecZnxNormalizeTmpBytes
34 + VecZnxBigNormalizeTmpBytes,
35 {
36 let out_size: usize = k_out.div_ceil(basek);
37 let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size);
38 let ks_internal: usize =
39 GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank);
40 let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank);
41 ci_dft + (ks_internal | expand)
42 }
43
44 #[allow(clippy::too_many_arguments)]
45 pub fn automorphism_inplace_scratch_space<B: Backend>(
46 module: &Module<B>,
47 basek: usize,
48 k_out: usize,
49 k_ksk: usize,
50 digits_ksk: usize,
51 k_tsk: usize,
52 digits_tsk: usize,
53 rank: usize,
54 ) -> usize
55 where
56 Module<B>: VecZnxDftAllocBytes
57 + VmpApplyDftToDftTmpBytes
58 + VecZnxBigAllocBytes
59 + VecZnxNormalizeTmpBytes
60 + VecZnxBigNormalizeTmpBytes,
61 {
62 GGSWCiphertext::automorphism_scratch_space(
63 module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
64 )
65 }
66}
67
68impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
69 pub fn automorphism<DataLhs: DataRef, DataAk: DataRef, DataTsk: DataRef, B: Backend>(
70 &mut self,
71 module: &Module<B>,
72 lhs: &GGSWCiphertext<DataLhs>,
73 auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
74 tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
75 scratch: &mut Scratch<B>,
76 ) where
77 Module<B>: VecZnxDftAllocBytes
78 + VmpApplyDftToDftTmpBytes
79 + VecZnxBigNormalizeTmpBytes
80 + VmpApplyDftToDft<B>
81 + VmpApplyDftToDftAdd<B>
82 + DFT<B>
83 + IDFTConsume<B>
84 + VecZnxBigAddSmallInplace<B>
85 + VecZnxBigNormalize<B>
86 + VecZnxAutomorphismInplace
87 + VecZnxBigAllocBytes
88 + VecZnxNormalizeTmpBytes
89 + VecZnxDftCopy<B>
90 + VecZnxDftAddInplace<B>
91 + IDFTTmpA<B>,
92 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
93 {
94 #[cfg(debug_assertions)]
95 {
96 assert_eq!(self.n(), auto_key.n());
97 assert_eq!(lhs.n(), auto_key.n());
98
99 assert_eq!(
100 self.rank(),
101 lhs.rank(),
102 "ggsw_out rank: {} != ggsw_in rank: {}",
103 self.rank(),
104 lhs.rank()
105 );
106 assert_eq!(
107 self.rank(),
108 auto_key.rank(),
109 "ggsw_in rank: {} != auto_key rank: {}",
110 self.rank(),
111 auto_key.rank()
112 );
113 assert_eq!(
114 self.rank(),
115 tensor_key.rank(),
116 "ggsw_in rank: {} != tensor_key rank: {}",
117 self.rank(),
118 tensor_key.rank()
119 );
120 assert!(
121 scratch.available()
122 >= GGSWCiphertext::automorphism_scratch_space(
123 module,
124 self.basek(),
125 self.k(),
126 lhs.k(),
127 auto_key.k(),
128 auto_key.digits(),
129 tensor_key.k(),
130 tensor_key.digits(),
131 self.rank(),
132 )
133 )
134 };
135
136 self.automorphism_internal(module, lhs, auto_key, scratch);
137 self.expand_row(module, tensor_key, scratch);
138 }
139
140 pub fn automorphism_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
141 &mut self,
142 module: &Module<B>,
143 auto_key: &GGLWEAutomorphismKeyPrepared<DataKsk, B>,
144 tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
145 scratch: &mut Scratch<B>,
146 ) where
147 Module<B>: VecZnxDftAllocBytes
148 + VmpApplyDftToDftTmpBytes
149 + VecZnxBigNormalizeTmpBytes
150 + VmpApplyDftToDft<B>
151 + VmpApplyDftToDftAdd<B>
152 + DFT<B>
153 + IDFTConsume<B>
154 + VecZnxBigAddSmallInplace<B>
155 + VecZnxBigNormalize<B>
156 + VecZnxAutomorphismInplace
157 + VecZnxBigAllocBytes
158 + VecZnxNormalizeTmpBytes
159 + VecZnxDftCopy<B>
160 + VecZnxDftAddInplace<B>
161 + IDFTTmpA<B>,
162 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
163 {
164 unsafe {
165 let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
166 self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
167 }
168 }
169
170 fn automorphism_internal<DataLhs: DataRef, DataAk: DataRef, B: Backend>(
171 &mut self,
172 module: &Module<B>,
173 lhs: &GGSWCiphertext<DataLhs>,
174 auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
175 scratch: &mut Scratch<B>,
176 ) where
177 Module<B>: VecZnxDftAllocBytes
178 + VmpApplyDftToDftTmpBytes
179 + VecZnxBigNormalizeTmpBytes
180 + VmpApplyDftToDft<B>
181 + VmpApplyDftToDftAdd<B>
182 + DFT<B>
183 + IDFTConsume<B>
184 + VecZnxBigAddSmallInplace<B>
185 + VecZnxBigNormalize<B>
186 + VecZnxAutomorphismInplace,
187 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
188 {
189 (0..lhs.rows()).for_each(|row_i| {
191 self.at_mut(row_i, 0)
194 .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
195 });
196 }
197}