poulpy_core/automorphism/
ggsw_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
4 VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy,
5 VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VmpApply,
6 VmpApplyAdd, VmpApplyTmpBytes,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9};
10
11use crate::layouts::{
12 GGSWCiphertext, GLWECiphertext, Infos,
13 prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17 #[allow(clippy::too_many_arguments)]
18 pub fn automorphism_scratch_space<B: Backend>(
19 module: &Module<B>,
20 n: usize,
21 basek: usize,
22 k_out: usize,
23 k_in: usize,
24 k_ksk: usize,
25 digits_ksk: usize,
26 k_tsk: usize,
27 digits_tsk: usize,
28 rank: usize,
29 ) -> usize
30 where
31 Module<B>:
32 VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
33 {
34 let out_size: usize = k_out.div_ceil(basek);
35 let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size);
36 let ks_internal: usize =
37 GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank);
38 let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank);
39 ci_dft + (ks_internal | expand)
40 }
41
42 #[allow(clippy::too_many_arguments)]
43 pub fn automorphism_inplace_scratch_space<B: Backend>(
44 module: &Module<B>,
45 n: usize,
46 basek: usize,
47 k_out: usize,
48 k_ksk: usize,
49 digits_ksk: usize,
50 k_tsk: usize,
51 digits_tsk: usize,
52 rank: usize,
53 ) -> usize
54 where
55 Module<B>:
56 VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
57 {
58 GGSWCiphertext::automorphism_scratch_space(
59 module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
60 )
61 }
62}
63
64impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
65 pub fn automorphism<DataLhs: DataRef, DataAk: DataRef, DataTsk: DataRef, B: Backend>(
66 &mut self,
67 module: &Module<B>,
68 lhs: &GGSWCiphertext<DataLhs>,
69 auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
70 tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
71 scratch: &mut Scratch<B>,
72 ) where
73 Module<B>: VecZnxDftAllocBytes
74 + VmpApplyTmpBytes
75 + VecZnxBigNormalizeTmpBytes
76 + VmpApply<B>
77 + VmpApplyAdd<B>
78 + VecZnxDftFromVecZnx<B>
79 + VecZnxDftToVecZnxBigConsume<B>
80 + VecZnxBigAddSmallInplace<B>
81 + VecZnxBigNormalize<B>
82 + VecZnxAutomorphismInplace
83 + VecZnxBigAllocBytes
84 + VecZnxNormalizeTmpBytes
85 + VecZnxDftCopy<B>
86 + VecZnxDftAddInplace<B>
87 + VecZnxDftToVecZnxBigTmpA<B>,
88 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
89 {
90 #[cfg(debug_assertions)]
91 {
92 assert_eq!(self.n(), auto_key.n());
93 assert_eq!(lhs.n(), auto_key.n());
94
95 assert_eq!(
96 self.rank(),
97 lhs.rank(),
98 "ggsw_out rank: {} != ggsw_in rank: {}",
99 self.rank(),
100 lhs.rank()
101 );
102 assert_eq!(
103 self.rank(),
104 auto_key.rank(),
105 "ggsw_in rank: {} != auto_key rank: {}",
106 self.rank(),
107 auto_key.rank()
108 );
109 assert_eq!(
110 self.rank(),
111 tensor_key.rank(),
112 "ggsw_in rank: {} != tensor_key rank: {}",
113 self.rank(),
114 tensor_key.rank()
115 );
116 assert!(
117 scratch.available()
118 >= GGSWCiphertext::automorphism_scratch_space(
119 module,
120 self.n(),
121 self.basek(),
122 self.k(),
123 lhs.k(),
124 auto_key.k(),
125 auto_key.digits(),
126 tensor_key.k(),
127 tensor_key.digits(),
128 self.rank(),
129 )
130 )
131 };
132
133 self.automorphism_internal(module, lhs, auto_key, scratch);
134 self.expand_row(module, tensor_key, scratch);
135 }
136
137 pub fn automorphism_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
138 &mut self,
139 module: &Module<B>,
140 auto_key: &GGLWEAutomorphismKeyPrepared<DataKsk, B>,
141 tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
142 scratch: &mut Scratch<B>,
143 ) where
144 Module<B>: VecZnxDftAllocBytes
145 + VmpApplyTmpBytes
146 + VecZnxBigNormalizeTmpBytes
147 + VmpApply<B>
148 + VmpApplyAdd<B>
149 + VecZnxDftFromVecZnx<B>
150 + VecZnxDftToVecZnxBigConsume<B>
151 + VecZnxBigAddSmallInplace<B>
152 + VecZnxBigNormalize<B>
153 + VecZnxAutomorphismInplace
154 + VecZnxBigAllocBytes
155 + VecZnxNormalizeTmpBytes
156 + VecZnxDftCopy<B>
157 + VecZnxDftAddInplace<B>
158 + VecZnxDftToVecZnxBigTmpA<B>,
159 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
160 {
161 unsafe {
162 let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
163 self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
164 }
165 }
166
167 fn automorphism_internal<DataLhs: DataRef, DataAk: DataRef, B: Backend>(
168 &mut self,
169 module: &Module<B>,
170 lhs: &GGSWCiphertext<DataLhs>,
171 auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
172 scratch: &mut Scratch<B>,
173 ) where
174 Module<B>: VecZnxDftAllocBytes
175 + VmpApplyTmpBytes
176 + VecZnxBigNormalizeTmpBytes
177 + VmpApply<B>
178 + VmpApplyAdd<B>
179 + VecZnxDftFromVecZnx<B>
180 + VecZnxDftToVecZnxBigConsume<B>
181 + VecZnxBigAddSmallInplace<B>
182 + VecZnxBigNormalize<B>
183 + VecZnxAutomorphismInplace,
184 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
185 {
186 (0..lhs.rows()).for_each(|row_i| {
188 self.at_mut(row_i, 0)
191 .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
192 });
193 }
194}