1use std::sync::Arc;
2
3use fhe_math::{
4 rq::{traits::TryConvertFrom, Poly, Representation},
5 zq::Modulus,
6};
7use itertools::Itertools;
8use rand::{CryptoRng, RngCore};
9use zeroize::Zeroizing;
10
11use crate::bfv::{BfvParameters, Ciphertext, Plaintext, SecretKey};
12use crate::{Error, Result};
13
14use super::Aggregate;
15
16pub struct SecretKeySwitchShare {
26 pub(crate) par: Arc<BfvParameters>,
27 pub(crate) ct: Arc<Ciphertext>,
30 pub(crate) h_share: Poly,
31}
32
33impl SecretKeySwitchShare {
34 pub fn new<R: RngCore + CryptoRng>(
41 sk_input_share: &SecretKey,
42 sk_output_share: &SecretKey,
43 ct: Arc<Ciphertext>,
44 rng: &mut R,
45 ) -> Result<Self> {
46 if sk_input_share.par != sk_output_share.par || sk_output_share.par != ct.par {
47 return Err(Error::DefaultError(
48 "Incompatible BFV parameters".to_string(),
49 ));
50 }
51 if ct.len() != 2 {
53 return Err(Error::TooManyValues {
54 actual: ct.len(),
55 limit: 2,
56 });
57 }
58
59 let par = sk_input_share.par.clone();
60 let mut s_in = Zeroizing::new(Poly::try_convert_from(
61 sk_input_share.coeffs.as_ref(),
62 ct[0].ctx(),
63 false,
64 Representation::PowerBasis,
65 )?);
66 s_in.change_representation(Representation::Ntt);
67 let mut s_out = Zeroizing::new(Poly::try_convert_from(
68 sk_output_share.coeffs.as_ref(),
69 ct[0].ctx(),
70 false,
71 Representation::PowerBasis,
72 )?);
73 s_out.change_representation(Representation::Ntt);
74
75 let e = Zeroizing::new(Poly::small(
78 ct[0].ctx(),
79 Representation::Ntt,
80 par.variance,
81 rng,
82 )?);
83
84 let mut h_share = s_in.as_ref() - s_out.as_ref();
86 h_share.disallow_variable_time_computations();
87 h_share *= &ct[1];
88 h_share += e.as_ref();
89
90 Ok(Self { par, ct, h_share })
91 }
92}
93
94impl Aggregate<SecretKeySwitchShare> for Ciphertext {
95 fn from_shares<T>(iter: T) -> Result<Self>
96 where
97 T: IntoIterator<Item = SecretKeySwitchShare>,
98 {
99 let mut shares = iter.into_iter();
100 let share = shares.next().ok_or(Error::TooFewValues {
101 actual: 0,
102 minimum: 1,
103 })?;
104 let mut h = share.h_share;
105 for sh in shares {
106 h += &sh.h_share;
107 }
108
109 let c0 = &share.ct[0] + &h;
110 let c1 = share.ct[1].clone();
111
112 Ciphertext::new(vec![c0, c1], &share.par)
113 }
114}
115
116pub struct DecryptionShare {
123 pub(crate) sks_share: SecretKeySwitchShare,
124}
125
126impl DecryptionShare {
127 pub fn new<R: RngCore + CryptoRng>(
133 sk_input_share: &SecretKey,
134 ct: &Arc<Ciphertext>,
135 rng: &mut R,
136 ) -> Result<Self> {
137 let par = &sk_input_share.par;
138 let zero = SecretKey::new(vec![0; par.degree()], par);
139 let sks_share = SecretKeySwitchShare::new(sk_input_share, &zero, ct.clone(), rng)?;
140 Ok(DecryptionShare { sks_share })
141 }
142}
143
144impl Aggregate<DecryptionShare> for Plaintext {
145 fn from_shares<T>(iter: T) -> Result<Self>
146 where
147 T: IntoIterator<Item = DecryptionShare>,
148 {
149 let sks_shares = iter.into_iter().map(|s| s.sks_share);
150 let ct = Ciphertext::from_shares(sks_shares)?;
151
152 let mut c = Zeroizing::new(ct[0].clone());
154 c.disallow_variable_time_computations();
155 c.change_representation(Representation::PowerBasis);
156
157 let ctx_lvl = ct.par.context_level_at(ct.level)?;
159 let d = Zeroizing::new(c.scale(&ctx_lvl.cipher_plain_context.scaler)?);
160 let v = Zeroizing::new(
161 Vec::<u64>::from(d.as_ref())
162 .iter_mut()
163 .map(|vi| *vi + *ct.par.plaintext)
164 .collect_vec(),
165 );
166 let mut w = v[..ct.par.degree()].to_vec();
167 let q = Modulus::new(ct.par.moduli[0]).map_err(Error::MathError)?;
168 q.reduce_vec(&mut w);
169 ct.par.plaintext.reduce_vec(&mut w);
170
171 let mut poly = Poly::try_convert_from(&w, ct[0].ctx(), false, Representation::PowerBasis)?;
172 poly.change_representation(Representation::Ntt);
173
174 let pt = Plaintext {
175 par: ct.par.clone(),
176 value: w.into_boxed_slice(),
177 encoding: None,
178 poly_ntt: poly,
179 level: ct.level,
180 };
181
182 Ok(pt)
183 }
184}
185
186#[cfg(test)]
187mod tests {
188 use std::sync::Arc;
189
190 use fhe_traits::{FheDecoder, FheEncoder, FheEncrypter};
191 use rand::rng;
192
193 use crate::{
194 bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey},
195 mbfv::{
196 Aggregate, AggregateIter, CommonRandomPoly, DecryptionShare, PublicKeyShare,
197 SecretKeySwitchShare,
198 },
199 };
200
201 const NUM_PARTIES: usize = 11;
202
203 struct Party {
204 sk_share: SecretKey,
205 pk_share: PublicKeyShare,
206 }
207
208 #[test]
209 fn encrypt_decrypt() {
210 let mut rng = rng();
211 for par in [
212 BfvParameters::default_arc(1, 16),
213 BfvParameters::default_arc(6, 32),
214 ] {
215 for level in 0..=par.max_level() {
216 for _ in 0..20 {
217 let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
218
219 let mut parties: Vec<Party> = vec![];
220
221 for _ in 0..NUM_PARTIES {
223 let sk_share = SecretKey::random(&par, &mut rng);
224 let pk_share =
225 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
226 parties.push(Party { sk_share, pk_share })
227 }
228 let public_key: PublicKey = parties
229 .iter()
230 .map(|p| p.pk_share.clone())
231 .aggregate()
232 .unwrap();
233
234 let pt1 = Plaintext::try_encode(
236 &par.plaintext.random_vec(par.degree(), &mut rng),
237 Encoding::poly_at_level(level),
238 &par,
239 )
240 .unwrap();
241 let ct = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
242
243 let decryption_shares = parties
245 .iter()
246 .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng));
247 let pt2 = Plaintext::from_shares(decryption_shares).unwrap();
248
249 assert_eq!(pt1, pt2);
250 }
251 }
252 }
253 }
254
255 #[test]
256 fn encrypt_keyswitch_decrypt() {
257 let mut rng = rng();
258 for par in [
259 BfvParameters::default_arc(1, 16),
260 BfvParameters::default_arc(6, 32),
261 ] {
262 for level in 0..=par.max_level() {
263 for _ in 0..20 {
264 let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
265
266 let mut parties: Vec<Party> = vec![];
268 for _ in 0..NUM_PARTIES {
269 let sk_share = SecretKey::random(&par, &mut rng);
270 let pk_share =
271 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
272 parties.push(Party { sk_share, pk_share })
273 }
274
275 let public_key =
276 PublicKey::from_shares(parties.iter().map(|p| p.pk_share.clone())).unwrap();
277
278 let pt1 = Plaintext::try_encode(
280 &par.plaintext.random_vec(par.degree(), &mut rng),
281 Encoding::poly_at_level(level),
282 &par,
283 )
284 .unwrap();
285 let ct1 = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
286
287 let mut out_parties = Vec::new();
289 for _ in 0..NUM_PARTIES {
290 let sk_share = SecretKey::random(&par, &mut rng);
291 let pk_share =
292 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
293 out_parties.push(Party { sk_share, pk_share })
294 }
295 let ct2 = parties
296 .iter()
297 .zip(out_parties.iter())
298 .map(|(ip, op)| {
299 SecretKeySwitchShare::new(
300 &ip.sk_share,
301 &op.sk_share,
302 ct1.clone(),
303 &mut rng,
304 )
305 })
306 .aggregate()
307 .unwrap();
308 let ct2 = Arc::new(ct2);
309
310 let pt2 = out_parties
312 .iter()
313 .map(|p| DecryptionShare::new(&p.sk_share, &ct2, &mut rng))
314 .aggregate()
315 .unwrap();
316
317 assert_eq!(pt1, pt2);
318 }
319 }
320 }
321 }
322
323 #[test]
324 fn collective_keys_enable_homomorphic_addition() {
325 let mut rng = rng();
326 for par in [
327 BfvParameters::default_arc(1, 16),
328 BfvParameters::default_arc(6, 32),
329 ] {
330 for level in 0..=par.max_level() {
331 for _ in 0..20 {
332 let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
333
334 let mut parties: Vec<Party> = vec![];
335
336 for _ in 0..NUM_PARTIES {
338 let sk_share = SecretKey::random(&par, &mut rng);
339 let pk_share =
340 PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
341 parties.push(Party { sk_share, pk_share })
342 }
343 let public_key: PublicKey = parties
344 .iter()
345 .map(|p| p.pk_share.clone())
346 .aggregate()
347 .unwrap();
348
349 let a = par.plaintext.random_vec(par.degree(), &mut rng);
351 let b = par.plaintext.random_vec(par.degree(), &mut rng);
352 let mut expected = a.clone();
353 par.plaintext.add_vec(&mut expected, &b);
354
355 let pt_a =
356 Plaintext::try_encode(&a, Encoding::poly_at_level(level), &par).unwrap();
357 let pt_b =
358 Plaintext::try_encode(&b, Encoding::poly_at_level(level), &par).unwrap();
359 let ct_a = public_key.try_encrypt(&pt_a, &mut rng).unwrap();
360 let ct_b = public_key.try_encrypt(&pt_b, &mut rng).unwrap();
361
362 let ct = Arc::new(&ct_a + &ct_b);
364
365 let pt = parties
367 .iter()
368 .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng))
369 .aggregate()
370 .unwrap();
371
372 assert_eq!(
373 Vec::<u64>::try_decode(&pt, Encoding::poly_at_level(level)).unwrap(),
374 expected
375 );
376 }
377 }
378 }
379 }
380}