fhe/mbfv/
secret_key_switch.rs

1use std::sync::Arc;
2
3use fhe_math::{
4    rq::{traits::TryConvertFrom, Poly, Representation},
5    zq::Modulus,
6};
7use itertools::Itertools;
8use rand::{CryptoRng, RngCore};
9use zeroize::Zeroizing;
10
11use crate::bfv::{BfvParameters, Ciphertext, Plaintext, SecretKey};
12use crate::{Error, Result};
13
14use super::Aggregate;
15
16/// A party's share in the secret key switch protocol.
17///
18/// Each party uses the `SecretKeySwitchShare` to generate their share of the
19/// new ciphertext and participate in the "Protocol 3: KeySwitch" protocol
20/// detailed in [Multiparty BFV](https://eprint.iacr.org/2020/304.pdf) (p7). Use the [`Aggregate`] impl to combine the
21/// shares into a [`Ciphertext`].
22///
23/// Note: this protocol assumes the output key is split into the same number of
24/// parties as the input key, and is likely only useful for niche scenarios.
25pub struct SecretKeySwitchShare {
26    pub(crate) par: Arc<BfvParameters>,
27    /// The original input ciphertext
28    // Probably doesn't need to be Arc in real usage but w/e
29    pub(crate) ct: Arc<Ciphertext>,
30    pub(crate) h_share: Poly,
31}
32
33impl SecretKeySwitchShare {
34    /// Participate in a new KeySwitch protocol
35    ///
36    /// 1. *Private input*: BFV input secret key share
37    /// 2. *Private input*: BFV output secret key share
38    /// 3. *Public input*: Input ciphertext to keyswitch
39    // 4. *Public input*: TODO: variance of the ciphertext noise
40    pub fn new<R: RngCore + CryptoRng>(
41        sk_input_share: &SecretKey,
42        sk_output_share: &SecretKey,
43        ct: Arc<Ciphertext>,
44        rng: &mut R,
45    ) -> Result<Self> {
46        if sk_input_share.par != sk_output_share.par || sk_output_share.par != ct.par {
47            return Err(Error::DefaultError(
48                "Incompatible BFV parameters".to_string(),
49            ));
50        }
51        // Note: M-BFV implementation only supports ciphertext of length 2
52        if ct.c.len() != 2 {
53            return Err(Error::TooManyValues(ct.c.len(), 2));
54        }
55
56        let par = sk_input_share.par.clone();
57        let mut s_in = Zeroizing::new(Poly::try_convert_from(
58            sk_input_share.coeffs.as_ref(),
59            ct.c[0].ctx(),
60            false,
61            Representation::PowerBasis,
62        )?);
63        s_in.change_representation(Representation::Ntt);
64        let mut s_out = Zeroizing::new(Poly::try_convert_from(
65            sk_output_share.coeffs.as_ref(),
66            ct.c[0].ctx(),
67            false,
68            Representation::PowerBasis,
69        )?);
70        s_out.change_representation(Representation::Ntt);
71
72        // Sample error
73        // TODO this should be exponential in ciphertext noise!
74        let e = Zeroizing::new(Poly::small(
75            ct.c[0].ctx(),
76            Representation::Ntt,
77            par.variance,
78            rng,
79        )?);
80
81        // Create h_i share
82        let mut h_share = s_in.as_ref() - s_out.as_ref();
83        h_share.disallow_variable_time_computations();
84        h_share *= &ct.c[1];
85        h_share += e.as_ref();
86
87        Ok(Self { par, ct, h_share })
88    }
89}
90
91impl Aggregate<SecretKeySwitchShare> for Ciphertext {
92    fn from_shares<T>(iter: T) -> Result<Self>
93    where
94        T: IntoIterator<Item = SecretKeySwitchShare>,
95    {
96        let mut shares = iter.into_iter();
97        let share = shares.next().ok_or(Error::TooFewValues(0, 1))?;
98        let mut h = share.h_share;
99        for sh in shares {
100            h += &sh.h_share;
101        }
102
103        let c0 = &share.ct.c[0] + &h;
104        let c1 = share.ct.c[1].clone();
105
106        Ciphertext::new(vec![c0, c1], &share.par)
107    }
108}
109
110/// A party's share in the decryption protocol.
111///
112/// Each party uses the `DecryptionShare` to generate their share of the
113/// plaintext output. Note that this is a special case of the "Protocol 3:
114/// KeySwitch" protocol detailed in [Multiparty BFV](https://eprint.iacr.org/2020/304.pdf) (p7), using an output key of zero. Use the
115/// [`Aggregate`] impl to combine the shares into a [`Plaintext`].
116pub struct DecryptionShare {
117    pub(crate) sks_share: SecretKeySwitchShare,
118}
119
120impl DecryptionShare {
121    /// Participate in a new Decryption protocol.
122    ///
123    /// 1. *Private input*: BFV input secret key share
124    /// 3. *Public input*: Ciphertext to decrypt
125    // 4. *Public input*: TODO: variance of the ciphertext noise
126    pub fn new<R: RngCore + CryptoRng>(
127        sk_input_share: &SecretKey,
128        ct: &Arc<Ciphertext>,
129        rng: &mut R,
130    ) -> Result<Self> {
131        let par = &sk_input_share.par;
132        let zero = SecretKey::new(vec![0; par.degree()], par);
133        let sks_share = SecretKeySwitchShare::new(sk_input_share, &zero, ct.clone(), rng)?;
134        Ok(DecryptionShare { sks_share })
135    }
136}
137
138impl Aggregate<DecryptionShare> for Plaintext {
139    fn from_shares<T>(iter: T) -> Result<Self>
140    where
141        T: IntoIterator<Item = DecryptionShare>,
142    {
143        let sks_shares = iter.into_iter().map(|s| s.sks_share);
144        let ct = Ciphertext::from_shares(sks_shares)?;
145        let par = ct.par;
146
147        // Note: during SKS, c[1]*sk has already been added to c[0].
148        let mut c = Zeroizing::new(ct.c[0].clone());
149        c.disallow_variable_time_computations();
150        c.change_representation(Representation::PowerBasis);
151
152        // The true decryption part is done during SKS; all that is left is to scale
153        let d = Zeroizing::new(c.scale(&par.scalers[ct.level])?);
154        let v = Zeroizing::new(
155            Vec::<u64>::from(d.as_ref())
156                .iter_mut()
157                .map(|vi| *vi + par.plaintext.modulus())
158                .collect_vec(),
159        );
160        let mut w = v[..par.degree()].to_vec();
161        let q = Modulus::new(par.moduli[0]).map_err(Error::MathError)?;
162        q.reduce_vec(&mut w);
163        par.plaintext.reduce_vec(&mut w);
164
165        let mut poly =
166            Poly::try_convert_from(&w, ct.c[0].ctx(), false, Representation::PowerBasis)?;
167        poly.change_representation(Representation::Ntt);
168
169        let pt = Plaintext {
170            par: par.clone(),
171            value: w.into_boxed_slice(),
172            encoding: None,
173            poly_ntt: poly,
174            level: ct.level,
175        };
176
177        Ok(pt)
178    }
179}
180
181#[cfg(test)]
182mod tests {
183    use std::sync::Arc;
184
185    use fhe_traits::{FheDecoder, FheEncoder, FheEncrypter};
186    use rand::thread_rng;
187
188    use crate::{
189        bfv::{BfvParameters, Encoding, Plaintext, PublicKey, SecretKey},
190        mbfv::{Aggregate, AggregateIter, CommonRandomPoly, PublicKeyShare},
191    };
192
193    use super::*;
194
195    const NUM_PARTIES: usize = 11;
196
197    struct Party {
198        sk_share: SecretKey,
199        pk_share: PublicKeyShare,
200    }
201
202    #[test]
203    fn encrypt_decrypt() {
204        let mut rng = thread_rng();
205        for par in [
206            BfvParameters::default_arc(1, 16),
207            BfvParameters::default_arc(6, 32),
208        ] {
209            for level in 0..=par.max_level() {
210                for _ in 0..20 {
211                    let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
212
213                    let mut parties: Vec<Party> = vec![];
214
215                    // Parties collectively generate public key
216                    for _ in 0..NUM_PARTIES {
217                        let sk_share = SecretKey::random(&par, &mut rng);
218                        let pk_share =
219                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
220                        parties.push(Party { sk_share, pk_share })
221                    }
222                    let public_key: PublicKey = parties
223                        .iter()
224                        .map(|p| p.pk_share.clone())
225                        .aggregate()
226                        .unwrap();
227
228                    // Use it to encrypt a random polynomial
229                    let pt1 = Plaintext::try_encode(
230                        &par.plaintext.random_vec(par.degree(), &mut rng),
231                        Encoding::poly_at_level(level),
232                        &par,
233                    )
234                    .unwrap();
235                    let ct = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
236
237                    // Parties perform a collective decryption
238                    let decryption_shares = parties
239                        .iter()
240                        .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng));
241                    let pt2 = Plaintext::from_shares(decryption_shares).unwrap();
242
243                    assert_eq!(pt1, pt2);
244                }
245            }
246        }
247    }
248
249    #[test]
250    fn encrypt_keyswitch_decrypt() {
251        let mut rng = thread_rng();
252        for par in [
253            BfvParameters::default_arc(1, 16),
254            BfvParameters::default_arc(6, 32),
255        ] {
256            for level in 0..=par.max_level() {
257                for _ in 0..20 {
258                    let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
259
260                    // Parties collectively generate public key
261                    let mut parties: Vec<Party> = vec![];
262                    for _ in 0..NUM_PARTIES {
263                        let sk_share = SecretKey::random(&par, &mut rng);
264                        let pk_share =
265                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
266                        parties.push(Party { sk_share, pk_share })
267                    }
268
269                    let public_key =
270                        PublicKey::from_shares(parties.iter().map(|p| p.pk_share.clone())).unwrap();
271
272                    // Use it to encrypt a random polynomial ct1
273                    let pt1 = Plaintext::try_encode(
274                        &par.plaintext.random_vec(par.degree(), &mut rng),
275                        Encoding::poly_at_level(level),
276                        &par,
277                    )
278                    .unwrap();
279                    let ct1 = Arc::new(public_key.try_encrypt(&pt1, &mut rng).unwrap());
280
281                    // Key switch ct1 to a different set of parties
282                    let mut out_parties = Vec::new();
283                    for _ in 0..NUM_PARTIES {
284                        let sk_share = SecretKey::random(&par, &mut rng);
285                        let pk_share =
286                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
287                        out_parties.push(Party { sk_share, pk_share })
288                    }
289                    let ct2 = parties
290                        .iter()
291                        .zip(out_parties.iter())
292                        .map(|(ip, op)| {
293                            SecretKeySwitchShare::new(
294                                &ip.sk_share,
295                                &op.sk_share,
296                                ct1.clone(),
297                                &mut rng,
298                            )
299                        })
300                        .aggregate()
301                        .unwrap();
302                    let ct2 = Arc::new(ct2);
303
304                    // The second set of parties then does a collective decryption
305                    let pt2 = out_parties
306                        .iter()
307                        .map(|p| DecryptionShare::new(&p.sk_share, &ct2, &mut rng))
308                        .aggregate()
309                        .unwrap();
310
311                    assert_eq!(pt1, pt2);
312                }
313            }
314        }
315    }
316
317    #[test]
318    fn collective_keys_enable_homomorphic_addition() {
319        let mut rng = thread_rng();
320        for par in [
321            BfvParameters::default_arc(1, 16),
322            BfvParameters::default_arc(6, 32),
323        ] {
324            for level in 0..=par.max_level() {
325                for _ in 0..20 {
326                    let crp = CommonRandomPoly::new(&par, &mut rng).unwrap();
327
328                    let mut parties: Vec<Party> = vec![];
329
330                    // Parties collectively generate public key
331                    for _ in 0..NUM_PARTIES {
332                        let sk_share = SecretKey::random(&par, &mut rng);
333                        let pk_share =
334                            PublicKeyShare::new(&sk_share, crp.clone(), &mut rng).unwrap();
335                        parties.push(Party { sk_share, pk_share })
336                    }
337                    let public_key: PublicKey = parties
338                        .iter()
339                        .map(|p| p.pk_share.clone())
340                        .aggregate()
341                        .unwrap();
342
343                    // Parties encrypt two plaintexts
344                    let a = par.plaintext.random_vec(par.degree(), &mut rng);
345                    let b = par.plaintext.random_vec(par.degree(), &mut rng);
346                    let mut expected = a.clone();
347                    par.plaintext.add_vec(&mut expected, &b);
348
349                    let pt_a =
350                        Plaintext::try_encode(&a, Encoding::poly_at_level(level), &par).unwrap();
351                    let pt_b =
352                        Plaintext::try_encode(&b, Encoding::poly_at_level(level), &par).unwrap();
353                    let ct_a = public_key.try_encrypt(&pt_a, &mut rng).unwrap();
354                    let ct_b = public_key.try_encrypt(&pt_b, &mut rng).unwrap();
355
356                    // and add them together
357                    let ct = Arc::new(&ct_a + &ct_b);
358
359                    // Parties perform a collective decryption
360                    let pt = parties
361                        .iter()
362                        .map(|p| DecryptionShare::new(&p.sk_share, &ct, &mut rng))
363                        .aggregate()
364                        .unwrap();
365
366                    assert_eq!(
367                        Vec::<u64>::try_decode(&pt, Encoding::poly_at_level(level)).unwrap(),
368                        expected
369                    );
370                }
371            }
372        }
373    }
374}