poulpy_core/automorphism/
ggsw_ct.rs1use poulpy_hal::{
2 api::{
3 ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4 VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes,
5 VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes,
6 VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
7 },
8 layouts::{Backend, DataMut, DataRef, Module, Scratch},
9};
10
11use crate::layouts::{
12 GGLWELayoutInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext,
13 prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17 pub fn automorphism_scratch_space<B: Backend, OUT, IN, KEY, TSK>(
18 module: &Module<B>,
19 out_infos: &OUT,
20 in_infos: &IN,
21 key_infos: &KEY,
22 tsk_infos: &TSK,
23 ) -> usize
24 where
25 OUT: GGSWInfos,
26 IN: GGSWInfos,
27 KEY: GGLWELayoutInfos,
28 TSK: GGLWELayoutInfos,
29 Module<B>: VecZnxDftAllocBytes
30 + VmpApplyDftToDftTmpBytes
31 + VecZnxBigAllocBytes
32 + VecZnxNormalizeTmpBytes
33 + VecZnxBigNormalizeTmpBytes,
34 {
35 let out_size: usize = out_infos.size();
36 let ci_dft: usize = module.vec_znx_dft_alloc_bytes((key_infos.rank_out() + 1).into(), out_size);
37 let ks_internal: usize = GLWECiphertext::keyswitch_scratch_space(
38 module,
39 &out_infos.glwe_layout(),
40 &in_infos.glwe_layout(),
41 key_infos,
42 );
43 let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos);
44 ci_dft + (ks_internal | expand)
45 }
46
47 pub fn automorphism_inplace_scratch_space<B: Backend, OUT, KEY, TSK>(
48 module: &Module<B>,
49 out_infos: &OUT,
50 key_infos: &KEY,
51 tsk_infos: &TSK,
52 ) -> usize
53 where
54 OUT: GGSWInfos,
55 KEY: GGLWELayoutInfos,
56 TSK: GGLWELayoutInfos,
57 Module<B>: VecZnxDftAllocBytes
58 + VmpApplyDftToDftTmpBytes
59 + VecZnxBigAllocBytes
60 + VecZnxNormalizeTmpBytes
61 + VecZnxBigNormalizeTmpBytes,
62 {
63 GGSWCiphertext::automorphism_scratch_space(module, out_infos, out_infos, key_infos, tsk_infos)
64 }
65}
66
67impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
68 pub fn automorphism<DataLhs: DataRef, DataAk: DataRef, DataTsk: DataRef, B: Backend>(
69 &mut self,
70 module: &Module<B>,
71 lhs: &GGSWCiphertext<DataLhs>,
72 auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
73 tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
74 scratch: &mut Scratch<B>,
75 ) where
76 Module<B>: VecZnxDftAllocBytes
77 + VmpApplyDftToDftTmpBytes
78 + VecZnxBigNormalizeTmpBytes
79 + VmpApplyDftToDft<B>
80 + VmpApplyDftToDftAdd<B>
81 + VecZnxDftApply<B>
82 + VecZnxIdftApplyConsume<B>
83 + VecZnxBigAddSmallInplace<B>
84 + VecZnxBigNormalize<B>
85 + VecZnxAutomorphismInplace<B>
86 + VecZnxBigAllocBytes
87 + VecZnxNormalizeTmpBytes
88 + VecZnxDftCopy<B>
89 + VecZnxDftAddInplace<B>
90 + VecZnxIdftApplyTmpA<B>
91 + VecZnxNormalize<B>,
92 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B> + TakeVecZnx,
93 {
94 #[cfg(debug_assertions)]
95 {
96 use crate::layouts::{GLWEInfos, LWEInfos};
97
98 assert_eq!(self.n(), module.n() as u32);
99 assert_eq!(lhs.n(), module.n() as u32);
100 assert_eq!(auto_key.n(), module.n() as u32);
101 assert_eq!(tensor_key.n(), module.n() as u32);
102
103 assert_eq!(
104 self.rank(),
105 lhs.rank(),
106 "ggsw_out rank: {} != ggsw_in rank: {}",
107 self.rank(),
108 lhs.rank()
109 );
110 assert_eq!(
111 self.rank(),
112 auto_key.rank_out(),
113 "ggsw_in rank: {} != auto_key rank: {}",
114 self.rank(),
115 auto_key.rank_out()
116 );
117 assert_eq!(
118 self.rank(),
119 tensor_key.rank_out(),
120 "ggsw_in rank: {} != tensor_key rank: {}",
121 self.rank(),
122 tensor_key.rank_out()
123 );
124 assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self, lhs, auto_key, tensor_key))
125 };
126
127 (0..lhs.rows().into()).for_each(|row_i| {
129 self.at_mut(row_i, 0)
132 .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
133 });
134 self.expand_row(module, tensor_key, scratch);
135 }
136
137 pub fn automorphism_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
138 &mut self,
139 module: &Module<B>,
140 auto_key: &GGLWEAutomorphismKeyPrepared<DataKsk, B>,
141 tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
142 scratch: &mut Scratch<B>,
143 ) where
144 Module<B>: VecZnxDftAllocBytes
145 + VmpApplyDftToDftTmpBytes
146 + VecZnxBigNormalizeTmpBytes
147 + VmpApplyDftToDft<B>
148 + VmpApplyDftToDftAdd<B>
149 + VecZnxDftApply<B>
150 + VecZnxIdftApplyConsume<B>
151 + VecZnxBigAddSmallInplace<B>
152 + VecZnxBigNormalize<B>
153 + VecZnxAutomorphismInplace<B>
154 + VecZnxBigAllocBytes
155 + VecZnxNormalizeTmpBytes
156 + VecZnxDftCopy<B>
157 + VecZnxDftAddInplace<B>
158 + VecZnxIdftApplyTmpA<B>
159 + VecZnxNormalize<B>,
160 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B> + TakeVecZnx,
161 {
162 (0..self.rows().into()).for_each(|row_i| {
164 self.at_mut(row_i, 0)
167 .automorphism_inplace(module, auto_key, scratch);
168 });
169 self.expand_row(module, tensor_key, scratch);
170 }
171}