concrete_shortint/server_key/
mod.rs

1//! Module with the definition of the ServerKey.
2//!
3//! This module implements the generation of the server public key, together with all the
4//! available homomorphic integer operations.
5mod add;
6mod bitwise_op;
7mod comp_op;
8mod div_mod;
9mod mul;
10mod neg;
11mod scalar_add;
12mod scalar_mul;
13mod scalar_sub;
14mod shift;
15mod sub;
16
17#[cfg(test)]
18mod tests;
19
20use crate::ciphertext::Ciphertext;
21use crate::client_key::ClientKey;
22use crate::engine::ShortintEngine;
23use crate::parameters::{CarryModulus, MessageModulus};
24use concrete_core::prelude::*;
25use serde::{Deserialize, Deserializer, Serialize, Serializer};
26use std::fmt::{Debug, Display, Formatter};
27
28/// Maximum value that the degree can reach.
29#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
30pub struct MaxDegree(pub usize);
31
32/// Error returned when the carry buffer is full.
33#[derive(Debug)]
34pub enum CheckError {
35    CarryFull,
36}
37
38impl Display for CheckError {
39    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40        match self {
41            CheckError::CarryFull => {
42                write!(f, "The carry buffer is full")
43            }
44        }
45    }
46}
47
48impl std::error::Error for CheckError {}
49
50/// A structure containing the server public key.
51///
52/// The server key is generated by the client and is meant to be published: the client
53/// sends it to the server so it can compute homomorphic circuits.
54#[derive(Clone, Debug, PartialEq)]
55pub struct ServerKey {
56    pub key_switching_key: LweKeyswitchKey64,
57    pub bootstrapping_key: FftFourierLweBootstrapKey64,
58    // Size of the message buffer
59    pub message_modulus: MessageModulus,
60    // Size of the carry buffer
61    pub carry_modulus: CarryModulus,
62    // Maximum number of operations that can be done before emptying the operation buffer
63    pub max_degree: MaxDegree,
64}
65
66impl ServerKey {
67    /// Generates a server key.
68    ///
69    /// # Example
70    ///
71    /// ```rust
72    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
73    /// use concrete_shortint::{gen_keys, ServerKey};
74    ///
75    /// // Generate the client key and the server key:
76    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
77    ///
78    /// // Generate the server key:
79    /// let sks = ServerKey::new(&cks);
80    /// ```
81    pub fn new(cks: &ClientKey) -> ServerKey {
82        ShortintEngine::with_thread_local_mut(|engine| engine.new_server_key(cks).unwrap())
83    }
84
85    /// Generates a server key with a chosen maximum degree
86    pub fn new_with_max_degree(cks: &ClientKey, max_degree: MaxDegree) -> ServerKey {
87        ShortintEngine::with_thread_local_mut(|engine| {
88            engine
89                .new_server_key_with_max_degree(cks, max_degree)
90                .unwrap()
91        })
92    }
93
94    /// Constructs the accumulator given a function as input.
95    ///
96    /// # Example
97    ///
98    /// ```rust
99    /// use concrete_shortint::gen_keys;
100    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
101    ///
102    /// // Generate the client key and the server key:
103    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
104    ///
105    /// let msg = 3;
106    ///
107    /// let ct = cks.encrypt(msg);
108    ///
109    /// // Generate the accumulator for the function f: x -> x^2 mod 2^2
110    /// let f = |x| x ^ 2 % 4;
111    ///
112    /// let acc = sks.generate_accumulator(f);
113    /// let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc);
114    ///
115    /// let dec = cks.decrypt(&ct_res);
116    /// // 3^2 mod 4 = 1
117    /// assert_eq!(dec, f(msg));
118    /// ```
119    pub fn generate_accumulator<F>(&self, f: F) -> GlweCiphertext64
120    where
121        F: Fn(u64) -> u64,
122    {
123        ShortintEngine::with_thread_local_mut(|engine| {
124            engine.generate_accumulator(self, f).unwrap()
125        })
126    }
127
128    /// Computes a keyswitch and a bootstrap, returning a new ciphertext with empty
129    /// carry bits.
130    ///
131    /// # Example
132    ///
133    /// ```rust
134    /// use concrete_shortint::gen_keys;
135    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
136    ///
137    /// // Generate the client key and the server key:
138    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
139    ///
140    /// let mut ct1 = cks.encrypt(3);
141    /// // |      ct1        |
142    /// // | carry | message |
143    /// // |-------|---------|
144    /// // |  0 0  |   1 1   |
145    /// let mut ct2 = cks.encrypt(2);
146    /// // |      ct2        |
147    /// // | carry | message |
148    /// // |-------|---------|
149    /// // |  0 0  |   1 0   |
150    ///
151    /// let ct_res = sks.smart_add(&mut ct1, &mut ct2);
152    /// // |     ct_res      |
153    /// // | carry | message |
154    /// // |-------|---------|
155    /// // |  0 1  |   0 1   |
156    ///
157    /// // Get the carry
158    /// let ct_carry = sks.carry_extract(&ct_res);
159    /// let carry = cks.decrypt(&ct_carry);
160    /// assert_eq!(carry, 1);
161    ///
162    /// let ct_res = sks.keyswitch_bootstrap(&ct_res);
163    ///
164    /// let ct_carry = sks.carry_extract(&ct_res);
165    /// let carry = cks.decrypt(&ct_carry);
166    /// assert_eq!(carry, 0);
167    ///
168    /// let clear = cks.decrypt(&ct_res);
169    ///
170    /// assert_eq!(clear, (3 + 2) % 4);
171    /// ```
172    pub fn keyswitch_bootstrap(&self, ct_in: &Ciphertext) -> Ciphertext {
173        ShortintEngine::with_thread_local_mut(|engine| {
174            engine.keyswitch_bootstrap(self, ct_in).unwrap()
175        })
176    }
177
178    pub fn keyswitch_bootstrap_assign(&self, ct_in: &mut Ciphertext) {
179        ShortintEngine::with_thread_local_mut(|engine| {
180            engine.keyswitch_bootstrap_assign(self, ct_in).unwrap()
181        })
182    }
183
184    /// Computes a keyswitch and programmable bootstrap.
185    ///
186    /// # Example
187    ///
188    /// ```rust
189    /// use concrete_shortint::gen_keys;
190    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
191    ///
192    /// // Generate the client key and the server key:
193    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
194    ///
195    /// let msg: u64 = 3;
196    /// let ct = cks.encrypt(msg);
197    /// let modulus = cks.parameters.message_modulus.0 as u64;
198    ///
199    /// // Generate the accumulator for the function f: x -> x^3 mod 2^2
200    /// let acc = sks.generate_accumulator(|x| x * x * x % modulus);
201    /// let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc);
202    ///
203    /// let dec = cks.decrypt(&ct_res);
204    /// // 3^3 mod 4 = 3
205    /// assert_eq!(dec, (msg * msg * msg) % modulus);
206    /// ```
207    pub fn keyswitch_programmable_bootstrap(
208        &self,
209        ct_in: &Ciphertext,
210        acc: &GlweCiphertext64,
211    ) -> Ciphertext {
212        ShortintEngine::with_thread_local_mut(|engine| {
213            engine
214                .programmable_bootstrap_keyswitch(self, ct_in, acc)
215                .unwrap()
216        })
217    }
218
219    pub fn keyswitch_programmable_bootstrap_assign(
220        &self,
221        ct_in: &mut Ciphertext,
222        acc: &GlweCiphertext64,
223    ) {
224        ShortintEngine::with_thread_local_mut(|engine| {
225            engine
226                .programmable_bootstrap_keyswitch_assign(self, ct_in, acc)
227                .unwrap()
228        })
229    }
230
231    /// Generic programmable bootstrap where messages are concatenated
232    /// into one ciphertext to compute bivariate functions.
233    /// This is used to apply many binary operations (comparisons, multiplications, division).
234    pub fn unchecked_functional_bivariate_pbs<F>(
235        &self,
236        ct_left: &Ciphertext,
237        ct_right: &Ciphertext,
238        f: F,
239    ) -> Ciphertext
240    where
241        F: Fn(u64) -> u64,
242    {
243        ShortintEngine::with_thread_local_mut(|engine| {
244            engine
245                .unchecked_functional_bivariate_pbs(self, ct_left, ct_right, f)
246                .unwrap()
247        })
248    }
249
250    pub fn unchecked_functional_bivariate_pbs_assign<F>(
251        &self,
252        ct_left: &mut Ciphertext,
253        ct_right: &Ciphertext,
254        f: F,
255    ) where
256        F: Fn(u64) -> u64,
257    {
258        ShortintEngine::with_thread_local_mut(|engine| {
259            engine
260                .unchecked_functional_bivariate_pbs_assign(self, ct_left, ct_right, f)
261                .unwrap()
262        })
263    }
264
265    /// Verifies if a bivariate functional pbs can be applied on ct_left and ct_right.
266    pub fn is_functional_bivariate_pbs_possible(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> bool {
267        //product of the degree
268        let final_degree = ct1.degree.0 * (ct2.degree.0 + 1) + ct2.degree.0;
269        final_degree < ct1.carry_modulus.0 * ct1.message_modulus.0
270    }
271
272    /// Replace the input encrypted message by the value of its carry buffer.
273    ///
274    /// # Example
275    ///
276    ///```rust
277    /// use concrete_shortint::gen_keys;
278    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
279    ///
280    /// // Generate the client key and the server key:
281    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
282    ///
283    /// let clear = 9;
284    ///
285    /// // Encrypt a message
286    /// let mut ct = cks.unchecked_encrypt(clear);
287    ///
288    /// // |       ct        |
289    /// // | carry | message |
290    /// // |-------|---------|
291    /// // |  1 0  |   0 1   |
292    ///
293    /// // Compute homomorphically carry extraction
294    /// sks.carry_extract_assign(&mut ct);
295    ///
296    /// // |       ct        |
297    /// // | carry | message |
298    /// // |-------|---------|
299    /// // |  0 0  |   1 0   |
300    ///
301    /// // Decrypt:
302    /// let res = cks.decrypt_message_and_carry(&ct);
303    /// assert_eq!(2, res);
304    /// ```
305    pub fn carry_extract_assign(&self, ct: &mut Ciphertext) {
306        ShortintEngine::with_thread_local_mut(|engine| {
307            engine.carry_extract_assign(self, ct).unwrap()
308        })
309    }
310
311    /// Extracts a new ciphertext encrypting the input carry buffer.
312    ///
313    /// # Example
314    ///
315    ///```rust
316    /// use concrete_shortint::gen_keys;
317    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
318    ///
319    /// // Generate the client key and the server key:
320    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
321    ///
322    /// let clear = 9;
323    ///
324    /// // Encrypt a message
325    /// let ct = cks.unchecked_encrypt(clear);
326    ///
327    /// // |       ct        |
328    /// // | carry | message |
329    /// // |-------|---------|
330    /// // |  1 0  |   0 1   |
331    ///
332    /// // Compute homomorphically carry extraction
333    /// let ct_res = sks.carry_extract(&ct);
334    ///
335    /// // |     ct_res      |
336    /// // | carry | message |
337    /// // |-------|---------|
338    /// // |  0 0  |   1 0   |
339    ///
340    /// // Decrypt:
341    /// let res = cks.decrypt(&ct_res);
342    /// assert_eq!(2, res);
343    /// ```
344    pub fn carry_extract(&self, ct: &Ciphertext) -> Ciphertext {
345        ShortintEngine::with_thread_local_mut(|engine| engine.carry_extract(self, ct).unwrap())
346    }
347
348    /// Clears the carry buffer of the input ciphertext.
349    ///
350    /// # Example
351    ///
352    ///```rust
353    /// use concrete_shortint::gen_keys;
354    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
355    ///
356    /// // Generate the client key and the server key:
357    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
358    ///
359    /// let clear = 9;
360    ///
361    /// // Encrypt a message
362    /// let mut ct = cks.unchecked_encrypt(clear);
363    ///
364    /// // |       ct        |
365    /// // | carry | message |
366    /// // |-------|---------|
367    /// // |  1 0  |   0 1   |
368    ///
369    /// // Compute homomorphically the message extraction
370    /// sks.message_extract_assign(&mut ct);
371    ///
372    /// // |       ct        |
373    /// // | carry | message |
374    /// // |-------|---------|
375    /// // |  0 0  |   0 1   |
376    ///
377    /// // Decrypt:
378    /// let res = cks.decrypt(&ct);
379    /// assert_eq!(1, res);
380    /// ```
381    pub fn message_extract_assign(&self, ct: &mut Ciphertext) {
382        ShortintEngine::with_thread_local_mut(|engine| {
383            engine.message_extract_assign(self, ct).unwrap()
384        })
385    }
386
387    /// Extracts a new ciphertext containing only the message i.e., with a cleared carry buffer.
388    ///
389    /// # Example
390    ///
391    ///```rust
392    /// use concrete_shortint::gen_keys;
393    /// use concrete_shortint::parameters::PARAM_MESSAGE_1_CARRY_1;
394    ///
395    /// // Generate the client key and the server key:
396    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_1_CARRY_1);
397    ///
398    /// let clear = 9;
399    ///
400    /// // Encrypt a message
401    /// let ct = cks.unchecked_encrypt(clear);
402    ///
403    /// // |       ct        |
404    /// // | carry | message |
405    /// // |-------|---------|
406    /// // |  1 0  |   0 1   |
407    ///
408    /// // Compute homomorphically the message extraction
409    /// let ct_res = sks.message_extract(&ct);
410    ///
411    /// // |     ct_res      |
412    /// // | carry | message |
413    /// // |-------|---------|
414    /// // |  0 0  |   0 1   |
415    ///
416    /// // Decrypt:
417    /// let res = cks.decrypt(&ct_res);
418    /// assert_eq!(1, res);
419    /// ```
420    pub fn message_extract(&self, ct: &Ciphertext) -> Ciphertext {
421        ShortintEngine::with_thread_local_mut(|engine| engine.message_extract(self, ct).unwrap())
422    }
423
424    /// Computes a trivial shortint from a given value.
425    ///
426    /// # Example
427    ///
428    /// ```rust
429    /// use concrete_shortint::gen_keys;
430    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
431    ///
432    /// // Generate the client key and the server key:
433    /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
434    ///
435    /// let msg = 1;
436    ///
437    /// // Trivial encryption
438    /// let ct1 = sks.create_trivial(msg);
439    ///
440    /// let ct_res = cks.decrypt(&ct1);
441    /// assert_eq!(1, ct_res);
442    /// ```
443    pub fn create_trivial(&self, value: u8) -> Ciphertext {
444        ShortintEngine::with_thread_local_mut(|engine| engine.create_trivial(self, value).unwrap())
445    }
446
447    pub fn create_trivial_assign(&self, ct: &mut Ciphertext, value: u8) {
448        ShortintEngine::with_thread_local_mut(|engine| {
449            engine.create_trivial_assign(self, ct, value).unwrap()
450        })
451    }
452}
453
454#[derive(Serialize, Deserialize)]
455pub(super) struct SerializableServerKey {
456    pub key_switching_key: Vec<u8>,
457    pub bootstrapping_key: Vec<u8>,
458    // Size of the message buffer
459    pub message_modulus: MessageModulus,
460    // Size of the carry buffer
461    pub carry_modulus: CarryModulus,
462    // Maximum number of operations that can be done before emptying the operation buffer
463    pub max_degree: MaxDegree,
464}
465
466impl Serialize for ServerKey {
467    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
468    where
469        S: Serializer,
470    {
471        let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?;
472        let mut fft_ser_eng = FftSerializationEngine::new(()).map_err(serde::ser::Error::custom)?;
473
474        let key_switching_key = ser_eng
475            .serialize(&self.key_switching_key)
476            .map_err(serde::ser::Error::custom)?;
477        let bootstrapping_key = fft_ser_eng
478            .serialize(&self.bootstrapping_key)
479            .map_err(serde::ser::Error::custom)?;
480
481        SerializableServerKey {
482            key_switching_key,
483            bootstrapping_key,
484            message_modulus: self.message_modulus,
485            carry_modulus: self.carry_modulus,
486            max_degree: self.max_degree,
487        }
488        .serialize(serializer)
489    }
490}
491
492impl<'de> Deserialize<'de> for ServerKey {
493    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
494    where
495        D: Deserializer<'de>,
496    {
497        let thing =
498            SerializableServerKey::deserialize(deserializer).map_err(serde::de::Error::custom)?;
499        let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?;
500        let mut fft_ser_eng = FftSerializationEngine::new(()).map_err(serde::de::Error::custom)?;
501
502        Ok(Self {
503            key_switching_key: ser_eng
504                .deserialize(thing.key_switching_key.as_slice())
505                .map_err(serde::de::Error::custom)?,
506            bootstrapping_key: fft_ser_eng
507                .deserialize(thing.bootstrapping_key.as_slice())
508                .map_err(serde::de::Error::custom)?,
509            message_modulus: thing.message_modulus,
510            carry_modulus: thing.carry_modulus,
511            max_degree: thing.max_degree,
512        })
513    }
514}