concrete_integer/wopbs/mod.rs
1//! Module with the definition of the WopbsKey (WithOut padding PBS Key).
2//!
3//! This module implements the generation of another server public key, which allows to compute
4//! an alternative version of the programmable bootstrapping. This does not require the use of a
5//! bit of padding.
6#[cfg(test)]
7mod test;
8
9use crate::client_key::utils::i_crt;
10use crate::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertext, ServerKey};
11use concrete_core::prelude::*;
12use concrete_shortint::ciphertext::Degree;
13use rayon::prelude::*;
14
15use concrete_shortint::Parameters;
16use serde::{Deserialize, Serialize};
17
18#[derive(Clone, Serialize, Deserialize)]
19pub struct WopbsKey {
20 wopbs_key: concrete_shortint::wopbs::WopbsKey,
21}
22
23/// ```rust
24/// use concrete_integer::wopbs::{decode_radix, encode_radix};
25///
26/// let val = 11;
27/// let basis = 2;
28/// let nb_block = 5;
29/// let radix = encode_radix(val, basis, nb_block);
30///
31/// assert_eq!(val, decode_radix(radix, basis));
32/// ```
33pub fn encode_radix(val: u64, basis: u64, nb_block: u64) -> Vec<u64> {
34 let mut output = vec![];
35 //Bits of message put to 1éfé
36 let mask = (basis - 1) as u64;
37
38 let mut power = 1_u64;
39 //Put each decomposition into a new ciphertext
40 for _ in 0..nb_block {
41 let mut decomp = val & (mask * power);
42 decomp /= power;
43
44 // fill the vector with the message moduli
45 output.push(decomp);
46
47 //modulus to the power i
48 power *= basis;
49 }
50 output
51}
52
53pub fn encode_crt(val: u64, basis: &[u64]) -> Vec<u64> {
54 let mut output = vec![];
55 //Put each decomposition into a new ciphertext
56 for i in basis {
57 output.push(val % i);
58 }
59 output
60}
61
62//Concatenate two ciphertexts in one
63//Used to compute bivariate wopbs
64fn ciphertext_concatenation<T>(ct1: &T, ct2: &T) -> T
65where
66 T: IntegerCiphertext,
67{
68 let mut new_blocks = ct1.blocks().to_vec();
69 new_blocks.extend_from_slice(ct2.blocks());
70 T::from_blocks(new_blocks)
71}
72
73pub fn encode_mix_radix(mut val: u64, basis: &[u64], modulus: u64) -> Vec<u64> {
74 let mut output = vec![];
75 for basis in basis.iter() {
76 output.push(val % modulus);
77 val -= val % modulus;
78 let tmp = (val % (1 << basis)) >> (f64::log2(modulus as f64) as u64);
79 val >>= basis;
80 val += tmp;
81 }
82 output
83}
84
85// Example: val = 5 = 0b101 , basis = [1,2] -> output = [1, 1]
86/// ```rust
87/// use concrete_integer::wopbs::split_value_according_to_bit_basis;
88/// // Generate the client key and the server key:
89/// let val = 5;
90/// let basis = vec![1, 2];
91/// assert_eq!(vec![1, 2], split_value_according_to_bit_basis(val, &basis));
92/// ```
93pub fn split_value_according_to_bit_basis(value: u64, basis: &[u64]) -> Vec<u64> {
94 let mut output = vec![];
95 let mut tmp = value;
96 let mask = 1;
97
98 for i in basis {
99 let mut tmp_output = 0;
100 for j in 0..*i {
101 let val = tmp & mask;
102 tmp_output += val << j;
103 tmp >>= 1;
104 }
105 output.push(tmp_output);
106 }
107 output
108}
109
110/// ```rust
111/// use concrete_integer::wopbs::{decode_radix, encode_radix};
112///
113/// let val = 11;
114/// let basis = 2;
115/// let nb_block = 5;
116/// assert_eq!(val, decode_radix(encode_radix(val, basis, nb_block), basis));
117/// ```
118pub fn decode_radix(val: Vec<u64>, basis: u64) -> u64 {
119 let mut result = 0_u64;
120 let mut shift = 1_u64;
121 for v_i in val.iter() {
122 //decrypt the component i of the integer and multiply it by the radix product
123 let tmp = v_i.wrapping_mul(shift);
124
125 // update the result
126 result = result.wrapping_add(tmp);
127
128 // update the shift for the next iteration
129 shift = shift.wrapping_mul(basis);
130 }
131 result
132}
133
134impl From<concrete_shortint::wopbs::WopbsKey> for WopbsKey {
135 fn from(wopbs_key: concrete_shortint::wopbs::WopbsKey) -> Self {
136 Self { wopbs_key }
137 }
138}
139
140impl WopbsKey {
141 /// Generates the server key required to compute a WoPBS from the client and the server keys.
142 /// # Example
143 /// ```rust
144 /// use concrete_integer::gen_keys;
145 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1;
146 /// use concrete_integer::wopbs::*;
147 /// use concrete_shortint::parameters::PARAM_MESSAGE_1_CARRY_1;
148 ///
149 /// // Generate the client key and the server key:
150 /// let (mut cks, mut sks) = gen_keys(&PARAM_MESSAGE_1_CARRY_1);
151 /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_1_CARRY_1);
152 /// ```
153 pub fn new_wopbs_key(cks: &ClientKey, sks: &ServerKey, parameters: &Parameters) -> WopbsKey {
154 WopbsKey {
155 wopbs_key: concrete_shortint::wopbs::WopbsKey::new_wopbs_key(
156 &cks.key, &sks.key, parameters,
157 ),
158 }
159 }
160
161 pub fn new_from_shortint(wopbskey: &concrete_shortint::wopbs::WopbsKey) -> WopbsKey {
162 let key = wopbskey.clone();
163 WopbsKey { wopbs_key: key }
164 }
165
166 pub fn new_wopbs_key_only_for_wopbs(cks: &ClientKey, sks: &ServerKey) -> WopbsKey {
167 WopbsKey {
168 wopbs_key: concrete_shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(
169 &cks.key, &sks.key,
170 ),
171 }
172 }
173
174 /// Computes the WoP-PBS given the luts.
175 ///
176 /// This works for both RadixCiphertext and CrtCiphertext.
177 ///
178 /// # Example
179 ///
180 /// ```rust
181 /// use concrete_integer::gen_keys;
182 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
183 /// use concrete_integer::wopbs::*;
184 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
185 ///
186 /// let nb_block = 3;
187 /// //Generate the client key and the server key:
188 /// let (mut cks, mut sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
189 /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
190 /// let mut moduli = 1_u64;
191 /// for _ in 0..nb_block{
192 /// moduli *= cks.parameters().message_modulus.0 as u64;
193 /// }
194 /// let clear = 42 % moduli;
195 /// let ct = cks.encrypt_radix(clear as u64, nb_block);
196 /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct);
197 /// let lut = wopbs_key.generate_lut_radix(&ct, |x|x);
198 /// let ct_res = wopbs_key.wopbs(&ct, &lut);
199 /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
200 /// let res = cks.decrypt_radix(&ct_res);
201 ///
202 /// assert_eq!(res, clear);
203 /// ```
204 pub fn wopbs<T>(&self, ct_in: &T, lut: &[Vec<u64>]) -> T
205 where
206 T: IntegerCiphertext,
207 {
208 let mut extracted_bits_blocks = vec![];
209 // Extraction of each bit for each block
210 for block in ct_in.blocks().iter() {
211 let delta = (1_usize << 63)
212 / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0);
213 let delta_log = DeltaLog(f64::log2(delta as f64) as usize);
214 let nb_bit_to_extract = f64::log2((block.degree.0 + 1) as f64).ceil() as usize;
215
216 let extracted_bits = self
217 .wopbs_key
218 .extract_bits(delta_log, block, nb_bit_to_extract);
219
220 extracted_bits_blocks.push(extracted_bits);
221 }
222
223 extracted_bits_blocks.reverse();
224 let vec_ct_out = self
225 .wopbs_key
226 .circuit_bootstrapping_vertical_packing(lut.to_vec(), extracted_bits_blocks);
227
228 let mut ct_vec_out = vec![];
229 for (block, block_out) in ct_in.blocks().iter().zip(vec_ct_out.into_iter()) {
230 ct_vec_out.push(concrete_shortint::Ciphertext {
231 ct: block_out,
232 degree: Degree(block.message_modulus.0 - 1),
233 message_modulus: block.message_modulus,
234 carry_modulus: block.carry_modulus,
235 });
236 }
237 T::from_blocks(ct_vec_out)
238 }
239
240 /// # Example
241 /// ```rust
242 /// use concrete_integer::gen_keys;
243 /// use concrete_integer::wopbs::WopbsKey;
244 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
245 ///
246 /// let nb_block = 3;
247 /// //Generate the client key and the server key:
248 /// let (mut cks, mut sks) = gen_keys(&WOPBS_PARAM_MESSAGE_2_CARRY_2);
249 /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
250 /// let mut moduli = 1_u64;
251 /// for _ in 0..nb_block{
252 /// moduli *= cks.parameters().message_modulus.0 as u64;
253 /// }
254 /// let clear = 15 % moduli;
255 /// let ct = cks.encrypt_radix_without_padding(clear as u64, nb_block);
256 /// let lut = wopbs_key.generate_lut_radix_without_padding(&ct, |x| 2 * x);
257 /// let ct_res = wopbs_key.wopbs_without_padding(&ct, &lut);
258 /// let res = cks.decrypt_radix_without_padding(&ct_res);
259 ///
260 /// assert_eq!(res, (clear * 2) % moduli)
261 /// ```
262 pub fn wopbs_without_padding<T>(&self, ct_in: &T, lut: &[Vec<u64>]) -> T
263 where
264 T: IntegerCiphertext,
265 {
266 let mut extracted_bits_blocks = vec![];
267 let mut ct_in = ct_in.clone();
268 // Extraction of each bit for each block
269 for block in ct_in.blocks_mut().iter_mut() {
270 let delta = (1_usize << 63) / (block.message_modulus.0 * block.carry_modulus.0 / 2);
271 let delta_log = DeltaLog(f64::log2(delta as f64) as usize);
272 let nb_bit_to_extract =
273 f64::log2((block.message_modulus.0 * block.carry_modulus.0) as f64) as usize;
274
275 let extracted_bits = self
276 .wopbs_key
277 .extract_bits(delta_log, block, nb_bit_to_extract);
278 extracted_bits_blocks.push(extracted_bits);
279 }
280
281 extracted_bits_blocks.reverse();
282
283 let vec_ct_out = self
284 .wopbs_key
285 .circuit_bootstrapping_vertical_packing(lut.to_vec(), extracted_bits_blocks);
286
287 let mut ct_vec_out = vec![];
288 for (block, block_out) in ct_in.blocks().iter().zip(vec_ct_out.into_iter()) {
289 ct_vec_out.push(concrete_shortint::Ciphertext {
290 ct: block_out,
291 degree: Degree(block.message_modulus.0 - 1),
292 message_modulus: block.message_modulus,
293 carry_modulus: block.carry_modulus,
294 });
295 }
296 T::from_blocks(ct_vec_out)
297 }
298
299 /// WOPBS for native CRT
300 /// # Example
301 /// ```rust
302 /// use concrete_integer::gen_keys;
303 /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
304 /// use concrete_integer::wopbs::WopbsKey;
305 ///
306 /// let basis: Vec<u64> = vec![9, 11];
307 ///
308 /// let param = PARAM_4_BITS_5_BLOCKS;
309 /// //Generate the client key and the server key:
310 /// let (cks, sks) = gen_keys(¶m);
311 /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
312 ///
313 /// let mut msg_space = 1;
314 /// for modulus in basis.iter() {
315 /// msg_space *= modulus;
316 /// }
317 /// let clear = 42 % msg_space; // Encrypt the integers
318 /// let mut ct = cks.encrypt_native_crt(clear, basis.clone());
319 /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x);
320 /// let ct_res = wopbs_key.wopbs_native_crt(&mut ct, &lut);
321 /// let res = cks.decrypt_native_crt(&ct_res);
322 /// assert_eq!(res, clear);
323 /// ```
324 pub fn wopbs_native_crt(&self, ct1: &CrtCiphertext, lut: &[Vec<u64>]) -> CrtCiphertext {
325 self.circuit_bootstrap_vertical_packing_native_crt(&[ct1.clone()], lut)
326 }
327
328 /// # Example
329 /// ```rust
330 /// use concrete_integer::gen_keys;
331 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
332 /// use concrete_integer::wopbs::*;
333 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
334 ///
335 /// let nb_block = 3;
336 /// //Generate the client key and the server key:
337 /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
338 ///
339 /// //Generate wopbs_v0 key ///
340 /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
341 /// let mut moduli = 1_u64;
342 /// for _ in 0..nb_block{
343 /// moduli *= cks.parameters().message_modulus.0 as u64;
344 /// }
345 /// let clear1 = 42 % moduli;
346 /// let clear2 = 24 % moduli;
347 /// let ct1 = cks.encrypt_radix(clear1 as u64, nb_block);
348 /// let ct2 = cks.encrypt_radix(clear2 as u64, nb_block);
349 ///
350 /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1);
351 /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2);
352 /// let lut = wopbs_key.generate_lut_bivariate_radix(&ct1, &ct2, |x,y| 2 * x * y);
353 /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(& ct1, & ct2, &lut);
354 /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
355 /// let res = cks.decrypt_radix(&ct_res);
356 ///
357 /// assert_eq!(res, (2 * clear1 * clear2) % moduli);
358 /// ```
359 pub fn bivariate_wopbs_with_degree<T>(&self, ct1: &T, ct2: &T, lut: &[Vec<u64>]) -> T
360 where
361 T: IntegerCiphertext,
362 {
363 let ct = ciphertext_concatenation(ct1, ct2);
364 self.wopbs(&ct, lut)
365 }
366
367 /// # Example
368 ///
369 /// ```rust
370 /// use concrete_integer::gen_keys;
371 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
372 /// use concrete_integer::wopbs::*;
373 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
374 ///
375 /// let nb_block = 3;
376 /// //Generate the client key and the server key:
377 /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
378 ///
379 /// //Generate wopbs_v0 key ///
380 /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
381 /// let mut moduli = 1_u64;
382 /// for _ in 0..nb_block{
383 /// moduli *= cks.parameters().message_modulus.0 as u64;
384 /// }
385 /// let clear = 42 % moduli;
386 /// let ct = cks.encrypt_radix(clear as u64, nb_block);
387 /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct);
388 /// let lut = wopbs_key.generate_lut_radix(&ct, |x| 2 * x);
389 /// let ct_res = wopbs_key.wopbs(&ct, &lut);
390 /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
391 /// let res = cks.decrypt_radix(&ct_res);
392 ///
393 /// assert_eq!(res, (2 * clear) % moduli);
394 /// ```
395 pub fn generate_lut_radix<F, T>(&self, ct: &T, f: F) -> Vec<Vec<u64>>
396 where
397 F: Fn(u64) -> u64,
398 T: IntegerCiphertext,
399 {
400 let mut total_bit = 0;
401 let block_nb = ct.blocks().len();
402 let mut modulus = 1;
403
404 //This contains the basis of each block depending on the degree
405 let mut vec_deg_basis = vec![];
406
407 for (i, deg) in ct.moduli().iter().zip(ct.blocks().iter()) {
408 modulus *= i;
409 let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
410 vec_deg_basis.push(b);
411 total_bit += b;
412 }
413
414 let mut lut_size = 1 << total_bit;
415 if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
416 lut_size = self.wopbs_key.param.polynomial_size.0;
417 }
418 let mut vec_lut = vec![vec![0; lut_size]; ct.blocks().len()];
419
420 let basis = ct.moduli()[0];
421 let delta: u64 = (1 << 63)
422 / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
423 as u64;
424
425 for lut_index_val in 0..(1 << total_bit) {
426 let encoded_with_deg_val = encode_mix_radix(lut_index_val, &vec_deg_basis, basis);
427 let decoded_val = decode_radix(encoded_with_deg_val.clone(), basis);
428 let f_val = f(decoded_val % modulus) % modulus;
429 let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
430 for lut_number in 0..block_nb {
431 vec_lut[lut_number as usize][lut_index_val as usize] =
432 encoded_f_val[lut_number] * delta;
433 }
434 }
435 vec_lut
436 }
437
438 /// # Example
439 /// ```rust
440 /// use concrete_integer::gen_keys;
441 /// use concrete_integer::wopbs::WopbsKey;
442 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
443 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
444 ///
445 /// let nb_block = 3;
446 /// //Generate the client key and the server key:
447 /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
448 /// //Generate wopbs_v0 key
449 /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
450 /// let mut moduli = 1_u64;
451 /// for _ in 0..nb_block{
452 /// moduli *= cks.parameters().message_modulus.0 as u64;
453 /// }
454 /// let clear = 15 % moduli;
455 /// let ct = cks.encrypt_radix_without_padding(clear as u64, nb_block);
456 /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct);
457 /// let lut = wopbs_key.generate_lut_radix_without_padding(&ct, |x| 2 * x);
458 /// let ct_res = wopbs_key.wopbs_without_padding(&ct, &lut);
459 /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
460 /// let res = cks.decrypt_radix_without_padding(&ct_res);
461 ///
462 /// assert_eq!(res, (clear * 2) % moduli)
463 /// ```
464 pub fn generate_lut_radix_without_padding<F, T>(&self, ct: &T, f: F) -> Vec<Vec<u64>>
465 where
466 F: Fn(u64) -> u64,
467 T: IntegerCiphertext,
468 {
469 let log_message_modulus = f64::log2((self.wopbs_key.param.message_modulus.0) as f64) as u64;
470 let log_carry_modulus = f64::log2((self.wopbs_key.param.carry_modulus.0) as f64) as u64;
471 let log_basis = log_message_modulus + log_carry_modulus;
472 let delta = 64 - log_basis;
473 let nb_block = ct.blocks().len();
474 let poly_size = self.wopbs_key.param.polynomial_size.0;
475 let mut lut_size = 1 << (nb_block * log_basis as usize);
476 if lut_size < poly_size {
477 lut_size = poly_size;
478 }
479 let mut vec_lut = vec![vec![0; lut_size]; nb_block];
480
481 for index in 0..lut_size {
482 // find the value represented by the index
483 let mut value = 0;
484 let mut tmp_index = index;
485 for i in 0..nb_block as u64 {
486 let tmp = tmp_index % (1 << (log_basis * (i + 1)));
487 tmp_index -= tmp;
488 value += tmp >> (log_carry_modulus * i);
489 }
490
491 // fill the LUTs
492 for (block_index, lut_block) in vec_lut.iter_mut().enumerate().take(nb_block) {
493 lut_block[index] = ((f(value as u64) >> (log_carry_modulus * block_index as u64))
494 % (1 << log_message_modulus))
495 << delta
496 }
497 }
498 vec_lut
499 }
500
501 /// generate lut for native CRT
502 /// # Example
503 ///
504 /// ```rust
505 /// use concrete_integer::gen_keys;
506 /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
507 /// use concrete_integer::wopbs::WopbsKey;
508 ///
509 /// let basis: Vec<u64> = vec![9, 11];
510 ///
511 /// let param = PARAM_4_BITS_5_BLOCKS;
512 /// //Generate the client key and the server key:
513 /// let (cks, sks) = gen_keys(¶m);
514 /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
515 ///
516 /// let mut msg_space = 1;
517 /// for modulus in basis.iter() {
518 /// msg_space *= modulus;
519 /// }
520 /// let clear = 42 % msg_space; // Encrypt the integers
521 /// let mut ct = cks.encrypt_native_crt(clear, basis.clone());
522 /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x);
523 /// let ct_res = wopbs_key.wopbs_native_crt(&mut ct, &lut);
524 /// let res = cks.decrypt_native_crt(&ct_res);
525 /// assert_eq!(res, clear);
526 /// ```
527 pub fn generate_lut_native_crt<F>(&self, ct: &CrtCiphertext, f: F) -> Vec<Vec<u64>>
528 where
529 F: Fn(u64) -> u64,
530 {
531 let mut bit = vec![];
532 let mut total_bit = 0;
533 let mut modulus = 1;
534 let basis: Vec<_> = ct.moduli();
535
536 for i in basis.iter() {
537 modulus *= i;
538 let b = f64::log2(*i as f64).ceil() as u64;
539 total_bit += b;
540 bit.push(b);
541 }
542 let mut lut_size = 1 << total_bit;
543 if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
544 lut_size = self.wopbs_key.param.polynomial_size.0;
545 }
546 let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
547
548 for value in 0..modulus {
549 let mut index_lut = 0;
550 let mut tmp = 1;
551 for (base, bit) in basis.iter().zip(bit.iter()) {
552 index_lut += (((value % base) << bit) / base) * tmp;
553 tmp <<= bit;
554 }
555 for (j, b) in basis.iter().enumerate() {
556 vec_lut[j][index_lut as usize] =
557 (((f(value) % b) as u128 * (1 << 64)) / *b as u128) as u64
558 }
559 }
560 vec_lut
561 }
562
563 /// generate LUt for crt
564 /// # Example
565 /// ```rust
566 ///
567 /// use concrete_integer::gen_keys;
568 /// use concrete_integer::wopbs::*;
569 /// use concrete_shortint::parameters::PARAM_MESSAGE_3_CARRY_3;
570 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_3_CARRY_3;
571 ///
572 /// let basis : Vec<u64> = vec![5,7];
573 /// let nb_block = basis.len();
574 ///
575 /// //Generate the client key and the server key:
576 /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3);
577 /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_3_CARRY_3);
578 ///
579 /// let mut msg_space = 1;
580 /// for modulus in basis.iter() {
581 /// msg_space *= modulus;
582 /// }
583 /// let clear = 42 % msg_space;
584 /// let ct = cks.encrypt_crt(clear, basis.clone());
585 /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct);
586 /// let lut = wopbs_key.generate_lut_crt(&ct, |x| x);
587 /// let ct_res = wopbs_key.wopbs(&ct, &lut);
588 /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
589 /// let res = cks.decrypt_crt(&ct_res);
590 /// assert_eq!(res, clear);
591 /// ```
592 pub fn generate_lut_crt<F>(&self, ct: &CrtCiphertext, f: F) -> Vec<Vec<u64>>
593 where
594 F: Fn(u64) -> u64,
595 {
596 let mut bit = vec![];
597 let mut total_bit = 0;
598 let mut modulus = 1;
599 let basis = ct.moduli();
600
601 for (i, deg) in basis.iter().zip(ct.blocks.iter()) {
602 modulus *= i;
603 let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
604 total_bit += b;
605 bit.push(b);
606 }
607 let mut lut_size = 1 << total_bit;
608 if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
609 lut_size = self.wopbs_key.param.polynomial_size.0;
610 }
611 let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
612
613 for i in 0..(1 << total_bit) {
614 let mut value = i;
615 for (j, block) in ct.blocks.iter().enumerate() {
616 let deg = f64::log2((block.degree.0 + 1) as f64).ceil() as u64;
617 let delta: u64 = (1 << 63)
618 / (self.wopbs_key.param.message_modulus.0
619 * self.wopbs_key.param.carry_modulus.0) as u64;
620 vec_lut[j][i as usize] =
621 ((f((value % (1 << deg)) % block.message_modulus.0 as u64))
622 % block.message_modulus.0 as u64)
623 * delta;
624 value >>= deg;
625 }
626 }
627 vec_lut
628 }
629
630 /// # Example
631 ///
632 /// ```rust
633 /// use concrete_integer::gen_keys;
634 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
635 /// use concrete_integer::wopbs::*;
636 /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
637 ///
638 /// let nb_block = 3;
639 /// //Generate the client key and the server key:
640 /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
641 ///
642 /// //Generate wopbs_v0 key ///
643 /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
644 /// let mut moduli = 1_u64;
645 /// for _ in 0..nb_block{
646 /// moduli *= cks.parameters().message_modulus.0 as u64;
647 /// }
648 /// let clear1 = 42 % moduli;
649 /// let clear2 = 24 % moduli;
650 /// let ct1 = cks.encrypt_radix(clear1 as u64, nb_block);
651 /// let ct2 = cks.encrypt_radix(clear2 as u64, nb_block);
652 ///
653 /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct1);
654 /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct2);
655 /// let lut = wopbs_key.generate_lut_bivariate_radix(&ct1, &ct2, |x,y| 2 * x * y);
656 /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut);
657 /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
658 /// let res = cks.decrypt_radix(&ct_res);
659 ///
660 /// assert_eq!(res, (2 * clear1 * clear2) % moduli);
661 /// ```
662 pub fn generate_lut_bivariate_radix<F>(
663 &self,
664 ct1: &RadixCiphertext,
665 ct2: &RadixCiphertext,
666 f: F,
667 ) -> Vec<Vec<u64>>
668 where
669 F: Fn(u64, u64) -> u64,
670 {
671 let mut nb_bit_to_extract = vec![0; 2];
672 let block_nb = ct1.blocks.len();
673 //ct2 & ct1 should have the same basis
674 let basis = ct1.moduli();
675
676 //This contains the basis of each block depending on the degree
677 let mut vec_deg_basis = vec![vec![]; 2];
678
679 let mut modulus = 1;
680 for (ct_num, ct) in [ct1, ct2].iter().enumerate() {
681 modulus = 1;
682 for deg in ct.blocks.iter() {
683 modulus *= self.wopbs_key.param.message_modulus.0 as u64;
684 let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
685 vec_deg_basis[ct_num].push(b);
686 nb_bit_to_extract[ct_num] += b;
687 }
688 }
689
690 let total_bit: u64 = nb_bit_to_extract.iter().sum();
691
692 let mut lut_size = 1 << total_bit;
693 if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
694 lut_size = self.wopbs_key.param.polynomial_size.0;
695 }
696 let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
697 let basis = ct1.moduli()[0];
698
699 let delta: u64 = (1 << 63)
700 / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
701 as u64;
702
703 for lut_index_val in 0..(1 << total_bit) {
704 let split = vec![
705 lut_index_val % (1 << nb_bit_to_extract[0]),
706 lut_index_val >> nb_bit_to_extract[0],
707 ];
708 let mut decoded_val = vec![0; 2];
709 for i in 0..2 {
710 let encoded_with_deg_val = encode_mix_radix(split[i], &vec_deg_basis[i], basis);
711 decoded_val[i] = decode_radix(encoded_with_deg_val.clone(), basis);
712 }
713 let f_val = f(decoded_val[0] % modulus, decoded_val[1] % modulus) % modulus;
714 let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
715 for lut_number in 0..block_nb {
716 vec_lut[lut_number as usize][lut_index_val as usize] =
717 encoded_f_val[lut_number] * delta;
718 }
719 }
720 vec_lut
721 }
722
723 /// generate bivariate LUT for 'fake' CRT
724 ///
725 /// # Example
726 ///
727 /// ```rust
728 ///
729 /// use concrete_integer::gen_keys;
730 /// use concrete_integer::wopbs::*;
731 /// use concrete_shortint::parameters::PARAM_MESSAGE_3_CARRY_3;
732 /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_3_CARRY_3;
733 ///
734 /// let basis : Vec<u64> = vec![5,7];
735 /// //Generate the client key and the server key:
736 /// let ( cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3);
737 /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_3_CARRY_3);
738 ///
739 /// let mut msg_space = 1;
740 /// for modulus in basis.iter() {
741 /// msg_space *= modulus;
742 /// }
743 /// let clear1 = 42 % msg_space; // Encrypt the integers
744 /// let clear2 = 24 % msg_space; // Encrypt the integers
745 /// let ct1 = cks.encrypt_crt(clear1, basis.clone());
746 /// let ct2 = cks.encrypt_crt(clear2, basis.clone());
747 ///
748 /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1);
749 /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2);
750 ///
751 ///
752 /// let lut = wopbs_key.generate_lut_bivariate_crt(&ct1, &ct2, |x,y| x * y * 2);
753 /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut);
754 /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
755 /// let res = cks.decrypt_crt(&ct_res);
756 /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space );
757 /// ```
758 pub fn generate_lut_bivariate_crt<F>(
759 &self,
760 ct1: &CrtCiphertext,
761 ct2: &CrtCiphertext,
762 f: F,
763 ) -> Vec<Vec<u64>>
764 where
765 F: Fn(u64, u64) -> u64,
766 {
767 let mut bit = vec![];
768 let mut nb_bit_to_extract = vec![0; 2];
769 let mut modulus = 1;
770
771 //ct2 & ct1 should have the same basis
772 let basis = ct1.moduli();
773
774 for (ct_num, ct) in [ct1, ct2].iter().enumerate() {
775 for (i, deg) in basis.iter().zip(ct.blocks.iter()) {
776 modulus *= i;
777 let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
778 nb_bit_to_extract[ct_num] += b;
779 bit.push(b);
780 }
781 }
782
783 let total_bit: u64 = nb_bit_to_extract.iter().sum();
784
785 let mut lut_size = 1 << total_bit;
786 if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
787 lut_size = self.wopbs_key.param.polynomial_size.0;
788 }
789 let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
790
791 let delta: u64 = (1 << 63)
792 / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
793 as u64;
794
795 for index in 0..(1 << total_bit) {
796 let mut split = encode_radix(index, 1 << nb_bit_to_extract[0], 2);
797 let mut crt_value = vec![vec![0; ct1.blocks.len()]; 2];
798 for (j, base) in basis.iter().enumerate().take(ct1.blocks.len()) {
799 let deg_1 = f64::log2((ct1.blocks[j].degree.0 + 1) as f64).ceil() as u64;
800 let deg_2 = f64::log2((ct2.blocks[j].degree.0 + 1) as f64).ceil() as u64;
801 crt_value[0][j] = (split[0] % (1 << deg_1)) % base;
802 crt_value[1][j] = (split[1] % (1 << deg_2)) % base;
803 split[0] >>= deg_1;
804 split[1] >>= deg_2;
805 }
806 let value_1 = i_crt(&ct1.moduli(), &crt_value[0]);
807 let value_2 = i_crt(&ct2.moduli(), &crt_value[1]);
808 for (j, current_mod) in basis.iter().enumerate() {
809 let value = f(value_1, value_2) % current_mod;
810 vec_lut[j][index as usize] = (value % current_mod) * delta;
811 }
812 }
813
814 vec_lut
815 }
816
817 /// generate bivariate LUT for 'true' CRT
818 /// # Example
819 ///
820 /// ```rust
821 /// use concrete_integer::gen_keys;
822 /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
823 /// use concrete_integer::wopbs::WopbsKey;
824 ///
825 /// let basis: Vec<u64> = vec![9, 11];
826 ///
827 /// let param = PARAM_4_BITS_5_BLOCKS;
828 /// //Generate the client key and the server key:
829 /// let (cks, sks) = gen_keys(¶m);
830 /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
831 ///
832 /// let mut msg_space = 1;
833 /// for modulus in basis.iter() {
834 /// msg_space *= modulus;
835 /// }
836 /// let clear1 = 42 % msg_space;
837 /// let clear2 = 24 % msg_space;
838 /// let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone());
839 /// let mut ct2 = cks.encrypt_native_crt(clear2, basis.clone());
840 /// let lut = wopbs_key.generate_lut_bivariate_native_crt(&ct1, |x, y| x * y * 2);
841 /// let ct_res = wopbs_key.bivariate_wopbs_native_crt(&mut ct1, &mut ct2, &lut);
842 /// let res = cks.decrypt_native_crt(&ct_res);
843 /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space);
844 /// ```
845 pub fn generate_lut_bivariate_native_crt<F>(&self, ct_1: &CrtCiphertext, f: F) -> Vec<Vec<u64>>
846 where
847 F: Fn(u64, u64) -> u64,
848 {
849 let mut bit = vec![];
850 let mut total_bit = 0;
851 let mut modulus = 1;
852 let basis = ct_1.moduli();
853 for i in basis.iter() {
854 modulus *= i;
855 let b = f64::log2(*i as f64).ceil() as u64;
856 total_bit += b;
857 bit.push(b);
858 }
859 let mut lut_size = 1 << (2 * total_bit);
860 if 1 << (2 * total_bit) < self.wopbs_key.param.polynomial_size.0 as u64 {
861 lut_size = self.wopbs_key.param.polynomial_size.0;
862 }
863 let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
864
865 for value in 0..1 << (2 * total_bit) {
866 let value_1 = value % (1 << total_bit);
867 let value_2 = value >> total_bit;
868 let mut index_lut_1 = 0;
869 let mut index_lut_2 = 0;
870 let mut tmp = 1;
871 for (base, bit) in basis.iter().zip(bit.iter()) {
872 index_lut_1 += (((value_1 % base) << bit) / base) * tmp;
873 index_lut_2 += (((value_2 % base) << bit) / base) * tmp;
874 tmp <<= bit;
875 }
876 let index = (index_lut_2 << total_bit) + (index_lut_1);
877 for (j, b) in basis.iter().enumerate() {
878 vec_lut[j][index as usize] =
879 (((f(value_1, value_2) % b) as u128 * (1 << 64)) / *b as u128) as u64
880 }
881 }
882 vec_lut
883 }
884
885 /// bivariate WOPBS for native CRT
886 /// # Example
887 ///
888 /// ```rust
889 /// use concrete_integer::gen_keys;
890 /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
891 /// use concrete_integer::wopbs::WopbsKey;
892 ///
893 /// let basis: Vec<u64> = vec![9, 11];
894 ///
895 /// let param = PARAM_4_BITS_5_BLOCKS;
896 /// //Generate the client key and the server key:
897 /// let (cks, sks) = gen_keys(¶m);
898 /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
899 ///
900 /// let mut msg_space = 1;
901 /// for modulus in basis.iter() {
902 /// msg_space *= modulus;
903 /// }
904 /// let clear1 = 42 % msg_space;
905 /// let clear2 = 24 % msg_space;
906 /// let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone());
907 /// let mut ct2 = cks.encrypt_native_crt(clear2, basis.clone());
908 /// let lut = wopbs_key.generate_lut_bivariate_native_crt(&ct1, |x, y| x * y * 2);
909 /// let ct_res = wopbs_key.bivariate_wopbs_native_crt(&mut ct1, &mut ct2, &lut);
910 /// let res = cks.decrypt_native_crt(&ct_res);
911 /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space);
912 /// ```
913 pub fn bivariate_wopbs_native_crt(
914 &self,
915 ct1: &CrtCiphertext,
916 ct2: &CrtCiphertext,
917 lut: &[Vec<u64>],
918 ) -> CrtCiphertext {
919 self.circuit_bootstrap_vertical_packing_native_crt(&[ct1.clone(), ct2.clone()], lut)
920 }
921
922 fn circuit_bootstrap_vertical_packing_native_crt<T>(
923 &self,
924 vec_ct_in: &[T],
925 lut: &[Vec<u64>],
926 ) -> T
927 where
928 T: IntegerCiphertext,
929 {
930 let mut extracted_bits_blocks = vec![];
931 for ct_in in vec_ct_in.iter() {
932 let mut ct_in = ct_in.clone();
933 // Extraction of each bit for each block
934 for block in ct_in.blocks_mut().iter_mut() {
935 let nb_bit_to_extract =
936 f64::log2((block.message_modulus.0 * block.carry_modulus.0) as f64).ceil()
937 as usize;
938 let delta_log = DeltaLog(64 - nb_bit_to_extract);
939
940 // trick ( ct - delta/2 + delta/2^4 )
941 let lwe_size = block.ct.lwe_dimension().to_lwe_size().0;
942 let mut cont = vec![0u64; lwe_size];
943 cont[lwe_size - 1] =
944 (1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5));
945 let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(0))).unwrap();
946 let tmp = engine.create_lwe_ciphertext_from(cont).unwrap();
947 engine.fuse_sub_lwe_ciphertext(&mut block.ct, &tmp).unwrap();
948
949 let extracted_bits =
950 self.wopbs_key
951 .extract_bits(delta_log, block, nb_bit_to_extract);
952 extracted_bits_blocks.push(extracted_bits);
953 }
954 }
955
956 extracted_bits_blocks.reverse();
957
958 let vec_ct_out = self
959 .wopbs_key
960 .circuit_bootstrapping_vertical_packing(lut.to_vec(), extracted_bits_blocks);
961
962 let mut ct_vec_out: Vec<concrete_shortint::Ciphertext> = vec![];
963 for (block, block_out) in vec_ct_in[0].blocks().iter().zip(vec_ct_out.into_iter()) {
964 ct_vec_out.push(concrete_shortint::Ciphertext {
965 ct: block_out,
966 degree: Degree(block.message_modulus.0 - 1),
967 message_modulus: block.message_modulus,
968 carry_modulus: block.carry_modulus,
969 });
970 }
971 T::from_blocks(ct_vec_out)
972 }
973
974 pub fn keyswitch_to_wopbs_params<T>(&self, sks: &ServerKey, ct_in: &T) -> T
975 where
976 T: IntegerCiphertext,
977 {
978 let blocks: Vec<_> = ct_in
979 .blocks()
980 .par_iter()
981 .map(|block| self.wopbs_key.keyswitch_to_wopbs_params(&sks.key, block))
982 .collect();
983 T::from_blocks(blocks)
984 }
985
986 pub fn keyswitch_to_pbs_params<T>(&self, ct_in: &T) -> T
987 where
988 T: IntegerCiphertext,
989 {
990 let blocks: Vec<_> = ct_in
991 .blocks()
992 .par_iter()
993 .map(|block| self.wopbs_key.keyswitch_to_pbs_params(block))
994 .collect();
995 T::from_blocks(blocks)
996 }
997}