poulpy_core/automorphism/
ggsw_ct.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnx, TakeVecZnxBig, TakeVecZnxDft, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4        VecZnxBigAllocBytes, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAddInplace, VecZnxDftAllocBytes,
5        VecZnxDftApply, VecZnxDftCopy, VecZnxIdftApplyConsume, VecZnxIdftApplyTmpA, VecZnxNormalize, VecZnxNormalizeTmpBytes,
6        VmpApplyDftToDft, VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
7    },
8    layouts::{Backend, DataMut, DataRef, Module, Scratch},
9};
10
11use crate::layouts::{
12    GGLWELayoutInfos, GGSWCiphertext, GGSWInfos, GLWECiphertext,
13    prepared::{GGLWEAutomorphismKeyPrepared, GGLWETensorKeyPrepared},
14};
15
16impl GGSWCiphertext<Vec<u8>> {
17    pub fn automorphism_scratch_space<B: Backend, OUT, IN, KEY, TSK>(
18        module: &Module<B>,
19        out_infos: &OUT,
20        in_infos: &IN,
21        key_infos: &KEY,
22        tsk_infos: &TSK,
23    ) -> usize
24    where
25        OUT: GGSWInfos,
26        IN: GGSWInfos,
27        KEY: GGLWELayoutInfos,
28        TSK: GGLWELayoutInfos,
29        Module<B>: VecZnxDftAllocBytes
30            + VmpApplyDftToDftTmpBytes
31            + VecZnxBigAllocBytes
32            + VecZnxNormalizeTmpBytes
33            + VecZnxBigNormalizeTmpBytes,
34    {
35        let out_size: usize = out_infos.size();
36        let ci_dft: usize = module.vec_znx_dft_alloc_bytes((key_infos.rank_out() + 1).into(), out_size);
37        let ks_internal: usize = GLWECiphertext::keyswitch_scratch_space(
38            module,
39            &out_infos.glwe_layout(),
40            &in_infos.glwe_layout(),
41            key_infos,
42        );
43        let expand: usize = GGSWCiphertext::expand_row_scratch_space(module, out_infos, tsk_infos);
44        ci_dft + (ks_internal | expand)
45    }
46
47    pub fn automorphism_inplace_scratch_space<B: Backend, OUT, KEY, TSK>(
48        module: &Module<B>,
49        out_infos: &OUT,
50        key_infos: &KEY,
51        tsk_infos: &TSK,
52    ) -> usize
53    where
54        OUT: GGSWInfos,
55        KEY: GGLWELayoutInfos,
56        TSK: GGLWELayoutInfos,
57        Module<B>: VecZnxDftAllocBytes
58            + VmpApplyDftToDftTmpBytes
59            + VecZnxBigAllocBytes
60            + VecZnxNormalizeTmpBytes
61            + VecZnxBigNormalizeTmpBytes,
62    {
63        GGSWCiphertext::automorphism_scratch_space(module, out_infos, out_infos, key_infos, tsk_infos)
64    }
65}
66
67impl<DataSelf: DataMut> GGSWCiphertext<DataSelf> {
68    pub fn automorphism<DataLhs: DataRef, DataAk: DataRef, DataTsk: DataRef, B: Backend>(
69        &mut self,
70        module: &Module<B>,
71        lhs: &GGSWCiphertext<DataLhs>,
72        auto_key: &GGLWEAutomorphismKeyPrepared<DataAk, B>,
73        tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
74        scratch: &mut Scratch<B>,
75    ) where
76        Module<B>: VecZnxDftAllocBytes
77            + VmpApplyDftToDftTmpBytes
78            + VecZnxBigNormalizeTmpBytes
79            + VmpApplyDftToDft<B>
80            + VmpApplyDftToDftAdd<B>
81            + VecZnxDftApply<B>
82            + VecZnxIdftApplyConsume<B>
83            + VecZnxBigAddSmallInplace<B>
84            + VecZnxBigNormalize<B>
85            + VecZnxAutomorphismInplace<B>
86            + VecZnxBigAllocBytes
87            + VecZnxNormalizeTmpBytes
88            + VecZnxDftCopy<B>
89            + VecZnxDftAddInplace<B>
90            + VecZnxIdftApplyTmpA<B>
91            + VecZnxNormalize<B>,
92        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B> + TakeVecZnx,
93    {
94        #[cfg(debug_assertions)]
95        {
96            use crate::layouts::{GLWEInfos, LWEInfos};
97
98            assert_eq!(self.n(), module.n() as u32);
99            assert_eq!(lhs.n(), module.n() as u32);
100            assert_eq!(auto_key.n(), module.n() as u32);
101            assert_eq!(tensor_key.n(), module.n() as u32);
102
103            assert_eq!(
104                self.rank(),
105                lhs.rank(),
106                "ggsw_out rank: {} != ggsw_in rank: {}",
107                self.rank(),
108                lhs.rank()
109            );
110            assert_eq!(
111                self.rank(),
112                auto_key.rank_out(),
113                "ggsw_in rank: {} != auto_key rank: {}",
114                self.rank(),
115                auto_key.rank_out()
116            );
117            assert_eq!(
118                self.rank(),
119                tensor_key.rank_out(),
120                "ggsw_in rank: {} != tensor_key rank: {}",
121                self.rank(),
122                tensor_key.rank_out()
123            );
124            assert!(scratch.available() >= GGSWCiphertext::automorphism_scratch_space(module, self, lhs, auto_key, tensor_key))
125        };
126
127        // Keyswitch the j-th row of the col 0
128        (0..lhs.rows().into()).for_each(|row_i| {
129            // Key-switch column 0, i.e.
130            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
131            self.at_mut(row_i, 0)
132                .automorphism(module, &lhs.at(row_i, 0), auto_key, scratch);
133        });
134        self.expand_row(module, tensor_key, scratch);
135    }
136
137    pub fn automorphism_inplace<DataKsk: DataRef, DataTsk: DataRef, B: Backend>(
138        &mut self,
139        module: &Module<B>,
140        auto_key: &GGLWEAutomorphismKeyPrepared<DataKsk, B>,
141        tensor_key: &GGLWETensorKeyPrepared<DataTsk, B>,
142        scratch: &mut Scratch<B>,
143    ) where
144        Module<B>: VecZnxDftAllocBytes
145            + VmpApplyDftToDftTmpBytes
146            + VecZnxBigNormalizeTmpBytes
147            + VmpApplyDftToDft<B>
148            + VmpApplyDftToDftAdd<B>
149            + VecZnxDftApply<B>
150            + VecZnxIdftApplyConsume<B>
151            + VecZnxBigAddSmallInplace<B>
152            + VecZnxBigNormalize<B>
153            + VecZnxAutomorphismInplace<B>
154            + VecZnxBigAllocBytes
155            + VecZnxNormalizeTmpBytes
156            + VecZnxDftCopy<B>
157            + VecZnxDftAddInplace<B>
158            + VecZnxIdftApplyTmpA<B>
159            + VecZnxNormalize<B>,
160        Scratch<B>: TakeVecZnxDft<B> + ScratchAvailable + TakeVecZnxBig<B> + TakeVecZnx,
161    {
162        // Keyswitch the j-th row of the col 0
163        (0..self.rows().into()).for_each(|row_i| {
164            // Key-switch column 0, i.e.
165            // col 0: (-(a0s0 + a1s1 + a2s2) + M[i], a0, a1, a2) -> (-(a0pi^-1(s0) + a1pi^-1(s1) + a2pi^-1(s2)) + M[i], a0, a1, a2)
166            self.at_mut(row_i, 0)
167                .automorphism_inplace(module, auto_key, scratch);
168        });
169        self.expand_row(module, tensor_key, scratch);
170    }
171}