fhe/mbfv/
secret_key_switch.rs

1use std::sync::Arc;
2
3use fhe_math::{
4    rq::{traits::TryConvertFrom, Poly, Representation},
5    zq::Modulus,
6};
7use itertools::Itertools;
8use rand::{CryptoRng, RngCore};
9use zeroize::Zeroizing;
10
11use crate::bfv::{BfvParameters, Ciphertext, Plaintext, SecretKey};
12use crate::{Error, Result};
13
14use super::Aggregate;
15
16/// A party's share in the secret key switch protocol.
17///
18/// Each party uses the `SecretKeySwitchShare` to generate their share of the
19/// new ciphertext and participate in the "Protocol 3: KeySwitch" protocol
20/// detailed in [Multiparty BFV](https://eprint.iacr.org/2020/304.pdf) (p7). Use the [`Aggregate`] impl to combine the
21/// shares into a [`Ciphertext`].
22///
23/// Note: this protocol assumes the output key is split into the same number of
24/// parties as the input key, and is likely only useful for niche scenarios.
25pub struct SecretKeySwitchShare {
26    pub(crate) par: Arc<BfvParameters>,
27    /// The original input ciphertext
28    // Probably doesn't need to be Arc in real usage but w/e
29    pub(crate) ct: Arc<Ciphertext>,
30    pub(crate) h_share: Poly,
31}
32
33impl SecretKeySwitchShare {
34    /// Participate in a new KeySwitch protocol
35    ///
36    /// 1. *Private input*: BFV input secret key share
37    /// 2. *Private input*: BFV output secret key share
38    /// 3. *Public input*: Input ciphertext to keyswitch
39    // 4. *Public input*: TODO: variance of the ciphertext noise
40    pub fn new<R: RngCore + CryptoRng>(
41        sk_input_share: &SecretKey,
42        sk_output_share: &SecretKey,
43        ct: Arc<Ciphertext>,
44        rng: &mut R,
45    ) -> Result<Self> {
46        if sk_input_share.par != sk_output_share.par || sk_output_share.par != ct.par {
47            return Err(Error::DefaultError(
48                "Incompatible BFV parameters".to_string(),
49            ));
50        }
51        // Note: M-BFV implementation only supports ciphertext of length 2
52        if ct.len() != 2 {
53            return Err(Error::TooManyValues {
54                actual: ct.len(),
55                limit: 2,
56            });
57        }
58
59        let par = sk_input_share.par.clone();
60        let mut s_in = Zeroizing::new(Poly::try_convert_from(
61            sk_input_share.coeffs.as_ref(),
62            ct[0].ctx(),
63            false,
64            Representation::PowerBasis,
65        )?);
66        s_in.change_representation(Representation::Ntt);
67        let mut s_out = Zeroizing::new(Poly::try_convert_from(
68            sk_output_share.coeffs.as_ref(),
69            ct[0].ctx(),
70            false,
71            Representation::PowerBasis,
72        )?);
73        s_out.change_representation(Representation::Ntt);
74
75        // Sample error
76        // TODO this should be exponential in ciphertext noise!
77        let e = Zeroizing::new(Poly::small(
78            ct[0].ctx(),
79            Representation::Ntt,
80            par.variance,
81            rng,
82        )?);
83
84        // Create h_i share
85        let mut h_share = s_in.as_ref() - s_out.as_ref();
86        h_share.disallow_variable_time_computations();
87        h_share *= &ct[1];
88        h_share += e.as_ref();
89
90        Ok(Self { par, ct, h_share })
91    }
92}
93
94impl Aggregate<SecretKeySwitchShare> for Ciphertext {
95    fn from_shares<T>(iter: T) -> Result<Self>
96    where
97        T: IntoIterator<Item = SecretKeySwitchShare>,
98    {
99        let mut shares = iter.into_iter();
100        let share = shares.next().ok_or(Error::TooFewValues {
101            actual: 0,
102            minimum: 1,
103        })?;
104        let mut h = share.h_share;
105        for sh in shares {
106            h += &sh.h_share;
107        }
108
109        let c0 = &share.ct[0] + &h;
110        let c1 = share.ct[1].clone();
111
112        Ciphertext::new(vec![c0, c1], &share.par)
113    }
114}
115
116/// A party's share in the decryption protocol.
117///
118/// Each party uses the `DecryptionShare` to generate their share of the
119/// plaintext output. Note that this is a special case of the "Protocol 3:
120/// KeySwitch" protocol detailed in [Multiparty BFV](https://eprint.iacr.org/2020/304.pdf) (p7), using an output key of zero. Use the
121/// [`Aggregate`] impl to combine the shares into a [`Plaintext`].
122pub struct DecryptionShare {
123    pub(crate) sks_share: SecretKeySwitchShare,
124}
125
126impl DecryptionShare {
127    /// Participate in a new Decryption protocol.
128    ///
129    /// 1. *Private input*: BFV input secret key share
130    /// 3. *Public input*: Ciphertext to decrypt
131    // 4. *Public input*: TODO: variance of the ciphertext noise
132    pub fn new<R: RngCore + CryptoRng>(
133        sk_input_share: &SecretKey,
134        ct: &Arc<Ciphertext>,
135        rng: &mut R,
136    ) -> Result<Self> {
137        let par = &sk_input_share.par;
138        let zero = SecretKey::new(vec![0; par.degree()], par);
139        let sks_share = SecretKeySwitchShare::new(sk_input_share, &zero, ct.clone(), rng)?;
140        Ok(DecryptionShare { sks_share })
141    }
142}
143
144impl Aggregate<DecryptionShare> for Plaintext {
145    fn from_shares<T>(iter: T) -> Result<Self>
146    where
147        T: IntoIterator<Item = DecryptionShare>,
148    {
149        let sks_shares = iter.into_iter().map(|s| s.sks_share);
150        let ct = Ciphertext::from_shares(sks_shares)?;
151
152        // Note: during SKS, c[1]*sk has already been added to c[0].
153        let mut c = Zeroizing::new(ct[0].clone());
154        c.disallow_variable_time_computations();
155        c.change_representation(Representation::PowerBasis);
156
157        // The true decryption part is done during SKS; all that is left is to scale
158        let ctx_lvl = ct.par.context_level_at(ct.level)?;
159        let d = Zeroizing::new(c.scale(&ctx_lvl.cipher_plain_context.scaler)?);
160        let v = Zeroizing::new(
161            Vec::<u64>::from(d.as_ref())
162                .iter_mut()
163                .map(|vi| *vi + *ct.par.plaintext)
164                .collect_vec(),
165        );
166        let mut w = v[..ct.par.degree()].to_vec();
167        let q = Modulus::new(ct.par.moduli[0]).map_err(Error::MathError)?;
168        q.reduce_vec(&mut w);
169        ct.par.plaintext.reduce_vec(&mut w);
170
171        let mut poly = Poly::try_convert_from(&w, ct[0].ctx(), false, Representation::PowerBasis)?;
172        poly.change_representation(Representation::Ntt);
173
174        let pt = Plaintext {
175            par: ct.par.clone(),
176            value: w.into_boxed_slice(),
177            encoding: None,
178            poly_ntt: poly,
179            level: ct.level,
180        };
181
182        Ok(pt)
183    }
184}
185
186#[cfg(test)]
187mod tests {
188    use std::sync::Arc;
189
190    use fhe_traits::{FheDecoder, FheEncoder, FheEncrypter};
191    use rand::rng;
192
193    use crate::{
194        bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey},
195        mbfv::{
196            Aggregate, AggregateIter, CommonRandomPoly, DecryptionShare, PublicKeyShare,
197            SecretKeySwitchShare,
198        },
199    };
200
201    const NUM_PARTIES: usize = 11;
202
203    struct Party {
204        sk_share: SecretKey,
205        pk_share: PublicKeyShare,
206    }
207
208    #[test]
209    fn encrypt_decrypt() {
210        let mut rng = rng();
211        for par in [
212            BfvParameters::default_arc(1, 16),
213            BfvParameters::default_arc(6, 32),
214        ] {
215            for level in 0..=par.max_level() {
216                for _ in 0..20 {
217                    let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
218
219                    let mut parties: Vec<Party> = vec![];
220
221                    // Parties collectively generate public key
222                    for _ in 0..NUM_PARTIES {
223                        let sk_share = SecretKey::random(&par, &mut rng);
224                        let pk_share =
225                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
226                        parties.push(Party { sk_share, pk_share })
227                    }
228                    let public_key: PublicKey = parties
229                        .iter()
230                        .map(|p| p.pk_share.clone())
231                        .aggregate()
232                        .unwrap();
233
234                    // Use it to encrypt a random polynomial
235                    let pt1 = Plaintext::try_encode(
236                        &par.plaintext.random_vec(par.degree(), &mut rng),
237                        Encoding::poly_at_level(level),
238                        &par,
239                    )
240                    .unwrap();
241                    let ct = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
242
243                    // Parties perform a collective decryption
244                    let decryption_shares = parties
245                        .iter()
246                        .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng));
247                    let pt2 = Plaintext::from_shares(decryption_shares).unwrap();
248
249                    assert_eq!(pt1, pt2);
250                }
251            }
252        }
253    }
254
255    #[test]
256    fn encrypt_keyswitch_decrypt() {
257        let mut rng = rng();
258        for par in [
259            BfvParameters::default_arc(1, 16),
260            BfvParameters::default_arc(6, 32),
261        ] {
262            for level in 0..=par.max_level() {
263                for _ in 0..20 {
264                    let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
265
266                    // Parties collectively generate public key
267                    let mut parties: Vec<Party> = vec![];
268                    for _ in 0..NUM_PARTIES {
269                        let sk_share = SecretKey::random(&par, &mut rng);
270                        let pk_share =
271                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
272                        parties.push(Party { sk_share, pk_share })
273                    }
274
275                    let public_key =
276                        PublicKey::from_shares(parties.iter().map(|p| p.pk_share.clone())).unwrap();
277
278                    // Use it to encrypt a random polynomial ct1
279                    let pt1 = Plaintext::try_encode(
280                        &par.plaintext.random_vec(par.degree(), &mut rng),
281                        Encoding::poly_at_level(level),
282                        &par,
283                    )
284                    .unwrap();
285                    let ct1 = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
286
287                    // Key switch ct1 to a different set of parties
288                    let mut out_parties = Vec::new();
289                    for _ in 0..NUM_PARTIES {
290                        let sk_share = SecretKey::random(&par, &mut rng);
291                        let pk_share =
292                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
293                        out_parties.push(Party { sk_share, pk_share })
294                    }
295                    let ct2 = parties
296                        .iter()
297                        .zip(out_parties.iter())
298                        .map(|(ip, op)| {
299                            SecretKeySwitchShare::new(
300                                &ip.sk_share,
301                                &op.sk_share,
302                                ct1.clone(),
303                                &mut rng,
304                            )
305                        })
306                        .aggregate()
307                        .unwrap();
308                    let ct2 = Arc::new(ct2);
309
310                    // The second set of parties then does a collective decryption
311                    let pt2 = out_parties
312                        .iter()
313                        .map(|p| DecryptionShare::new(&p.sk_share, &ct2, &mut rng))
314                        .aggregate()
315                        .unwrap();
316
317                    assert_eq!(pt1, pt2);
318                }
319            }
320        }
321    }
322
323    #[test]
324    fn collective_keys_enable_homomorphic_addition() {
325        let mut rng = rng();
326        for par in [
327            BfvParameters::default_arc(1, 16),
328            BfvParameters::default_arc(6, 32),
329        ] {
330            for level in 0..=par.max_level() {
331                for _ in 0..20 {
332                    let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
333
334                    let mut parties: Vec<Party> = vec![];
335
336                    // Parties collectively generate public key
337                    for _ in 0..NUM_PARTIES {
338                        let sk_share = SecretKey::random(&par, &mut rng);
339                        let pk_share =
340                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
341                        parties.push(Party { sk_share, pk_share })
342                    }
343                    let public_key: PublicKey = parties
344                        .iter()
345                        .map(|p| p.pk_share.clone())
346                        .aggregate()
347                        .unwrap();
348
349                    // Parties encrypt two plaintexts
350                    let a = par.plaintext.random_vec(par.degree(), &mut rng);
351                    let b = par.plaintext.random_vec(par.degree(), &mut rng);
352                    let mut expected = a.clone();
353                    par.plaintext.add_vec(&mut expected, &b);
354
355                    let pt_a =
356                        Plaintext::try_encode(&a, Encoding::poly_at_level(level), &par).unwrap();
357                    let pt_b =
358                        Plaintext::try_encode(&b, Encoding::poly_at_level(level), &par).unwrap();
359                    let ct_a = public_key.try_encrypt(&pt_a, &mut rng).unwrap();
360                    let ct_b = public_key.try_encrypt(&pt_b, &mut rng).unwrap();
361
362                    // and add them together
363                    let ct = Arc::new(&ct_a + &ct_b);
364
365                    // Parties perform a collective decryption
366                    let pt = parties
367                        .iter()
368                        .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng))
369                        .aggregate()
370                        .unwrap();
371
372                    assert_eq!(
373                        Vec::<u64>::try_decode(&pt, Encoding::poly_at_level(level)).unwrap(),
374                        expected
375                    );
376                }
377            }
378        }
379    }
380}