poulpy_core/automorphism/
gglwe_atk.rs

1use poulpy_hal::{
2    api::{
3        ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace, VecZnxBigAddSmallInplace,
4        VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VecZnxDftFromVecZnx, VecZnxDftToVecZnxBigConsume,
5        VmpApply, VmpApplyAdd, VmpApplyTmpBytes, ZnxZero,
6    },
7    layouts::{Backend, DataMut, DataRef, Module, Scratch},
8};
9
10use crate::layouts::{GGLWEAutomorphismKey, GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared};
11
12impl GGLWEAutomorphismKey<Vec<u8>> {
13    #[allow(clippy::too_many_arguments)]
14    pub fn automorphism_scratch_space<B: Backend>(
15        module: &Module<B>,
16        n: usize,
17        basek: usize,
18        k_out: usize,
19        k_in: usize,
20        k_ksk: usize,
21        digits: usize,
22        rank: usize,
23    ) -> usize
24    where
25        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
26    {
27        GLWECiphertext::keyswitch_scratch_space(module, n, basek, k_out, k_in, k_ksk, digits, rank, rank)
28    }
29
30    pub fn automorphism_inplace_scratch_space<B: Backend>(
31        module: &Module<B>,
32        n: usize,
33        basek: usize,
34        k_out: usize,
35        k_ksk: usize,
36        digits: usize,
37        rank: usize,
38    ) -> usize
39    where
40        Module<B>: VecZnxDftAllocBytes + VmpApplyTmpBytes + VecZnxBigNormalizeTmpBytes,
41    {
42        GGLWEAutomorphismKey::automorphism_scratch_space(module, n, basek, k_out, k_out, k_ksk, digits, rank)
43    }
44}
45
46impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
47    pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
48        &mut self,
49        module: &Module<B>,
50        lhs: &GGLWEAutomorphismKey<DataLhs>,
51        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
52        scratch: &mut Scratch<B>,
53    ) where
54        Module<B>: VecZnxDftAllocBytes
55            + VmpApplyTmpBytes
56            + VecZnxBigNormalizeTmpBytes
57            + VmpApply<B>
58            + VmpApplyAdd<B>
59            + VecZnxDftFromVecZnx<B>
60            + VecZnxDftToVecZnxBigConsume<B>
61            + VecZnxBigAddSmallInplace<B>
62            + VecZnxBigNormalize<B>
63            + VecZnxAutomorphism
64            + VecZnxAutomorphismInplace,
65        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
66    {
67        #[cfg(debug_assertions)]
68        {
69            assert_eq!(
70                self.rank_in(),
71                lhs.rank_in(),
72                "ksk_out input rank: {} != ksk_in input rank: {}",
73                self.rank_in(),
74                lhs.rank_in()
75            );
76            assert_eq!(
77                lhs.rank_out(),
78                rhs.rank_in(),
79                "ksk_in output rank: {} != ksk_apply input rank: {}",
80                self.rank_out(),
81                rhs.rank_in()
82            );
83            assert_eq!(
84                self.rank_out(),
85                rhs.rank_out(),
86                "ksk_out output rank: {} != ksk_apply output rank: {}",
87                self.rank_out(),
88                rhs.rank_out()
89            );
90            assert!(
91                self.k() <= lhs.k(),
92                "output k={} cannot be greater than input k={}",
93                self.k(),
94                lhs.k()
95            )
96        }
97
98        let cols_out: usize = rhs.rank_out() + 1;
99
100        let p: i64 = lhs.p();
101        let p_inv = module.galois_element_inv(p);
102
103        (0..self.rank_in()).for_each(|col_i| {
104            (0..self.rows()).for_each(|row_j| {
105                let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
106                let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i);
107
108                // Reverts the automorphism X^{-k}: (-pi^{-1}_{k}(s)a + s, a) to (-sa + pi_{k}(s), a)
109                (0..cols_out).for_each(|i| {
110                    module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i);
111                });
112
113                // Key-switch (-sa + pi_{k}(s), a) to (-pi^{-1}_{k'}(s)a + pi_{k}(s), a)
114                res_ct.keyswitch_inplace(module, &rhs.key, scratch);
115
116                // Applies back the automorphism X^{-k}: (-pi^{-1}_{k'}(s)a + pi_{k}(s), a) to (-pi^{-1}_{k'+k}(s)a + s, a)
117                (0..cols_out).for_each(|i| {
118                    module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i);
119                });
120            });
121        });
122
123        (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
124            (0..self.rank_in()).for_each(|col_j| {
125                self.at_mut(row_i, col_j).data.zero();
126            });
127        });
128
129        self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
130    }
131
132    pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
133        &mut self,
134        module: &Module<B>,
135        rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
136        scratch: &mut Scratch<B>,
137    ) where
138        Module<B>: VecZnxDftAllocBytes
139            + VmpApplyTmpBytes
140            + VecZnxBigNormalizeTmpBytes
141            + VmpApply<B>
142            + VmpApplyAdd<B>
143            + VecZnxDftFromVecZnx<B>
144            + VecZnxDftToVecZnxBigConsume<B>
145            + VecZnxBigAddSmallInplace<B>
146            + VecZnxBigNormalize<B>
147            + VecZnxAutomorphism
148            + VecZnxAutomorphismInplace,
149        Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
150    {
151        unsafe {
152            let self_ptr: *mut GGLWEAutomorphismKey<DataSelf> = self as *mut GGLWEAutomorphismKey<DataSelf>;
153            self.automorphism(module, &*self_ptr, rhs, scratch);
154        }
155    }
156}