1use std::sync::Arc;
2
3use fhe_math::{
4 rq::{traits::TryConvertFrom, Poly, Representation},
5 zq::Modulus,
6};
7use itertools::Itertools;
8use rand::{CryptoRng, RngCore};
9use zeroize::Zeroizing;
10
11use crate::bfv::{BfvParameters, Ciphertext, Plaintext, SecretKey};
12use crate::{Error, Result};
13
14use super::Aggregate;
15
16pub struct SecretKeySwitchShare {
26 pub(crate) par: Arc<BfvParameters>,
27 pub(crate) ct: Arc<Ciphertext>,
30 pub(crate) h_share: Poly,
31}
32
33impl SecretKeySwitchShare {
34 pub fn new<R: RngCore + CryptoRng>(
41 sk_input_share: &SecretKey,
42 sk_output_share: &SecretKey,
43 ct: Arc<Ciphertext>,
44 rng: &mut R,
45 ) -> Result<Self> {
46 if sk_input_share.par != sk_output_share.par || sk_output_share.par != ct.par {
47 return Err(Error::DefaultError(
48 "Incompatible BFV parameters".to_string(),
49 ));
50 }
51 if ct.c.len() != 2 {
53 return Err(Error::TooManyValues(ct.c.len(), 2));
54 }
55
56 let par = sk_input_share.par.clone();
57 let mut s_in = Zeroizing::new(Poly::try_convert_from(
58 sk_input_share.coeffs.as_ref(),
59 ct.c[0].ctx(),
60 false,
61 Representation::PowerBasis,
62 )?);
63 s_in.change_representation(Representation::Ntt);
64 let mut s_out = Zeroizing::new(Poly::try_convert_from(
65 sk_output_share.coeffs.as_ref(),
66 ct.c[0].ctx(),
67 false,
68 Representation::PowerBasis,
69 )?);
70 s_out.change_representation(Representation::Ntt);
71
72 let e = Zeroizing::new(Poly::small(
75 ct.c[0].ctx(),
76 Representation::Ntt,
77 par.variance,
78 rng,
79 )?);
80
81 let mut h_share = s_in.as_ref() - s_out.as_ref();
83 h_share.disallow_variable_time_computations();
84 h_share *= &ct.c[1];
85 h_share += e.as_ref();
86
87 Ok(Self { par, ct, h_share })
88 }
89}
90
91impl Aggregate<SecretKeySwitchShare> for Ciphertext {
92 fn from_shares<T>(iter: T) -> Result<Self>
93 where
94 T: IntoIterator<Item = SecretKeySwitchShare>,
95 {
96 let mut shares = iter.into_iter();
97 let share = shares.next().ok_or(Error::TooFewValues(0, 1))?;
98 let mut h = share.h_share;
99 for sh in shares {
100 h += &sh.h_share;
101 }
102
103 let c0 = &share.ct.c[0] + &h;
104 let c1 = share.ct.c[1].clone();
105
106 Ciphertext::new(vec![c0, c1], &share.par)
107 }
108}
109
110pub struct DecryptionShare {
117 pub(crate) sks_share: SecretKeySwitchShare,
118}
119
120impl DecryptionShare {
121 pub fn new<R: RngCore + CryptoRng>(
127 sk_input_share: &SecretKey,
128 ct: &Arc<Ciphertext>,
129 rng: &mut R,
130 ) -> Result<Self> {
131 let par = &sk_input_share.par;
132 let zero = SecretKey::new(vec![0; par.degree()], par);
133 let sks_share = SecretKeySwitchShare::new(sk_input_share, &zero, ct.clone(), rng)?;
134 Ok(DecryptionShare { sks_share })
135 }
136}
137
138impl Aggregate<DecryptionShare> for Plaintext {
139 fn from_shares<T>(iter: T) -> Result<Self>
140 where
141 T: IntoIterator<Item = DecryptionShare>,
142 {
143 let sks_shares = iter.into_iter().map(|s| s.sks_share);
144 let ct = Ciphertext::from_shares(sks_shares)?;
145 let par = ct.par;
146
147 let mut c = Zeroizing::new(ct.c[0].clone());
149 c.disallow_variable_time_computations();
150 c.change_representation(Representation::PowerBasis);
151
152 let d = Zeroizing::new(c.scale(&par.scalers[ct.level])?);
154 let v = Zeroizing::new(
155 Vec::<u64>::from(d.as_ref())
156 .iter_mut()
157 .map(|vi| *vi + par.plaintext.modulus())
158 .collect_vec(),
159 );
160 let mut w = v[..par.degree()].to_vec();
161 let q = Modulus::new(par.moduli[0]).map_err(Error::MathError)?;
162 q.reduce_vec(&mut w);
163 par.plaintext.reduce_vec(&mut w);
164
165 let mut poly =
166 Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?;
167 poly.change_representation(Representation::Ntt);
168
169 let pt = Plaintext {
170 par: par.clone(),
171 value: w.into_boxed_slice(),
172 encoding: None,
173 poly_ntt: poly,
174 level: ct.level,
175 };
176
177 Ok(pt)
178 }
179}
180
181#[cfg(test)]
182mod tests {
183 use std::sync::Arc;
184
185 use fhe_traits::{FheDecoder, FheEncoder, FheEncrypter};
186 use rand::thread_rng;
187
188 use crate::{
189 bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey},
190 mbfv::{Aggregate, AggregateIter, CommonRandomPoly, PublicKeyShare},
191 };
192
193 use super::*;
194
195 const NUM_PARTIES: usize = 11;
196
197 struct Party {
198 sk_share: SecretKey,
199 pk_share: PublicKeyShare,
200 }
201
202 #[test]
203 fn encrypt_decrypt() {
204 let mut rng = thread_rng();
205 for par in [
206 BfvParameters::default_arc(1, 16),
207 BfvParameters::default_arc(6, 32),
208 ] {
209 for level in 0..=par.max_level() {
210 for _ in 0..20 {
211 let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
212
213 let mut parties: Vec<Party> = vec![];
214
215 for _ in 0..NUM_PARTIES {
217 let sk_share = SecretKey::random(&par, &mut rng);
218 let pk_share =
219 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
220 parties.push(Party { sk_share, pk_share })
221 }
222 let public_key: PublicKey = parties
223 .iter()
224 .map(|p| p.pk_share.clone())
225 .aggregate()
226 .unwrap();
227
228 let pt1 = Plaintext::try_encode(
230 &par.plaintext.random_vec(par.degree(), &mut rng),
231 Encoding::poly_at_level(level),
232 &par,
233 )
234 .unwrap();
235 let ct = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
236
237 let decryption_shares = parties
239 .iter()
240 .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng));
241 let pt2 = Plaintext::from_shares(decryption_shares).unwrap();
242
243 assert_eq!(pt1, pt2);
244 }
245 }
246 }
247 }
248
249 #[test]
250 fn encrypt_keyswitch_decrypt() {
251 let mut rng = thread_rng();
252 for par in [
253 BfvParameters::default_arc(1, 16),
254 BfvParameters::default_arc(6, 32),
255 ] {
256 for level in 0..=par.max_level() {
257 for _ in 0..20 {
258 let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
259
260 let mut parties: Vec<Party> = vec![];
262 for _ in 0..NUM_PARTIES {
263 let sk_share = SecretKey::random(&par, &mut rng);
264 let pk_share =
265 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
266 parties.push(Party { sk_share, pk_share })
267 }
268
269 let public_key =
270 PublicKey::from_shares(parties.iter().map(|p| p.pk_share.clone())).unwrap();
271
272 let pt1 = Plaintext::try_encode(
274 &par.plaintext.random_vec(par.degree(), &mut rng),
275 Encoding::poly_at_level(level),
276 &par,
277 )
278 .unwrap();
279 let ct1 = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
280
281 let mut out_parties = Vec::new();
283 for _ in 0..NUM_PARTIES {
284 let sk_share = SecretKey::random(&par, &mut rng);
285 let pk_share =
286 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
287 out_parties.push(Party { sk_share, pk_share })
288 }
289 let ct2 = parties
290 .iter()
291 .zip(out_parties.iter())
292 .map(|(ip, op)| {
293 SecretKeySwitchShare::new(
294 &ip.sk_share,
295 &op.sk_share,
296 ct1.clone(),
297 &mut rng,
298 )
299 })
300 .aggregate()
301 .unwrap();
302 let ct2 = Arc::new(ct2);
303
304 let pt2 = out_parties
306 .iter()
307 .map(|p| DecryptionShare::new(&p.sk_share, &ct2, &mut rng))
308 .aggregate()
309 .unwrap();
310
311 assert_eq!(pt1, pt2);
312 }
313 }
314 }
315 }
316
317 #[test]
318 fn collective_keys_enable_homomorphic_addition() {
319 let mut rng = thread_rng();
320 for par in [
321 BfvParameters::default_arc(1, 16),
322 BfvParameters::default_arc(6, 32),
323 ] {
324 for level in 0..=par.max_level() {
325 for _ in 0..20 {
326 let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
327
328 let mut parties: Vec<Party> = vec![];
329
330 for _ in 0..NUM_PARTIES {
332 let sk_share = SecretKey::random(&par, &mut rng);
333 let pk_share =
334 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
335 parties.push(Party { sk_share, pk_share })
336 }
337 let public_key: PublicKey = parties
338 .iter()
339 .map(|p| p.pk_share.clone())
340 .aggregate()
341 .unwrap();
342
343 let a = par.plaintext.random_vec(par.degree(), &mut rng);
345 let b = par.plaintext.random_vec(par.degree(), &mut rng);
346 let mut expected = a.clone();
347 par.plaintext.add_vec(&mut expected, &b);
348
349 let pt_a =
350 Plaintext::try_encode(&a, Encoding::poly_at_level(level), &par).unwrap();
351 let pt_b =
352 Plaintext::try_encode(&b, Encoding::poly_at_level(level), &par).unwrap();
353 let ct_a = public_key.try_encrypt(&pt_a, &mut rng).unwrap();
354 let ct_b = public_key.try_encrypt(&pt_b, &mut rng).unwrap();
355
356 let ct = Arc::new(&ct_a + &ct_b);
358
359 let pt = parties
361 .iter()
362 .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng))
363 .aggregate()
364 .unwrap();
365
366 assert_eq!(
367 Vec::<u64>::try_decode(&pt, Encoding::poly_at_level(level)).unwrap(),
368 expected
369 );
370 }
371 }
372 }
373 }
374}