concrete_shortint/server_key/mod.rs
1//! Module with the definition of the ServerKey.
2//!
3//! This module implements the generation of the server public key, together with all the
4//! available homomorphic integer operations.
5mod add;
6mod bitwise_op;
7mod comp_op;
8mod div_mod;
9mod mul;
10mod neg;
11mod scalar_add;
12mod scalar_mul;
13mod scalar_sub;
14mod shift;
15mod sub;
16
17#[cfg(test)]
18mod tests;
19
20use crate::ciphertext::Ciphertext;
21use crate::client_key::ClientKey;
22use crate::engine::ShortintEngine;
23use crate::parameters::{CarryModulus, MessageModulus};
24use concrete_core::prelude::*;
25use serde::{Deserialize, Deserializer, Serialize, Serializer};
26use std::fmt::{Debug, Display, Formatter};
27
28/// Maximum value that the degree can reach.
29#[derive(Debug, PartialEq, Eq, Copy, Clone, Serialize, Deserialize)]
30pub struct MaxDegree(pub usize);
31
32/// Error returned when the carry buffer is full.
33#[derive(Debug)]
34pub enum CheckError {
35 CarryFull,
36}
37
38impl Display for CheckError {
39 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40 match self {
41 CheckError::CarryFull => {
42 write!(f, "The carry buffer is full")
43 }
44 }
45 }
46}
47
48impl std::error::Error for CheckError {}
49
50/// A structure containing the server public key.
51///
52/// The server key is generated by the client and is meant to be published: the client
53/// sends it to the server so it can compute homomorphic circuits.
54#[derive(Clone, Debug, PartialEq)]
55pub struct ServerKey {
56 pub key_switching_key: LweKeyswitchKey64,
57 pub bootstrapping_key: FftFourierLweBootstrapKey64,
58 // Size of the message buffer
59 pub message_modulus: MessageModulus,
60 // Size of the carry buffer
61 pub carry_modulus: CarryModulus,
62 // Maximum number of operations that can be done before emptying the operation buffer
63 pub max_degree: MaxDegree,
64}
65
66impl ServerKey {
67 /// Generates a server key.
68 ///
69 /// # Example
70 ///
71 /// ```rust
72 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
73 /// use concrete_shortint::{gen_keys, ServerKey};
74 ///
75 /// // Generate the client key and the server key:
76 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
77 ///
78 /// // Generate the server key:
79 /// let sks = ServerKey::new(&cks);
80 /// ```
81 pub fn new(cks: &ClientKey) -> ServerKey {
82 ShortintEngine::with_thread_local_mut(|engine| engine.new_server_key(cks).unwrap())
83 }
84
85 /// Generates a server key with a chosen maximum degree
86 pub fn new_with_max_degree(cks: &ClientKey, max_degree: MaxDegree) -> ServerKey {
87 ShortintEngine::with_thread_local_mut(|engine| {
88 engine
89 .new_server_key_with_max_degree(cks, max_degree)
90 .unwrap()
91 })
92 }
93
94 /// Constructs the accumulator given a function as input.
95 ///
96 /// # Example
97 ///
98 /// ```rust
99 /// use concrete_shortint::gen_keys;
100 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
101 ///
102 /// // Generate the client key and the server key:
103 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
104 ///
105 /// let msg = 3;
106 ///
107 /// let ct = cks.encrypt(msg);
108 ///
109 /// // Generate the accumulator for the function f: x -> x^2 mod 2^2
110 /// let f = |x| x ^ 2 % 4;
111 ///
112 /// let acc = sks.generate_accumulator(f);
113 /// let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc);
114 ///
115 /// let dec = cks.decrypt(&ct_res);
116 /// // 3^2 mod 4 = 1
117 /// assert_eq!(dec, f(msg));
118 /// ```
119 pub fn generate_accumulator<F>(&self, f: F) -> GlweCiphertext64
120 where
121 F: Fn(u64) -> u64,
122 {
123 ShortintEngine::with_thread_local_mut(|engine| {
124 engine.generate_accumulator(self, f).unwrap()
125 })
126 }
127
128 /// Computes a keyswitch and a bootstrap, returning a new ciphertext with empty
129 /// carry bits.
130 ///
131 /// # Example
132 ///
133 /// ```rust
134 /// use concrete_shortint::gen_keys;
135 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
136 ///
137 /// // Generate the client key and the server key:
138 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
139 ///
140 /// let mut ct1 = cks.encrypt(3);
141 /// // | ct1 |
142 /// // | carry | message |
143 /// // |-------|---------|
144 /// // | 0 0 | 1 1 |
145 /// let mut ct2 = cks.encrypt(2);
146 /// // | ct2 |
147 /// // | carry | message |
148 /// // |-------|---------|
149 /// // | 0 0 | 1 0 |
150 ///
151 /// let ct_res = sks.smart_add(&mut ct1, &mut ct2);
152 /// // | ct_res |
153 /// // | carry | message |
154 /// // |-------|---------|
155 /// // | 0 1 | 0 1 |
156 ///
157 /// // Get the carry
158 /// let ct_carry = sks.carry_extract(&ct_res);
159 /// let carry = cks.decrypt(&ct_carry);
160 /// assert_eq!(carry, 1);
161 ///
162 /// let ct_res = sks.keyswitch_bootstrap(&ct_res);
163 ///
164 /// let ct_carry = sks.carry_extract(&ct_res);
165 /// let carry = cks.decrypt(&ct_carry);
166 /// assert_eq!(carry, 0);
167 ///
168 /// let clear = cks.decrypt(&ct_res);
169 ///
170 /// assert_eq!(clear, (3 + 2) % 4);
171 /// ```
172 pub fn keyswitch_bootstrap(&self, ct_in: &Ciphertext) -> Ciphertext {
173 ShortintEngine::with_thread_local_mut(|engine| {
174 engine.keyswitch_bootstrap(self, ct_in).unwrap()
175 })
176 }
177
178 pub fn keyswitch_bootstrap_assign(&self, ct_in: &mut Ciphertext) {
179 ShortintEngine::with_thread_local_mut(|engine| {
180 engine.keyswitch_bootstrap_assign(self, ct_in).unwrap()
181 })
182 }
183
184 /// Computes a keyswitch and programmable bootstrap.
185 ///
186 /// # Example
187 ///
188 /// ```rust
189 /// use concrete_shortint::gen_keys;
190 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
191 ///
192 /// // Generate the client key and the server key:
193 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
194 ///
195 /// let msg: u64 = 3;
196 /// let ct = cks.encrypt(msg);
197 /// let modulus = cks.parameters.message_modulus.0 as u64;
198 ///
199 /// // Generate the accumulator for the function f: x -> x^3 mod 2^2
200 /// let acc = sks.generate_accumulator(|x| x * x * x % modulus);
201 /// let ct_res = sks.keyswitch_programmable_bootstrap(&ct, &acc);
202 ///
203 /// let dec = cks.decrypt(&ct_res);
204 /// // 3^3 mod 4 = 3
205 /// assert_eq!(dec, (msg * msg * msg) % modulus);
206 /// ```
207 pub fn keyswitch_programmable_bootstrap(
208 &self,
209 ct_in: &Ciphertext,
210 acc: &GlweCiphertext64,
211 ) -> Ciphertext {
212 ShortintEngine::with_thread_local_mut(|engine| {
213 engine
214 .programmable_bootstrap_keyswitch(self, ct_in, acc)
215 .unwrap()
216 })
217 }
218
219 pub fn keyswitch_programmable_bootstrap_assign(
220 &self,
221 ct_in: &mut Ciphertext,
222 acc: &GlweCiphertext64,
223 ) {
224 ShortintEngine::with_thread_local_mut(|engine| {
225 engine
226 .programmable_bootstrap_keyswitch_assign(self, ct_in, acc)
227 .unwrap()
228 })
229 }
230
231 /// Generic programmable bootstrap where messages are concatenated
232 /// into one ciphertext to compute bivariate functions.
233 /// This is used to apply many binary operations (comparisons, multiplications, division).
234 pub fn unchecked_functional_bivariate_pbs<F>(
235 &self,
236 ct_left: &Ciphertext,
237 ct_right: &Ciphertext,
238 f: F,
239 ) -> Ciphertext
240 where
241 F: Fn(u64) -> u64,
242 {
243 ShortintEngine::with_thread_local_mut(|engine| {
244 engine
245 .unchecked_functional_bivariate_pbs(self, ct_left, ct_right, f)
246 .unwrap()
247 })
248 }
249
250 pub fn unchecked_functional_bivariate_pbs_assign<F>(
251 &self,
252 ct_left: &mut Ciphertext,
253 ct_right: &Ciphertext,
254 f: F,
255 ) where
256 F: Fn(u64) -> u64,
257 {
258 ShortintEngine::with_thread_local_mut(|engine| {
259 engine
260 .unchecked_functional_bivariate_pbs_assign(self, ct_left, ct_right, f)
261 .unwrap()
262 })
263 }
264
265 /// Verifies if a bivariate functional pbs can be applied on ct_left and ct_right.
266 pub fn is_functional_bivariate_pbs_possible(&self, ct1: &Ciphertext, ct2: &Ciphertext) -> bool {
267 //product of the degree
268 let final_degree = ct1.degree.0 * (ct2.degree.0 + 1) + ct2.degree.0;
269 final_degree < ct1.carry_modulus.0 * ct1.message_modulus.0
270 }
271
272 /// Replace the input encrypted message by the value of its carry buffer.
273 ///
274 /// # Example
275 ///
276 ///```rust
277 /// use concrete_shortint::gen_keys;
278 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
279 ///
280 /// // Generate the client key and the server key:
281 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
282 ///
283 /// let clear = 9;
284 ///
285 /// // Encrypt a message
286 /// let mut ct = cks.unchecked_encrypt(clear);
287 ///
288 /// // | ct |
289 /// // | carry | message |
290 /// // |-------|---------|
291 /// // | 1 0 | 0 1 |
292 ///
293 /// // Compute homomorphically carry extraction
294 /// sks.carry_extract_assign(&mut ct);
295 ///
296 /// // | ct |
297 /// // | carry | message |
298 /// // |-------|---------|
299 /// // | 0 0 | 1 0 |
300 ///
301 /// // Decrypt:
302 /// let res = cks.decrypt_message_and_carry(&ct);
303 /// assert_eq!(2, res);
304 /// ```
305 pub fn carry_extract_assign(&self, ct: &mut Ciphertext) {
306 ShortintEngine::with_thread_local_mut(|engine| {
307 engine.carry_extract_assign(self, ct).unwrap()
308 })
309 }
310
311 /// Extracts a new ciphertext encrypting the input carry buffer.
312 ///
313 /// # Example
314 ///
315 ///```rust
316 /// use concrete_shortint::gen_keys;
317 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
318 ///
319 /// // Generate the client key and the server key:
320 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
321 ///
322 /// let clear = 9;
323 ///
324 /// // Encrypt a message
325 /// let ct = cks.unchecked_encrypt(clear);
326 ///
327 /// // | ct |
328 /// // | carry | message |
329 /// // |-------|---------|
330 /// // | 1 0 | 0 1 |
331 ///
332 /// // Compute homomorphically carry extraction
333 /// let ct_res = sks.carry_extract(&ct);
334 ///
335 /// // | ct_res |
336 /// // | carry | message |
337 /// // |-------|---------|
338 /// // | 0 0 | 1 0 |
339 ///
340 /// // Decrypt:
341 /// let res = cks.decrypt(&ct_res);
342 /// assert_eq!(2, res);
343 /// ```
344 pub fn carry_extract(&self, ct: &Ciphertext) -> Ciphertext {
345 ShortintEngine::with_thread_local_mut(|engine| engine.carry_extract(self, ct).unwrap())
346 }
347
348 /// Clears the carry buffer of the input ciphertext.
349 ///
350 /// # Example
351 ///
352 ///```rust
353 /// use concrete_shortint::gen_keys;
354 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
355 ///
356 /// // Generate the client key and the server key:
357 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
358 ///
359 /// let clear = 9;
360 ///
361 /// // Encrypt a message
362 /// let mut ct = cks.unchecked_encrypt(clear);
363 ///
364 /// // | ct |
365 /// // | carry | message |
366 /// // |-------|---------|
367 /// // | 1 0 | 0 1 |
368 ///
369 /// // Compute homomorphically the message extraction
370 /// sks.message_extract_assign(&mut ct);
371 ///
372 /// // | ct |
373 /// // | carry | message |
374 /// // |-------|---------|
375 /// // | 0 0 | 0 1 |
376 ///
377 /// // Decrypt:
378 /// let res = cks.decrypt(&ct);
379 /// assert_eq!(1, res);
380 /// ```
381 pub fn message_extract_assign(&self, ct: &mut Ciphertext) {
382 ShortintEngine::with_thread_local_mut(|engine| {
383 engine.message_extract_assign(self, ct).unwrap()
384 })
385 }
386
387 /// Extracts a new ciphertext containing only the message i.e., with a cleared carry buffer.
388 ///
389 /// # Example
390 ///
391 ///```rust
392 /// use concrete_shortint::gen_keys;
393 /// use concrete_shortint::parameters::PARAM_MESSAGE_1_CARRY_1;
394 ///
395 /// // Generate the client key and the server key:
396 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_1_CARRY_1);
397 ///
398 /// let clear = 9;
399 ///
400 /// // Encrypt a message
401 /// let ct = cks.unchecked_encrypt(clear);
402 ///
403 /// // | ct |
404 /// // | carry | message |
405 /// // |-------|---------|
406 /// // | 1 0 | 0 1 |
407 ///
408 /// // Compute homomorphically the message extraction
409 /// let ct_res = sks.message_extract(&ct);
410 ///
411 /// // | ct_res |
412 /// // | carry | message |
413 /// // |-------|---------|
414 /// // | 0 0 | 0 1 |
415 ///
416 /// // Decrypt:
417 /// let res = cks.decrypt(&ct_res);
418 /// assert_eq!(1, res);
419 /// ```
420 pub fn message_extract(&self, ct: &Ciphertext) -> Ciphertext {
421 ShortintEngine::with_thread_local_mut(|engine| engine.message_extract(self, ct).unwrap())
422 }
423
424 /// Computes a trivial shortint from a given value.
425 ///
426 /// # Example
427 ///
428 /// ```rust
429 /// use concrete_shortint::gen_keys;
430 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
431 ///
432 /// // Generate the client key and the server key:
433 /// let (cks, sks) = gen_keys(PARAM_MESSAGE_2_CARRY_2);
434 ///
435 /// let msg = 1;
436 ///
437 /// // Trivial encryption
438 /// let ct1 = sks.create_trivial(msg);
439 ///
440 /// let ct_res = cks.decrypt(&ct1);
441 /// assert_eq!(1, ct_res);
442 /// ```
443 pub fn create_trivial(&self, value: u8) -> Ciphertext {
444 ShortintEngine::with_thread_local_mut(|engine| engine.create_trivial(self, value).unwrap())
445 }
446
447 pub fn create_trivial_assign(&self, ct: &mut Ciphertext, value: u8) {
448 ShortintEngine::with_thread_local_mut(|engine| {
449 engine.create_trivial_assign(self, ct, value).unwrap()
450 })
451 }
452}
453
454#[derive(Serialize, Deserialize)]
455pub(super) struct SerializableServerKey {
456 pub key_switching_key: Vec<u8>,
457 pub bootstrapping_key: Vec<u8>,
458 // Size of the message buffer
459 pub message_modulus: MessageModulus,
460 // Size of the carry buffer
461 pub carry_modulus: CarryModulus,
462 // Maximum number of operations that can be done before emptying the operation buffer
463 pub max_degree: MaxDegree,
464}
465
466impl Serialize for ServerKey {
467 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
468 where
469 S: Serializer,
470 {
471 let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::ser::Error::custom)?;
472 let mut fft_ser_eng = FftSerializationEngine::new(()).map_err(serde::ser::Error::custom)?;
473
474 let key_switching_key = ser_eng
475 .serialize(&self.key_switching_key)
476 .map_err(serde::ser::Error::custom)?;
477 let bootstrapping_key = fft_ser_eng
478 .serialize(&self.bootstrapping_key)
479 .map_err(serde::ser::Error::custom)?;
480
481 SerializableServerKey {
482 key_switching_key,
483 bootstrapping_key,
484 message_modulus: self.message_modulus,
485 carry_modulus: self.carry_modulus,
486 max_degree: self.max_degree,
487 }
488 .serialize(serializer)
489 }
490}
491
492impl<'de> Deserialize<'de> for ServerKey {
493 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
494 where
495 D: Deserializer<'de>,
496 {
497 let thing =
498 SerializableServerKey::deserialize(deserializer).map_err(serde::de::Error::custom)?;
499 let mut ser_eng = DefaultSerializationEngine::new(()).map_err(serde::de::Error::custom)?;
500 let mut fft_ser_eng = FftSerializationEngine::new(()).map_err(serde::de::Error::custom)?;
501
502 Ok(Self {
503 key_switching_key: ser_eng
504 .deserialize(thing.key_switching_key.as_slice())
505 .map_err(serde::de::Error::custom)?,
506 bootstrapping_key: fft_ser_eng
507 .deserialize(thing.bootstrapping_key.as_slice())
508 .map_err(serde::de::Error::custom)?,
509 message_modulus: thing.message_modulus,
510 carry_modulus: thing.carry_modulus,
511 max_degree: thing.max_degree,
512 })
513 }
514}