poulpy_core/automorphism/
gglwe_atk.rs1use poulpy_hal::{
2 api::{
3 DFT, IDFTConsume, ScratchAvailable, TakeVecZnxDft, VecZnxAutomorphism, VecZnxAutomorphismInplace,
4 VecZnxBigAddSmallInplace, VecZnxBigNormalize, VecZnxBigNormalizeTmpBytes, VecZnxDftAllocBytes, VmpApplyDftToDft,
5 VmpApplyDftToDftAdd, VmpApplyDftToDftTmpBytes,
6 },
7 layouts::{Backend, DataMut, DataRef, Module, Scratch, ZnxZero},
8};
9
10use crate::layouts::{GGLWEAutomorphismKey, GLWECiphertext, Infos, prepared::GGLWEAutomorphismKeyPrepared};
11
12impl GGLWEAutomorphismKey<Vec<u8>> {
13 #[allow(clippy::too_many_arguments)]
14 pub fn automorphism_scratch_space<B: Backend>(
15 module: &Module<B>,
16 basek: usize,
17 k_out: usize,
18 k_in: usize,
19 k_ksk: usize,
20 digits: usize,
21 rank: usize,
22 ) -> usize
23 where
24 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
25 {
26 GLWECiphertext::keyswitch_scratch_space(module, basek, k_out, k_in, k_ksk, digits, rank, rank)
27 }
28
29 pub fn automorphism_inplace_scratch_space<B: Backend>(
30 module: &Module<B>,
31 basek: usize,
32 k_out: usize,
33 k_ksk: usize,
34 digits: usize,
35 rank: usize,
36 ) -> usize
37 where
38 Module<B>: VecZnxDftAllocBytes + VmpApplyDftToDftTmpBytes + VecZnxBigNormalizeTmpBytes,
39 {
40 GGLWEAutomorphismKey::automorphism_scratch_space(module, basek, k_out, k_out, k_ksk, digits, rank)
41 }
42}
43
44impl<DataSelf: DataMut> GGLWEAutomorphismKey<DataSelf> {
45 pub fn automorphism<DataLhs: DataRef, DataRhs: DataRef, B: Backend>(
46 &mut self,
47 module: &Module<B>,
48 lhs: &GGLWEAutomorphismKey<DataLhs>,
49 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
50 scratch: &mut Scratch<B>,
51 ) where
52 Module<B>: VecZnxDftAllocBytes
53 + VmpApplyDftToDftTmpBytes
54 + VecZnxBigNormalizeTmpBytes
55 + VmpApplyDftToDft<B>
56 + VmpApplyDftToDftAdd<B>
57 + DFT<B>
58 + IDFTConsume<B>
59 + VecZnxBigAddSmallInplace<B>
60 + VecZnxBigNormalize<B>
61 + VecZnxAutomorphism
62 + VecZnxAutomorphismInplace,
63 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
64 {
65 #[cfg(debug_assertions)]
66 {
67 assert_eq!(
68 self.rank_in(),
69 lhs.rank_in(),
70 "ksk_out input rank: {} != ksk_in input rank: {}",
71 self.rank_in(),
72 lhs.rank_in()
73 );
74 assert_eq!(
75 lhs.rank_out(),
76 rhs.rank_in(),
77 "ksk_in output rank: {} != ksk_apply input rank: {}",
78 self.rank_out(),
79 rhs.rank_in()
80 );
81 assert_eq!(
82 self.rank_out(),
83 rhs.rank_out(),
84 "ksk_out output rank: {} != ksk_apply output rank: {}",
85 self.rank_out(),
86 rhs.rank_out()
87 );
88 assert!(
89 self.k() <= lhs.k(),
90 "output k={} cannot be greater than input k={}",
91 self.k(),
92 lhs.k()
93 )
94 }
95
96 let cols_out: usize = rhs.rank_out() + 1;
97
98 let p: i64 = lhs.p();
99 let p_inv = module.galois_element_inv(p);
100
101 (0..self.rank_in()).for_each(|col_i| {
102 (0..self.rows()).for_each(|row_j| {
103 let mut res_ct: GLWECiphertext<&mut [u8]> = self.at_mut(row_j, col_i);
104 let lhs_ct: GLWECiphertext<&[u8]> = lhs.at(row_j, col_i);
105
106 (0..cols_out).for_each(|i| {
108 module.vec_znx_automorphism(lhs.p(), &mut res_ct.data, i, &lhs_ct.data, i);
109 });
110
111 res_ct.keyswitch_inplace(module, &rhs.key, scratch);
113
114 (0..cols_out).for_each(|i| {
116 module.vec_znx_automorphism_inplace(p_inv, &mut res_ct.data, i);
117 });
118 });
119 });
120
121 (self.rows().min(lhs.rows())..self.rows()).for_each(|row_i| {
122 (0..self.rank_in()).for_each(|col_j| {
123 self.at_mut(row_i, col_j).data.zero();
124 });
125 });
126
127 self.p = (lhs.p * rhs.p) % (module.cyclotomic_order() as i64);
128 }
129
130 pub fn automorphism_inplace<DataRhs: DataRef, B: Backend>(
131 &mut self,
132 module: &Module<B>,
133 rhs: &GGLWEAutomorphismKeyPrepared<DataRhs, B>,
134 scratch: &mut Scratch<B>,
135 ) where
136 Module<B>: VecZnxDftAllocBytes
137 + VmpApplyDftToDftTmpBytes
138 + VecZnxBigNormalizeTmpBytes
139 + VmpApplyDftToDft<B>
140 + VmpApplyDftToDftAdd<B>
141 + DFT<B>
142 + IDFTConsume<B>
143 + VecZnxBigAddSmallInplace<B>
144 + VecZnxBigNormalize<B>
145 + VecZnxAutomorphism
146 + VecZnxAutomorphismInplace,
147 Scratch<B>: ScratchAvailable + TakeVecZnxDft<B>,
148 {
149 unsafe {
150 let self_ptr: *mut GGLWEAutomorphismKey<DataSelf> = self as *mut GGLWEAutomorphismKey<DataSelf>;
151 self.automorphism(module, &*self_ptr, rhs, scratch);
152 }
153 }
154}