1use std::collections::HashMap;
2
3use poulpy_hal::{
4 api::{
5 ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
6 VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize,
7 VecZnxNormalizeTmpBytes, VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
8 },
9 layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx},
10};
11
12use crate::{
13 TakeGLWECt,
14 layouts::{
15 Base2K, GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWEInfos,
16 prepared::GGLWEAutomorphismKeyPrepared,
17 },
18 operations::GLWEOperations,
19};
20
21impl GLWECiphertext<Vec<u8>> {
22 pub fn trace_galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
23 let mut gal_els: Vec<i64> = Vec::new();
24 (0..module.log_n()).for_each(|i| {
25 if i == 0 {
26 gal_els.push(-1);
27 } else {
28 gal_els.push(module.galois_element(1 << (i - 1)));
29 }
30 });
31 gal_els
32 }
33
34 pub fn trace_scratch_space<B: Backend, OUT, IN, KEY>(
35 module: &Module<B>,
36 out_infos: &OUT,
37 in_infos: &IN,
38 key_infos: &KEY,
39 ) -> usize
40 where
41 OUT: GLWEInfos,
42 IN: GLWEInfos,
43 KEY: GGLWELayoutInfos,
44 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
45 {
46 let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos);
47 if in_infos.base2k() != key_infos.base2k() {
48 let glwe_conv: usize = VecZnx::alloc_bytes(
49 module.n(),
50 (key_infos.rank_out() + 1).into(),
51 out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize,
52 ) + module.vec_znx_normalize_tmp_bytes();
53 return glwe_conv + trace;
54 }
55
56 trace
57 }
58
59 pub fn trace_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
60 where
61 OUT: GLWEInfos,
62 KEY: GGLWELayoutInfos,
63 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
64 {
65 Self::trace_scratch_space(module, out_infos, out_infos, key_infos)
66 }
67}
68
69impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
70 pub fn trace<DataLhs: DataRef, DataAK: DataRef, B: Backend>(
71 &mut self,
72 module: &Module<B>,
73 start: usize,
74 end: usize,
75 lhs: &GLWECiphertext<DataLhs>,
76 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
77 scratch: &mut Scratch<B>,
78 ) where
79 Module<B>: VecZnxDftAllocBytes
80 + VmpApplyDftToDftTmpBytes
81 + VecZnxBigNormalizeTmpBytes
82 + VmpApplyDftToDft<B>
83 + VmpApplyDftToDftAdd<B>
84 + VecZnxDftApply<B>
85 + VecZnxIdftApplyConsume<B>
86 + VecZnxBigAddSmallInplace<B>
87 + VecZnxBigNormalize<B>
88 + VecZnxBigAutomorphismInplace<B>
89 + VecZnxRshInplace<B>
90 + VecZnxCopy
91 + VecZnxNormalizeTmpBytes
92 + VecZnxNormalize<B>,
93 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
94 {
95 self.copy(module, lhs);
96 self.trace_inplace(module, start, end, auto_keys, scratch);
97 }
98
99 pub fn trace_inplace<DataAK: DataRef, B: Backend>(
100 &mut self,
101 module: &Module<B>,
102 start: usize,
103 end: usize,
104 auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
105 scratch: &mut Scratch<B>,
106 ) where
107 Module<B>: VecZnxDftAllocBytes
108 + VmpApplyDftToDftTmpBytes
109 + VecZnxBigNormalizeTmpBytes
110 + VmpApplyDftToDft<B>
111 + VmpApplyDftToDftAdd<B>
112 + VecZnxDftApply<B>
113 + VecZnxIdftApplyConsume<B>
114 + VecZnxBigAddSmallInplace<B>
115 + VecZnxBigNormalize<B>
116 + VecZnxBigAutomorphismInplace<B>
117 + VecZnxRshInplace<B>
118 + VecZnxNormalizeTmpBytes
119 + VecZnxNormalize<B>,
120 Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
121 {
122 let basek_ksk: Base2K = auto_keys
123 .get(auto_keys.keys().next().unwrap())
124 .unwrap()
125 .base2k();
126
127 #[cfg(debug_assertions)]
128 {
129 assert_eq!(self.n(), module.n() as u32);
130 assert!(start < end);
131 assert!(end <= module.log_n());
132 for key in auto_keys.values() {
133 assert_eq!(key.n(), module.n() as u32);
134 assert_eq!(key.base2k(), basek_ksk);
135 assert_eq!(key.rank_in(), self.rank());
136 assert_eq!(key.rank_out(), self.rank());
137 }
138 }
139
140 if self.base2k() != basek_ksk {
141 let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
142 n: module.n().into(),
143 base2k: basek_ksk,
144 k: self.k(),
145 rank: self.rank(),
146 });
147
148 for j in 0..(self.rank() + 1).into() {
149 module.vec_znx_normalize(
150 basek_ksk.into(),
151 &mut self_conv.data,
152 j,
153 basek_ksk.into(),
154 &self.data,
155 j,
156 scratch_1,
157 );
158 }
159
160 for i in start..end {
161 self_conv.rsh(module, 1, scratch_1);
162
163 let p: i64 = if i == 0 {
164 -1
165 } else {
166 module.galois_element(1 << (i - 1))
167 };
168
169 if let Some(key) = auto_keys.get(&p) {
170 self_conv.automorphism_add_inplace(module, key, scratch_1);
171 } else {
172 panic!("auto_keys[{p}] is empty")
173 }
174 }
175
176 for j in 0..(self.rank() + 1).into() {
177 module.vec_znx_normalize(
178 self.base2k().into(),
179 &mut self.data,
180 j,
181 basek_ksk.into(),
182 &self_conv.data,
183 j,
184 scratch_1,
185 );
186 }
187 } else {
188 for i in start..end {
189 self.rsh(module, 1, scratch);
190
191 let p: i64 = if i == 0 {
192 -1
193 } else {
194 module.galois_element(1 << (i - 1))
195 };
196
197 if let Some(key) = auto_keys.get(&p) {
198 self.automorphism_add_inplace(module, key, scratch);
199 } else {
200 panic!("auto_keys[{p}] is empty")
201 }
202 }
203 }
204 }
205}