poulpy_core/
glwe_trace.rs

1use std::collections::HashMap;
2
3use poulpy_hal::{
4    api::{
5        ScratchAvailable, TakeVecZnx, TakeVecZnxDft, VecZnxBigAddSmallInplace, VecZnxBigAutomorphismInplace, VecZnxBigNormalize,
6        VecZnxBigNormalizeTmpBytes, VecZnxCopy, VecZnxDftAllocBytes, VecZnxDftApply, VecZnxIdftApplyConsume, VecZnxNormalize,
7        VecZnxNormalizeTmpBytes, VecZnxRshInplace, VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
8    },
9    layouts::{Backend, DataMut, DataRef, Module, Scratch, VecZnx},
10};
11
12use crate::{
13    TakeGLWECt,
14    layouts::{
15        Base2K, GGLWELayoutInfos, GLWECiphertext, GLWECiphertextLayout, GLWEInfos, LWEInfos,
16        prepared::GGLWEAutomorphismKeyPrepared,
17    },
18    operations::GLWEOperations,
19};
20
21impl GLWECiphertext<Vec<u8>> {
22    pub fn trace_galois_elements<B: Backend>(module: &Module<B>) -> Vec<i64> {
23        let mut gal_els: Vec<i64> = Vec::new();
24        (0..module.log_n()).for_each(|i| {
25            if i == 0 {
26                gal_els.push(-1);
27            } else {
28                gal_els.push(module.galois_element(1 << (i - 1)));
29            }
30        });
31        gal_els
32    }
33
34    pub fn trace_scratch_space<B: Backend, OUT, IN, KEY>(
35        module: &Module<B>,
36        out_infos: &OUT,
37        in_infos: &IN,
38        key_infos: &KEY,
39    ) -> usize
40    where
41        OUT: GLWEInfos,
42        IN: GLWEInfos,
43        KEY: GGLWELayoutInfos,
44        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
45    {
46        let trace: usize = Self::automorphism_inplace_scratch_space(module, out_infos, key_infos);
47        if in_infos.base2k() != key_infos.base2k() {
48            let glwe_conv: usize = VecZnx::alloc_bytes(
49                module.n(),
50                (key_infos.rank_out() + 1).into(),
51                out_infos.k().min(in_infos.k()).div_ceil(key_infos.base2k()) as usize,
52            ) + module.vec_znx_normalize_tmp_bytes();
53            return glwe_conv + trace;
54        }
55
56        trace
57    }
58
59    pub fn trace_inplace_scratch_space<B: Backend, OUT, KEY>(module: &Module<B>, out_infos: &OUT, key_infos: &KEY) -> usize
60    where
61        OUT: GLWEInfos,
62        KEY: GGLWELayoutInfos,
63        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes + VecZnxNormalizeTmpBytes,
64    {
65        Self::trace_scratch_space(module, out_infos, out_infos, key_infos)
66    }
67}
68
69impl<DataSelf: DataMut> GLWECiphertext<DataSelf> {
70    pub fn trace<DataLhs: DataRef, DataAK: DataRef, B: Backend>(
71        &mut self,
72        module: &Module<B>,
73        start: usize,
74        end: usize,
75        lhs: &GLWECiphertext<DataLhs>,
76        auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
77        scratch: &mut Scratch<B>,
78    ) where
79        Module<B>: VecZnxDftAllocBytes
80            + VmpApplyDftToDftTmpBytes
81            + VecZnxBigNormalizeTmpBytes
82            + VmpApplyDftToDft<B>
83            + VmpApplyDftToDftAdd<B>
84            + VecZnxDftApply<B>
85            + VecZnxIdftApplyConsume<B>
86            + VecZnxBigAddSmallInplace<B>
87            + VecZnxBigNormalize<B>
88            + VecZnxBigAutomorphismInplace<B>
89            + VecZnxRshInplace<B>
90            + VecZnxCopy
91            + VecZnxNormalizeTmpBytes
92            + VecZnxNormalize<B>,
93        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
94    {
95        self.copy(module, lhs);
96        self.trace_inplace(module, start, end, auto_keys, scratch);
97    }
98
99    pub fn trace_inplace<DataAK: DataRef, B: Backend>(
100        &mut self,
101        module: &Module<B>,
102        start: usize,
103        end: usize,
104        auto_keys: &HashMap<i64, GGLWEAutomorphismKeyPrepared<DataAK, B>>,
105        scratch: &mut Scratch<B>,
106    ) where
107        Module<B>: VecZnxDftAllocBytes
108            + VmpApplyDftToDftTmpBytes
109            + VecZnxBigNormalizeTmpBytes
110            + VmpApplyDftToDft<B>
111            + VmpApplyDftToDftAdd<B>
112            + VecZnxDftApply<B>
113            + VecZnxIdftApplyConsume<B>
114            + VecZnxBigAddSmallInplace<B>
115            + VecZnxBigNormalize<B>
116            + VecZnxBigAutomorphismInplace<B>
117            + VecZnxRshInplace<B>
118            + VecZnxNormalizeTmpBytes
119            + VecZnxNormalize<B>,
120        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnx,
121    {
122        let basek_ksk: Base2K = auto_keys
123            .get(auto_keys.keys().next().unwrap())
124            .unwrap()
125            .base2k();
126
127        #[cfg(debug_assertions)]
128        {
129            assert_eq!(self.n(), module.n() as u32);
130            assert!(start < end);
131            assert!(end <= module.log_n());
132            for key in auto_keys.values() {
133                assert_eq!(key.n(), module.n() as u32);
134                assert_eq!(key.base2k(), basek_ksk);
135                assert_eq!(key.rank_in(), self.rank());
136                assert_eq!(key.rank_out(), self.rank());
137            }
138        }
139
140        if self.base2k() != basek_ksk {
141            let (mut self_conv, scratch_1) = scratch.take_glwe_ct(&GLWECiphertextLayout {
142                n: module.n().into(),
143                base2k: basek_ksk,
144                k: self.k(),
145                rank: self.rank(),
146            });
147
148            for j in 0..(self.rank() + 1).into() {
149                module.vec_znx_normalize(
150                    basek_ksk.into(),
151                    &mut self_conv.data,
152                    j,
153                    basek_ksk.into(),
154                    &self.data,
155                    j,
156                    scratch_1,
157                );
158            }
159
160            for i in start..end {
161                self_conv.rsh(module, 1, scratch_1);
162
163                let p: i64 = if i == 0 {
164                    -1
165                } else {
166                    module.galois_element(1 << (i - 1))
167                };
168
169                if let Some(key) = auto_keys.get(&p) {
170                    self_conv.automorphism_add_inplace(module, key, scratch_1);
171                } else {
172                    panic!("auto_keys[{p}] is empty")
173                }
174            }
175
176            for j in 0..(self.rank() + 1).into() {
177                module.vec_znx_normalize(
178                    self.base2k().into(),
179                    &mut self.data,
180                    j,
181                    basek_ksk.into(),
182                    &self_conv.data,
183                    j,
184                    scratch_1,
185                );
186            }
187        } else {
188            for i in start..end {
189                self.rsh(module, 1, scratch);
190
191                let p: i64 = if i == 0 {
192                    -1
193                } else {
194                    module.galois_element(1 << (i - 1))
195                };
196
197                if let Some(key) = auto_keys.get(&p) {
198                    self.automorphism_add_inplace(module, key, scratch);
199                } else {
200                    panic!("auto_keys[{p}] is empty")
201                }
202            }
203        }
204    }
205}