poulpy_core/
glwe_trace.rs

1use std::collections::HashMap;
2
3use poulpy_hal::{
4    api::ModuleLogN,
5    layouts::{Backend, DataMut, GaloisElement, Module, Scratch, VecZnx, galois_element},
6};
7
8use crate::{
9    GLWEAutomorphism, GLWECopy, GLWEShift, ScratchTakeCore,
10    layouts::{
11        Base2K, GGLWEInfos, GGLWEPreparedToRef, GLWE, GLWEInfos, GLWELayout, GLWEToMut, GLWEToRef, GetGaloisElement, LWEInfos,
12    },
13};
14
15impl GLWE<Vec<u8>> {
16    pub fn trace_galois_elements<M, BE: Backend>(module: &M) -> Vec<i64>
17    where
18        M: GLWETrace<BE>,
19    {
20        module.glwe_trace_galois_elements()
21    }
22
23    pub fn trace_tmp_bytes<R, A, K, M, BE: Backend>(module: &M, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
24    where
25        R: GLWEInfos,
26        A: GLWEInfos,
27        K: GGLWEInfos,
28        M: GLWETrace<BE>,
29    {
30        module.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos)
31    }
32}
33
34impl<D: DataMut> GLWE<D> {
35    pub fn trace<A, K, M, BE: Backend>(
36        &mut self,
37        module: &M,
38        start: usize,
39        end: usize,
40        a: &A,
41        keys: &HashMap<i64, K>,
42        scratch: &mut Scratch<BE>,
43    ) where
44        A: GLWEToRef,
45        K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
46        Scratch<BE>: ScratchTakeCore<BE>,
47        M: GLWETrace<BE>,
48    {
49        module.glwe_trace(self, start, end, a, keys, scratch);
50    }
51
52    pub fn trace_inplace<K, M, BE: Backend>(
53        &mut self,
54        module: &M,
55        start: usize,
56        end: usize,
57        keys: &HashMap<i64, K>,
58        scratch: &mut Scratch<BE>,
59    ) where
60        K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
61        Scratch<BE>: ScratchTakeCore<BE>,
62        M: GLWETrace<BE>,
63    {
64        module.glwe_trace_inplace(self, start, end, keys, scratch);
65    }
66}
67
68impl<BE: Backend> GLWETrace<BE> for Module<BE> where
69    Self: ModuleLogN + GaloisElement + GLWEAutomorphism<BE> + GLWEShift<BE> + GLWECopy
70{
71}
72
73#[inline(always)]
74pub fn trace_galois_elements(log_n: usize, cyclotomic_order: i64) -> Vec<i64> {
75    (0..log_n)
76        .map(|i| {
77            if i == 0 {
78                -1
79            } else {
80                galois_element(1 << (i - 1), cyclotomic_order)
81            }
82        })
83        .collect()
84}
85
86pub trait GLWETrace<BE: Backend>
87where
88    Self: ModuleLogN + GaloisElement + GLWEAutomorphism<BE> + GLWEShift<BE> + GLWECopy,
89{
90    fn glwe_trace_galois_elements(&self) -> Vec<i64> {
91        trace_galois_elements(self.log_n(), self.cyclotomic_order())
92    }
93
94    fn glwe_trace_tmp_bytes<R, A, K>(&self, res_infos: &R, a_infos: &A, key_infos: &K) -> usize
95    where
96        R: GLWEInfos,
97        A: GLWEInfos,
98        K: GGLWEInfos,
99    {
100        let trace: usize = self.glwe_automorphism_tmp_bytes(res_infos, a_infos, key_infos);
101        if a_infos.base2k() != key_infos.base2k() {
102            let glwe_conv: usize = VecZnx::bytes_of(
103                self.n(),
104                (key_infos.rank_out() + 1).into(),
105                res_infos.k().min(a_infos.k()).div_ceil(key_infos.base2k()) as usize,
106            ) + self.vec_znx_normalize_tmp_bytes();
107            return glwe_conv + trace;
108        }
109
110        trace
111    }
112
113    fn glwe_trace<R, A, K>(&self, res: &mut R, start: usize, end: usize, a: &A, keys: &HashMap<i64, K>, scratch: &mut Scratch<BE>)
114    where
115        R: GLWEToMut,
116        A: GLWEToRef,
117        K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
118        Scratch<BE>: ScratchTakeCore<BE>,
119    {
120        self.glwe_copy(res, a);
121        self.glwe_trace_inplace(res, start, end, keys, scratch);
122    }
123
124    fn glwe_trace_inplace<R, K>(&self, res: &mut R, start: usize, end: usize, keys: &HashMap<i64, K>, scratch: &mut Scratch<BE>)
125    where
126        R: GLWEToMut,
127        K: GGLWEPreparedToRef<BE> + GetGaloisElement + GGLWEInfos,
128        Scratch<BE>: ScratchTakeCore<BE>,
129    {
130        let res: &mut GLWE<&mut [u8]> = &mut res.to_mut();
131
132        let basek_ksk: Base2K = keys.get(keys.keys().next().unwrap()).unwrap().base2k();
133
134        #[cfg(debug_assertions)]
135        {
136            assert_eq!(res.n(), self.n() as u32);
137            assert!(start < end);
138            assert!(end <= self.log_n());
139            for key in keys.values() {
140                assert_eq!(key.n(), self.n() as u32);
141                assert_eq!(key.base2k(), basek_ksk);
142                assert_eq!(key.rank_in(), res.rank());
143                assert_eq!(key.rank_out(), res.rank());
144            }
145        }
146
147        if res.base2k() != basek_ksk {
148            let (mut self_conv, scratch_1) = scratch.take_glwe(&GLWELayout {
149                n: self.n().into(),
150                base2k: basek_ksk,
151                k: res.k(),
152                rank: res.rank(),
153            });
154
155            for j in 0..(res.rank() + 1).into() {
156                self.vec_znx_normalize(
157                    basek_ksk.into(),
158                    &mut self_conv.data,
159                    j,
160                    basek_ksk.into(),
161                    res.data(),
162                    j,
163                    scratch_1,
164                );
165            }
166
167            for i in start..end {
168                self.glwe_rsh(1, &mut self_conv, scratch_1);
169
170                let p: i64 = if i == 0 {
171                    -1
172                } else {
173                    self.galois_element(1 << (i - 1))
174                };
175
176                if let Some(key) = keys.get(&p) {
177                    self.glwe_automorphism_add_inplace(&mut self_conv, key, scratch_1);
178                } else {
179                    panic!("keys[{p}] is empty")
180                }
181            }
182
183            for j in 0..(res.rank() + 1).into() {
184                self.vec_znx_normalize(
185                    res.base2k().into(),
186                    res.data_mut(),
187                    j,
188                    basek_ksk.into(),
189                    &self_conv.data,
190                    j,
191                    scratch_1,
192                );
193            }
194        } else {
195            // println!("res: {}", res);
196
197            for i in start..end {
198                self.glwe_rsh(1, res, scratch);
199
200                let p: i64 = if i == 0 {
201                    -1
202                } else {
203                    self.galois_element(1 << (i - 1))
204                };
205
206                if let Some(key) = keys.get(&p) {
207                    self.glwe_automorphism_add_inplace(res, key, scratch);
208                } else {
209                    panic!("keys[{p}] is empty")
210                }
211            }
212        }
213    }
214}