poulpy_core/automorphism/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace, VecZnxBigAllocBytes,
4        VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes, VecZnxDftCopy,
5        VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume, VecZnxDftToVecZnxBigTmpA, VecZnxNormalizeTmpBytes, VmpApply,
6        VmpApplyAdd, VmpApplyTmpBytes,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9};
10
11use crate::layouts::{
12    GGSWCiphertext, GLWECiphertext, Infos,
13    prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17    #[allow(clippy::too_many_arguments)]
18    pub fn automorphism_scratch_space<B: Backend>(
19        module: &Module<B>,
20        n: usize,
21        basek: usize,
22        k_out: usize,
23        k_in: usize,
24        k_ksk: usize,
25        digits_ksk: usize,
26        k_tsk: usize,
27        digits_tsk: usize,
28        rank: usize,
29    ) -> usize
30    where
31        Module<B>:
32            VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
33    {
34        let out_size: usize = k_out.div_ceil(basek);
35        let ci_dft: usize = module.vec_znx_dft_alloc_bytes(n, rank + 1, out_size);
36        let ks_internal: usize =
37            GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits_ksk, rank, rank);
38        let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, n, basek, k_out, k_tsk, digits_tsk, rank);
39        ci_dft + (ks_internal | expand)
40    }
41
42    #[allow(clippy::too_many_arguments)]
43    pub fn automorphism_inplace_scratch_space<B: Backend>(
44        module: &Module<B>,
45        n: usize,
46        basek: usize,
47        k_out: usize,
48        k_ksk: usize,
49        digits_ksk: usize,
50        k_tsk: usize,
51        digits_tsk: usize,
52        rank: usize,
53    ) -> usize
54    where
55        Module<B>:
56            VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigAllocBytes + VecZnxNormalizeTmpBytes + VecZnxBigNormalizeTmpBytes,
57    {
58        GGSWCiphertext::automorphism_scratch_space(
59            module, n, basek, k_out, k_out, k_ksk, digits_ksk, k_tsk, digits_tsk, rank,
60        )
61    }
62}
63
64impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
65    pub fn automorphism<DataLhs: DataRef, DataAk: DataRef, DataTsk: DataRef, B: Backend>(
66        &mut self,
67        module: &Module<B>,
68        lhs: &GGSWCiphertext<DataLhs>,
69        auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
70        tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
71        scratch: &mut Scratch<B>,
72    ) where
73        Module<B>: VecZnxDftAllocBytes
74            + VmpApplyTmpBytes
75            + VecZnxBigNormalizeTmpBytes
76            + VmpApply<B>
77            + VmpApplyAdd<B>
78            + VecZnxDftFromVecZnx<B>
79            + VecZnxDftToVecZnxBigConsume<B>
80            + VecZnxBigAddSmallInplace<B>
81            + VecZnxBigNormalize<B>
82            + VecZnxAutomorphismInplace
83            + VecZnxBigAllocBytes
84            + VecZnxNormalizeTmpBytes
85            + VecZnxDftCopy<B>
86            + VecZnxDftAddInplace<B>
87            + VecZnxDftToVecZnxBigTmpA<B>,
88        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
89    {
90        #[cfg(debug_assertions)]
91        {
92            assert_eq!(self.n(), auto_key.n());
93            assert_eq!(lhs.n(), auto_key.n());
94
95            assert_eq!(
96                self.rank(),
97                lhs.rank(),
98                "ggsw_out rank: {} != ggsw_in rank: {}",
99                self.rank(),
100                lhs.rank()
101            );
102            assert_eq!(
103                self.rank(),
104                auto_key.rank(),
105                "ggsw_in rank: {} != auto_key rank: {}",
106                self.rank(),
107                auto_key.rank()
108            );
109            assert_eq!(
110                self.rank(),
111                tensor_key.rank(),
112                "ggsw_in rank: {} != tensor_key rank: {}",
113                self.rank(),
114                tensor_key.rank()
115            );
116            assert!(
117                scratch.available()
118                    >= GGSWCiphertext::automorphism_scratch_space(
119                        module,
120                        self.n(),
121                        self.basek(),
122                        self.k(),
123                        lhs.k(),
124                        auto_key.k(),
125                        auto_key.digits(),
126                        tensor_key.k(),
127                        tensor_key.digits(),
128                        self.rank(),
129                    )
130            )
131        };
132
133        self.automorphism_internal(module, lhs, auto_key, scratch);
134        self.expand_row(module, tensor_key, scratch);
135    }
136
137    pub fn automorphism_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
138        &mut self,
139        module: &Module<B>,
140        auto_key: &GGLWEAutomorphismKeyPrepared<DataKsk, B>,
141        tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
142        scratch: &mut Scratch<B>,
143    ) where
144        Module<B>: VecZnxDftAllocBytes
145            + VmpApplyTmpBytes
146            + VecZnxBigNormalizeTmpBytes
147            + VmpApply<B>
148            + VmpApplyAdd<B>
149            + VecZnxDftFromVecZnx<B>
150            + VecZnxDftToVecZnxBigConsume<B>
151            + VecZnxBigAddSmallInplace<B>
152            + VecZnxBigNormalize<B>
153            + VecZnxAutomorphismInplace
154            + VecZnxBigAllocBytes
155            + VecZnxNormalizeTmpBytes
156            + VecZnxDftCopy<B>
157            + VecZnxDftAddInplace<B>
158            + VecZnxDftToVecZnxBigTmpA<B>,
159        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B>,
160    {
161        unsafe {
162            let self_ptr: *mut GGSWCiphertext<DataSelf> = self as *mut GGSWCiphertext<DataSelf>;
163            self.automorphism(module, &*self_ptr, auto_key, tensor_key, scratch);
164        }
165    }
166
167    fn automorphism_internal<DataLhs: DataRef, DataAk: DataRef, B: Backend>(
168        &mut self,
169        module: &Module<B>,
170        lhs: &GGSWCiphertext<DataLhs>,
171        auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
172        scratch: &mut Scratch<B>,
173    ) where
174        Module<B>: VecZnxDftAllocBytes
175            + VmpApplyTmpBytes
176            + VecZnxBigNormalizeTmpBytes
177            + VmpApply<B>
178            + VmpApplyAdd<B>
179            + VecZnxDftFromVecZnx<B>
180            + VecZnxDftToVecZnxBigConsume<B>
181            + VecZnxBigAddSmallInplace<B>
182            + VecZnxBigNormalize<B>
183            + VecZnxAutomorphismInplace,
184        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable,
185    {
186        // Keyswitch the j-th row of the col 0
187        (0..lhs.rows()).for_each(|row_i| {
188            // Key-switch column 0, i.e.
189            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
190            self.at_mut(row_i, 0)
191                .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
192        });
193    }
194}