poulpy_core/automorphism/
gglwe_atk.rs

1use poulpy_hal::{
2    api::{
3        DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace,
4        VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft,
5        VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{GGLWEAutomorphismKey, GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared};
11
12impl GGLWEAutomorphismKey<Vec<u8>> {
13    #[allow(clippy::too_many_arguments)]
14    pub fn automorphism_scratch_space<B: Backend>(
15        module: &Module<B>,
16        basek: usize,
17        k_out: usize,
18        k_in: usize,
19        k_ksk: usize,
20        digits: usize,
21        rank: usize,
22    ) -> usize
23    where
24        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
25    {
26        GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
27    }
28
29    pub fn automorphism_inplace_scratch_space<B: Backend>(
30        module: &Module<B>,
31        basek: usize,
32        k_out: usize,
33        k_ksk: usize,
34        digits: usize,
35        rank: usize,
36    ) -> usize
37    where
38        Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
39    {
40        GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank)
41    }
42}
43
44impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
45    pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
46        &mut self,
47        module: &Module<B>,
48        lhs: &GGLWEAutomorphismKey<DataLhs>,
49        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
50        scratch: &mut Scratch<B>,
51    ) where
52        Module<B>: VecZnxDftAllocBytes
53            + VmpApplyDftToDftTmpBytes
54            + VecZnxBigNormalizeTmpBytes
55            + VmpApplyDftToDft<B>
56            + VmpApplyDftToDftAdd<B>
57            + DFT<B>
58            + IDFTConsume<B>
59            + VecZnxBigAddSmallInplace<B>
60            + VecZnxBigNormalize<B>
61            + VecZnxAutomorphism
62            + VecZnxAutomorphismInplace,
63        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
64    {
65        #[cfg(debug_assertions)]
66        {
67            assert_eq!(
68                self.rank_in(),
69                lhs.rank_in(),
70                "ksk_out input rank: {} != ksk_in input rank: {}",
71                self.rank_in(),
72                lhs.rank_in()
73            );
74            assert_eq!(
75                lhs.rank_out(),
76                rhs.rank_in(),
77                "ksk_in output rank: {} != ksk_apply input rank: {}",
78                self.rank_out(),
79                rhs.rank_in()
80            );
81            assert_eq!(
82                self.rank_out(),
83                rhs.rank_out(),
84                "ksk_out output rank: {} != ksk_apply output rank: {}",
85                self.rank_out(),
86                rhs.rank_out()
87            );
88            assert!(
89                self.k() <= lhs.k(),
90                "output k={} cannot be greater than input k={}",
91                self.k(),
92                lhs.k()
93            )
94        }
95
96        let cols_out: usize = rhs.rank_out() + 1;
97
98        let p: i64 = lhs.p();
99        let p_inv = module.galois_element_inv(p);
100
101        (0..self.rank_in()).for_each(|col_i| {
102            (0..self.rows()).for_each(|row_j| {
103                let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
104                let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i);
105
106                // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
107                (0..cols_out).for_each(|i| {
108                    module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i);
109                });
110
111                // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
112                res_ct.keyswitch_inplace(module, &rhs.key, scratch);
113
114                // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
115                (0..cols_out).for_each(|i| {
116                    module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i);
117                });
118            });
119        });
120
121        (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
122            (0..self.rank_in()).for_each(|col_j| {
123                self.at_mut(row_i, col_j).data.zero();
124            });
125        });
126
127        self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
128    }
129
130    pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
131        &mut self,
132        module: &Module<B>,
133        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
134        scratch: &mut Scratch<B>,
135    ) where
136        Module<B>: VecZnxDftAllocBytes
137            + VmpApplyDftToDftTmpBytes
138            + VecZnxBigNormalizeTmpBytes
139            + VmpApplyDftToDft<B>
140            + VmpApplyDftToDftAdd<B>
141            + DFT<B>
142            + IDFTConsume<B>
143            + VecZnxBigAddSmallInplace<B>
144            + VecZnxBigNormalize<B>
145            + VecZnxAutomorphism
146            + VecZnxAutomorphismInplace,
147        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
148    {
149        unsafe {
150            let self_ptr: *mut GGLWEAutomorphismKey<DataSelf> = self as *mut GGLWEAutomorphismKey<DataSelf>;
151            self.automorphism(module, &*self_ptr, rhs, scratch);
152        }
153    }
154}