poulpy_core/automorphism/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, IDFTTmpA, ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace,
4        VecZnxBigAddSmallInplace, VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace,
5        VecZnxDftAllocBytes, VecZnxDftCopy, VecZnxNormalizeTmpBytes, VmpApplyDftToDft, VmpApplyDftToDftAdd,
6        VmpApplyDftToDftTmpBytes,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9};
10
11use crate::layouts::{
12    GGSWCiphertext, GLWECiphertext, Infos,
13    prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17    #[allow(clippy::too_many_arguments)]
18    pub fn automorphism_scratch_space<B: Backend>(
19        module: &Module<B>,
20        basek: usize,
21        k_out: usize,
22        k_in: usize,
23        k_ksk: usize,
24        digits_ksk: usize,
25        k_tsk: usize,
26        digits_tsk: usize,
27        rank: usize,
28    ) -> usize
29    where
30        Module<B>: VecZnxDftAllocBytes
31            + VmpApplyDftToDftTmpBytes
32            + VecZnxBigAllocBytes
33            + VecZnxNormalizeTmpBytes
34            + VecZnxBigNormalizeTmpBytes,
35    {
36        let out_size: usize = k_out.div_ceil(basek);
37        let ci_dft: usize = module.vec_znx_dft_alloc_bytes(rank + 1, out_size);
38        let ks_internal: usize =
39            GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank);
40        let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, basek, k_out, k_tsk, digits_tsk, rank);
41        ci_dft + (ks_internal | expand)
42    }
43
44    #[allow(clippy::too_many_arguments)]
45    pub fn automorphism_inplace_scratch_space<B: Backend>(
46        module: &Module<B>,
47        basek: usize,
48        k_out: usize,
49        k_ksk: usize,
50        digits_ksk: usize,
51        k_tsk: usize,
52        digits_tsk: usize,
53        rank: usize,
54    ) -> usize
55    where
56        Module<B>: VecZnxDftAllocBytes
57            + VmpApplyDftToDftTmpBytes
58            + VecZnxBigAllocBytes
59            + VecZnxNormalizeTmpBytes
60            + VecZnxBigNormalizeTmpBytes,
61    {
62        GGSWCiphertext::automorphism_scratch_space(
63            module, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
64        )
65    }
66}
67
68impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
69    pub fn automorphism<DataLhs: DataRef, DataAk: DataRef, DataTsk: DataRef, B: Backend>(
70        &mut self,
71        module: &Module<B>,
72        lhs: &GGSWCiphertext<DataLhs>,
73        auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
74        tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
75        scratch: &mut Scratch<B>,
76    ) where
77        Module<B>: VecZnxDftAllocBytes
78            + VmpApplyDftToDftTmpBytes
79            + VecZnxBigNormalizeTmpBytes
80            + VmpApplyDftToDft<B>
81            + VmpApplyDftToDftAdd<B>
82            + DFT<B>
83            + IDFTConsume<B>
84            + VecZnxBigAddSmallInplace<B>
85            + VecZnxBigNormalize<B>
86            + VecZnxAutomorphismInplace
87            + VecZnxBigAllocBytes
88            + VecZnxNormalizeTmpBytes
89            + VecZnxDftCopy<B>
90            + VecZnxDftAddInplace<B>
91            + IDFTTmpA<B>,
92        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
93    {
94        #[cfg(debug_assertions)]
95        {
96            assert_eq!(self.n(), auto_key.n());
97            assert_eq!(lhs.n(), auto_key.n());
98
99            assert_eq!(
100                self.rank(),
101                lhs.rank(),
102                "ggsw_out rank: {} != ggsw_in rank: {}",
103                self.rank(),
104                lhs.rank()
105            );
106            assert_eq!(
107                self.rank(),
108                auto_key.rank(),
109                "ggsw_in rank: {} != auto_key rank: {}",
110                self.rank(),
111                auto_key.rank()
112            );
113            assert_eq!(
114                self.rank(),
115                tensor_key.rank(),
116                "ggsw_in rank: {} != tensor_key rank: {}",
117                self.rank(),
118                tensor_key.rank()
119            );
120            assert!(
121                scratch.available()
122                    >= GGSWCiphertext::automorphism_scratch_space(
123                        module,
124                        self.basek(),
125                        self.k(),
126                        lhs.k(),
127                        auto_key.k(),
128                        auto_key.digits(),
129                        tensor_key.k(),
130                        tensor_key.digits(),
131                        self.rank(),
132                    )
133            )
134        };
135
136        self.automorphism_internal(module, lhs, auto_key, scratch);
137        self.expand_row(module, tensor_key, scratch);
138    }
139
140    pub fn automorphism_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
141        &mut self,
142        module: &Module<B>,
143        auto_key: &GGLWEAutomorphismKeyPrepared<DataKsk, B>,
144        tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
145        scratch: &mut Scratch<B>,
146    ) where
147        Module<B>: VecZnxDftAllocBytes
148            + VmpApplyDftToDftTmpBytes
149            + VecZnxBigNormalizeTmpBytes
150            + VmpApplyDftToDft<B>
151            + VmpApplyDftToDftAdd<B>
152            + DFT<B>
153            + IDFTConsume<B>
154            + VecZnxBigAddSmallInplace<B>
155            + VecZnxBigNormalize<B>
156            + VecZnxAutomorphismInplace
157            + VecZnxBigAllocBytes
158            + VecZnxNormalizeTmpBytes
159            + VecZnxDftCopy<B>
160            + VecZnxDftAddInplace<B>
161            + IDFTTmpA<B>,
162        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
163    {
164        unsafe {
165            let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
166            self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
167        }
168    }
169
170    fn automorphism_internal<DataLhs: DataRef, DataAk: DataRef, B: Backend>(
171        &mut self,
172        module: &Module<B>,
173        lhs: &GGSWCiphertext<DataLhs>,
174        auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
175        scratch: &mut Scratch<B>,
176    ) where
177        Module<B>: VecZnxDftAllocBytes
178            + VmpApplyDftToDftTmpBytes
179            + VecZnxBigNormalizeTmpBytes
180            + VmpApplyDftToDft<B>
181            + VmpApplyDftToDftAdd<B>
182            + DFT<B>
183            + IDFTConsume<B>
184            + VecZnxBigAddSmallInplace<B>
185            + VecZnxBigNormalize<B>
186            + VecZnxAutomorphismInplace,
187        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
188    {
189        // Keyswitch the j-th row of the col 0
190        (0..lhs.rows()).for_each(|row_i| {
191            // Key-switch column 0, i.e.
192            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
193            self.at_mut(row_i, 0)
194                .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
195        });
196    }
197}