concrete_integer/wopbs/
mod.rs

1//! Module with the definition of the WopbsKey (WithOut padding PBS Key).
2//!
3//! This module implements the generation of another server public key, which allows to compute
4//! an alternative version of the programmable bootstrapping. This does not require the use of a
5//! bit of padding.
6#[cfg(test)]
7mod test;
8
9use crate::client_key::utils::i_crt;
10use crate::{ClientKey, CrtCiphertext, IntegerCiphertext, RadixCiphertext, ServerKey};
11use concrete_core::prelude::*;
12use concrete_shortint::ciphertext::Degree;
13use rayon::prelude::*;
14
15use concrete_shortint::Parameters;
16use serde::{Deserialize, Serialize};
17
18#[derive(Clone, Serialize, Deserialize)]
19pub struct WopbsKey {
20    wopbs_key: concrete_shortint::wopbs::WopbsKey,
21}
22
23/// ```rust
24/// use concrete_integer::wopbs::{decode_radix, encode_radix};
25///
26/// let val = 11;
27/// let basis = 2;
28/// let nb_block = 5;
29/// let radix = encode_radix(val, basis, nb_block);
30///
31/// assert_eq!(val, decode_radix(radix, basis));
32/// ```
33pub fn encode_radix(val: u64, basis: u64, nb_block: u64) -> Vec<u64> {
34    let mut output = vec![];
35    //Bits of message put to 1éfé
36    let mask = (basis - 1) as u64;
37
38    let mut power = 1_u64;
39    //Put each decomposition into a new ciphertext
40    for _ in 0..nb_block {
41        let mut decomp = val & (mask * power);
42        decomp /= power;
43
44        // fill the vector with the message moduli
45        output.push(decomp);
46
47        //modulus to the power i
48        power *= basis;
49    }
50    output
51}
52
53pub fn encode_crt(val: u64, basis: &[u64]) -> Vec<u64> {
54    let mut output = vec![];
55    //Put each decomposition into a new ciphertext
56    for i in basis {
57        output.push(val % i);
58    }
59    output
60}
61
62//Concatenate two ciphertexts in one
63//Used to compute bivariate wopbs
64fn ciphertext_concatenation<T>(ct1: &T, ct2: &T) -> T
65where
66    T: IntegerCiphertext,
67{
68    let mut new_blocks = ct1.blocks().to_vec();
69    new_blocks.extend_from_slice(ct2.blocks());
70    T::from_blocks(new_blocks)
71}
72
73pub fn encode_mix_radix(mut val: u64, basis: &[u64], modulus: u64) -> Vec<u64> {
74    let mut output = vec![];
75    for basis in basis.iter() {
76        output.push(val % modulus);
77        val -= val % modulus;
78        let tmp = (val % (1 << basis)) >> (f64::log2(modulus as f64) as u64);
79        val >>= basis;
80        val += tmp;
81    }
82    output
83}
84
85// Example: val = 5 = 0b101 , basis = [1,2] -> output = [1, 1]
86/// ```rust
87/// use concrete_integer::wopbs::split_value_according_to_bit_basis;
88/// // Generate the client key and the server key:
89/// let val = 5;
90/// let basis = vec![1, 2];
91/// assert_eq!(vec![1, 2], split_value_according_to_bit_basis(val, &basis));
92/// ```
93pub fn split_value_according_to_bit_basis(value: u64, basis: &[u64]) -> Vec<u64> {
94    let mut output = vec![];
95    let mut tmp = value;
96    let mask = 1;
97
98    for i in basis {
99        let mut tmp_output = 0;
100        for j in 0..*i {
101            let val = tmp & mask;
102            tmp_output += val << j;
103            tmp >>= 1;
104        }
105        output.push(tmp_output);
106    }
107    output
108}
109
110/// ```rust
111/// use concrete_integer::wopbs::{decode_radix, encode_radix};
112///
113/// let val = 11;
114/// let basis = 2;
115/// let nb_block = 5;
116/// assert_eq!(val, decode_radix(encode_radix(val, basis, nb_block), basis));
117/// ```
118pub fn decode_radix(val: Vec<u64>, basis: u64) -> u64 {
119    let mut result = 0_u64;
120    let mut shift = 1_u64;
121    for v_i in val.iter() {
122        //decrypt the component i of the integer and multiply it by the radix product
123        let tmp = v_i.wrapping_mul(shift);
124
125        // update the result
126        result = result.wrapping_add(tmp);
127
128        // update the shift for the next iteration
129        shift = shift.wrapping_mul(basis);
130    }
131    result
132}
133
134impl From<concrete_shortint::wopbs::WopbsKey> for WopbsKey {
135    fn from(wopbs_key: concrete_shortint::wopbs::WopbsKey) -> Self {
136        Self { wopbs_key }
137    }
138}
139
140impl WopbsKey {
141    /// Generates the server key required to compute a WoPBS from the client and the server keys.
142    /// # Example
143    /// ```rust
144    /// use concrete_integer::gen_keys;
145    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_1_CARRY_1;
146    /// use concrete_integer::wopbs::*;
147    /// use concrete_shortint::parameters::PARAM_MESSAGE_1_CARRY_1;
148    ///
149    /// // Generate the client key and the server key:
150    /// let (mut cks, mut sks) = gen_keys(&PARAM_MESSAGE_1_CARRY_1);
151    /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_1_CARRY_1);
152    /// ```
153    pub fn new_wopbs_key(cks: &ClientKey, sks: &ServerKey, parameters: &Parameters) -> WopbsKey {
154        WopbsKey {
155            wopbs_key: concrete_shortint::wopbs::WopbsKey::new_wopbs_key(
156                &cks.key, &sks.key, parameters,
157            ),
158        }
159    }
160
161    pub fn new_from_shortint(wopbskey: &concrete_shortint::wopbs::WopbsKey) -> WopbsKey {
162        let key = wopbskey.clone();
163        WopbsKey { wopbs_key: key }
164    }
165
166    pub fn new_wopbs_key_only_for_wopbs(cks: &ClientKey, sks: &ServerKey) -> WopbsKey {
167        WopbsKey {
168            wopbs_key: concrete_shortint::wopbs::WopbsKey::new_wopbs_key_only_for_wopbs(
169                &cks.key, &sks.key,
170            ),
171        }
172    }
173
174    /// Computes the WoP-PBS given the luts.
175    ///
176    /// This works for both RadixCiphertext and CrtCiphertext.
177    ///
178    /// # Example
179    ///
180    /// ```rust
181    /// use concrete_integer::gen_keys;
182    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
183    /// use concrete_integer::wopbs::*;
184    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
185    ///
186    /// let nb_block = 3;
187    /// //Generate the client key and the server key:
188    /// let (mut cks, mut sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
189    /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
190    /// let mut moduli = 1_u64;
191    /// for _ in 0..nb_block{
192    ///     moduli *= cks.parameters().message_modulus.0 as u64;
193    /// }
194    /// let clear = 42 % moduli;
195    /// let ct = cks.encrypt_radix(clear as u64, nb_block);
196    /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct);
197    /// let lut = wopbs_key.generate_lut_radix(&ct, |x|x);
198    /// let ct_res = wopbs_key.wopbs(&ct, &lut);
199    /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
200    /// let res = cks.decrypt_radix(&ct_res);
201    ///
202    ///  assert_eq!(res, clear);
203    /// ```
204    pub fn wopbs<T>(&self, ct_in: &T, lut: &[Vec<u64>]) -> T
205    where
206        T: IntegerCiphertext,
207    {
208        let mut extracted_bits_blocks = vec![];
209        // Extraction of each bit for each block
210        for block in ct_in.blocks().iter() {
211            let delta = (1_usize << 63)
212                / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0);
213            let delta_log = DeltaLog(f64::log2(delta as f64) as usize);
214            let nb_bit_to_extract = f64::log2((block.degree.0 + 1) as f64).ceil() as usize;
215
216            let extracted_bits = self
217                .wopbs_key
218                .extract_bits(delta_log, block, nb_bit_to_extract);
219
220            extracted_bits_blocks.push(extracted_bits);
221        }
222
223        extracted_bits_blocks.reverse();
224        let vec_ct_out = self
225            .wopbs_key
226            .circuit_bootstrapping_vertical_packing(lut.to_vec(), extracted_bits_blocks);
227
228        let mut ct_vec_out = vec![];
229        for (block, block_out) in ct_in.blocks().iter().zip(vec_ct_out.into_iter()) {
230            ct_vec_out.push(concrete_shortint::Ciphertext {
231                ct: block_out,
232                degree: Degree(block.message_modulus.0 - 1),
233                message_modulus: block.message_modulus,
234                carry_modulus: block.carry_modulus,
235            });
236        }
237        T::from_blocks(ct_vec_out)
238    }
239
240    /// # Example
241    /// ```rust
242    /// use concrete_integer::gen_keys;
243    /// use concrete_integer::wopbs::WopbsKey;
244    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
245    ///
246    /// let nb_block = 3;
247    /// //Generate the client key and the server key:
248    /// let (mut cks, mut sks) = gen_keys(&WOPBS_PARAM_MESSAGE_2_CARRY_2);
249    /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
250    /// let mut moduli = 1_u64;
251    /// for _ in 0..nb_block{
252    ///     moduli *= cks.parameters().message_modulus.0 as u64;
253    /// }
254    /// let clear = 15 % moduli;
255    /// let ct = cks.encrypt_radix_without_padding(clear as u64, nb_block);
256    /// let lut = wopbs_key.generate_lut_radix_without_padding(&ct, |x| 2 * x);
257    /// let ct_res = wopbs_key.wopbs_without_padding(&ct, &lut);
258    /// let res = cks.decrypt_radix_without_padding(&ct_res);
259    ///
260    /// assert_eq!(res, (clear * 2) % moduli)
261    /// ```
262    pub fn wopbs_without_padding<T>(&self, ct_in: &T, lut: &[Vec<u64>]) -> T
263    where
264        T: IntegerCiphertext,
265    {
266        let mut extracted_bits_blocks = vec![];
267        let mut ct_in = ct_in.clone();
268        // Extraction of each bit for each block
269        for block in ct_in.blocks_mut().iter_mut() {
270            let delta = (1_usize << 63) / (block.message_modulus.0 * block.carry_modulus.0 / 2);
271            let delta_log = DeltaLog(f64::log2(delta as f64) as usize);
272            let nb_bit_to_extract =
273                f64::log2((block.message_modulus.0 * block.carry_modulus.0) as f64) as usize;
274
275            let extracted_bits = self
276                .wopbs_key
277                .extract_bits(delta_log, block, nb_bit_to_extract);
278            extracted_bits_blocks.push(extracted_bits);
279        }
280
281        extracted_bits_blocks.reverse();
282
283        let vec_ct_out = self
284            .wopbs_key
285            .circuit_bootstrapping_vertical_packing(lut.to_vec(), extracted_bits_blocks);
286
287        let mut ct_vec_out = vec![];
288        for (block, block_out) in ct_in.blocks().iter().zip(vec_ct_out.into_iter()) {
289            ct_vec_out.push(concrete_shortint::Ciphertext {
290                ct: block_out,
291                degree: Degree(block.message_modulus.0 - 1),
292                message_modulus: block.message_modulus,
293                carry_modulus: block.carry_modulus,
294            });
295        }
296        T::from_blocks(ct_vec_out)
297    }
298
299    /// WOPBS for native CRT
300    /// # Example
301    /// ```rust
302    /// use concrete_integer::gen_keys;
303    /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
304    /// use concrete_integer::wopbs::WopbsKey;
305    ///
306    /// let basis: Vec<u64> = vec![9, 11];
307    ///
308    /// let param = PARAM_4_BITS_5_BLOCKS;
309    /// //Generate the client key and the server key:
310    /// let (cks, sks) = gen_keys(&param);
311    /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
312    ///
313    /// let mut msg_space = 1;
314    /// for modulus in basis.iter() {
315    ///     msg_space *= modulus;
316    /// }
317    /// let clear = 42 % msg_space; // Encrypt the integers
318    /// let mut ct = cks.encrypt_native_crt(clear, basis.clone());
319    /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x);
320    /// let ct_res = wopbs_key.wopbs_native_crt(&mut ct, &lut);
321    /// let res = cks.decrypt_native_crt(&ct_res);
322    /// assert_eq!(res, clear);
323    /// ```
324    pub fn wopbs_native_crt(&self, ct1: &CrtCiphertext, lut: &[Vec<u64>]) -> CrtCiphertext {
325        self.circuit_bootstrap_vertical_packing_native_crt(&[ct1.clone()], lut)
326    }
327
328    /// # Example
329    /// ```rust
330    /// use concrete_integer::gen_keys;
331    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
332    /// use concrete_integer::wopbs::*;
333    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
334    ///
335    /// let nb_block = 3;
336    /// //Generate the client key and the server key:
337    /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
338    ///
339    /// //Generate wopbs_v0 key    ///
340    /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
341    /// let mut moduli = 1_u64;
342    /// for _ in 0..nb_block{
343    ///     moduli *= cks.parameters().message_modulus.0 as u64;
344    /// }
345    /// let clear1 = 42 % moduli;
346    /// let clear2 = 24 % moduli;
347    /// let ct1 = cks.encrypt_radix(clear1 as u64, nb_block);
348    /// let ct2 = cks.encrypt_radix(clear2 as u64, nb_block);
349    ///
350    /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1);
351    /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2);
352    /// let lut = wopbs_key.generate_lut_bivariate_radix(&ct1, &ct2, |x,y| 2 * x * y);
353    /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(& ct1, & ct2, &lut);
354    /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
355    /// let res = cks.decrypt_radix(&ct_res);
356    ///
357    ///  assert_eq!(res, (2 * clear1 * clear2) % moduli);
358    /// ```
359    pub fn bivariate_wopbs_with_degree<T>(&self, ct1: &T, ct2: &T, lut: &[Vec<u64>]) -> T
360    where
361        T: IntegerCiphertext,
362    {
363        let ct = ciphertext_concatenation(ct1, ct2);
364        self.wopbs(&ct, lut)
365    }
366
367    /// # Example
368    ///
369    /// ```rust
370    /// use concrete_integer::gen_keys;
371    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
372    /// use concrete_integer::wopbs::*;
373    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
374    ///
375    /// let nb_block = 3;
376    /// //Generate the client key and the server key:
377    /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
378    ///
379    /// //Generate wopbs_v0 key    ///
380    /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
381    /// let mut moduli = 1_u64;
382    /// for _ in 0..nb_block{
383    ///     moduli *= cks.parameters().message_modulus.0 as u64;
384    /// }
385    /// let clear = 42 % moduli;
386    /// let ct = cks.encrypt_radix(clear as u64, nb_block);
387    /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct);
388    /// let lut = wopbs_key.generate_lut_radix(&ct, |x| 2 * x);
389    /// let ct_res = wopbs_key.wopbs(&ct, &lut);
390    /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
391    /// let res = cks.decrypt_radix(&ct_res);
392    ///
393    ///  assert_eq!(res, (2 * clear) % moduli);
394    /// ```
395    pub fn generate_lut_radix<F, T>(&self, ct: &T, f: F) -> Vec<Vec<u64>>
396    where
397        F: Fn(u64) -> u64,
398        T: IntegerCiphertext,
399    {
400        let mut total_bit = 0;
401        let block_nb = ct.blocks().len();
402        let mut modulus = 1;
403
404        //This contains the basis of each block depending on the degree
405        let mut vec_deg_basis = vec![];
406
407        for (i, deg) in ct.moduli().iter().zip(ct.blocks().iter()) {
408            modulus *= i;
409            let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
410            vec_deg_basis.push(b);
411            total_bit += b;
412        }
413
414        let mut lut_size = 1 << total_bit;
415        if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
416            lut_size = self.wopbs_key.param.polynomial_size.0;
417        }
418        let mut vec_lut = vec![vec![0; lut_size]; ct.blocks().len()];
419
420        let basis = ct.moduli()[0];
421        let delta: u64 = (1 << 63)
422            / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
423                as u64;
424
425        for lut_index_val in 0..(1 << total_bit) {
426            let encoded_with_deg_val = encode_mix_radix(lut_index_val, &vec_deg_basis, basis);
427            let decoded_val = decode_radix(encoded_with_deg_val.clone(), basis);
428            let f_val = f(decoded_val % modulus) % modulus;
429            let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
430            for lut_number in 0..block_nb {
431                vec_lut[lut_number as usize][lut_index_val as usize] =
432                    encoded_f_val[lut_number] * delta;
433            }
434        }
435        vec_lut
436    }
437
438    /// # Example
439    /// ```rust
440    /// use concrete_integer::gen_keys;
441    /// use concrete_integer::wopbs::WopbsKey;
442    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
443    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
444    ///
445    /// let nb_block = 3;
446    /// //Generate the client key and the server key:
447    /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
448    /// //Generate wopbs_v0 key
449    /// let mut wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
450    /// let mut moduli = 1_u64;
451    /// for _ in 0..nb_block{
452    ///     moduli *= cks.parameters().message_modulus.0 as u64;
453    /// }
454    /// let clear = 15 % moduli;
455    /// let ct = cks.encrypt_radix_without_padding(clear as u64, nb_block);
456    /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct);
457    /// let lut = wopbs_key.generate_lut_radix_without_padding(&ct, |x| 2 * x);
458    /// let ct_res = wopbs_key.wopbs_without_padding(&ct, &lut);
459    /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
460    /// let res = cks.decrypt_radix_without_padding(&ct_res);
461    ///
462    /// assert_eq!(res, (clear * 2) % moduli)
463    /// ```
464    pub fn generate_lut_radix_without_padding<F, T>(&self, ct: &T, f: F) -> Vec<Vec<u64>>
465    where
466        F: Fn(u64) -> u64,
467        T: IntegerCiphertext,
468    {
469        let log_message_modulus = f64::log2((self.wopbs_key.param.message_modulus.0) as f64) as u64;
470        let log_carry_modulus = f64::log2((self.wopbs_key.param.carry_modulus.0) as f64) as u64;
471        let log_basis = log_message_modulus + log_carry_modulus;
472        let delta = 64 - log_basis;
473        let nb_block = ct.blocks().len();
474        let poly_size = self.wopbs_key.param.polynomial_size.0;
475        let mut lut_size = 1 << (nb_block * log_basis as usize);
476        if lut_size < poly_size {
477            lut_size = poly_size;
478        }
479        let mut vec_lut = vec![vec![0; lut_size]; nb_block];
480
481        for index in 0..lut_size {
482            // find the value represented by the index
483            let mut value = 0;
484            let mut tmp_index = index;
485            for i in 0..nb_block as u64 {
486                let tmp = tmp_index % (1 << (log_basis * (i + 1)));
487                tmp_index -= tmp;
488                value += tmp >> (log_carry_modulus * i);
489            }
490
491            // fill the LUTs
492            for (block_index, lut_block) in vec_lut.iter_mut().enumerate().take(nb_block) {
493                lut_block[index] = ((f(value as u64) >> (log_carry_modulus * block_index as u64))
494                    % (1 << log_message_modulus))
495                    << delta
496            }
497        }
498        vec_lut
499    }
500
501    /// generate lut for native CRT
502    /// # Example
503    ///
504    /// ```rust
505    /// use concrete_integer::gen_keys;
506    /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
507    /// use concrete_integer::wopbs::WopbsKey;
508    ///
509    /// let basis: Vec<u64> = vec![9, 11];
510    ///
511    /// let param = PARAM_4_BITS_5_BLOCKS;
512    /// //Generate the client key and the server key:
513    /// let (cks, sks) = gen_keys(&param);
514    /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
515    ///
516    /// let mut msg_space = 1;
517    /// for modulus in basis.iter() {
518    ///     msg_space *= modulus;
519    /// }
520    /// let clear = 42 % msg_space; // Encrypt the integers
521    /// let mut ct = cks.encrypt_native_crt(clear, basis.clone());
522    /// let lut = wopbs_key.generate_lut_native_crt(&ct, |x| x);
523    /// let ct_res = wopbs_key.wopbs_native_crt(&mut ct, &lut);
524    /// let res = cks.decrypt_native_crt(&ct_res);
525    /// assert_eq!(res, clear);
526    /// ```
527    pub fn generate_lut_native_crt<F>(&self, ct: &CrtCiphertext, f: F) -> Vec<Vec<u64>>
528    where
529        F: Fn(u64) -> u64,
530    {
531        let mut bit = vec![];
532        let mut total_bit = 0;
533        let mut modulus = 1;
534        let basis: Vec<_> = ct.moduli();
535
536        for i in basis.iter() {
537            modulus *= i;
538            let b = f64::log2(*i as f64).ceil() as u64;
539            total_bit += b;
540            bit.push(b);
541        }
542        let mut lut_size = 1 << total_bit;
543        if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
544            lut_size = self.wopbs_key.param.polynomial_size.0;
545        }
546        let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
547
548        for value in 0..modulus {
549            let mut index_lut = 0;
550            let mut tmp = 1;
551            for (base, bit) in basis.iter().zip(bit.iter()) {
552                index_lut += (((value % base) << bit) / base) * tmp;
553                tmp <<= bit;
554            }
555            for (j, b) in basis.iter().enumerate() {
556                vec_lut[j][index_lut as usize] =
557                    (((f(value) % b) as u128 * (1 << 64)) / *b as u128) as u64
558            }
559        }
560        vec_lut
561    }
562
563    /// generate LUt for crt
564    /// # Example
565    /// ```rust
566    /// 
567    /// use concrete_integer::gen_keys;
568    /// use concrete_integer::wopbs::*;
569    /// use concrete_shortint::parameters::PARAM_MESSAGE_3_CARRY_3;
570    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_3_CARRY_3;
571    ///
572    /// let basis : Vec<u64> = vec![5,7];
573    /// let nb_block = basis.len();
574    ///
575    /// //Generate the client key and the server key:
576    /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3);
577    /// let wopbs_key =  WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_3_CARRY_3);
578    ///
579    /// let mut msg_space = 1;
580    /// for modulus in basis.iter() {
581    ///     msg_space *= modulus;
582    /// }
583    /// let clear = 42 % msg_space;
584    /// let ct = cks.encrypt_crt(clear, basis.clone());
585    /// let ct = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct);
586    /// let lut = wopbs_key.generate_lut_crt(&ct, |x| x);
587    /// let ct_res = wopbs_key.wopbs(&ct, &lut);
588    /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
589    /// let res = cks.decrypt_crt(&ct_res);
590    /// assert_eq!(res, clear);
591    /// ```
592    pub fn generate_lut_crt<F>(&self, ct: &CrtCiphertext, f: F) -> Vec<Vec<u64>>
593    where
594        F: Fn(u64) -> u64,
595    {
596        let mut bit = vec![];
597        let mut total_bit = 0;
598        let mut modulus = 1;
599        let basis = ct.moduli();
600
601        for (i, deg) in basis.iter().zip(ct.blocks.iter()) {
602            modulus *= i;
603            let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
604            total_bit += b;
605            bit.push(b);
606        }
607        let mut lut_size = 1 << total_bit;
608        if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
609            lut_size = self.wopbs_key.param.polynomial_size.0;
610        }
611        let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
612
613        for i in 0..(1 << total_bit) {
614            let mut value = i;
615            for (j, block) in ct.blocks.iter().enumerate() {
616                let deg = f64::log2((block.degree.0 + 1) as f64).ceil() as u64;
617                let delta: u64 = (1 << 63)
618                    / (self.wopbs_key.param.message_modulus.0
619                        * self.wopbs_key.param.carry_modulus.0) as u64;
620                vec_lut[j][i as usize] =
621                    ((f((value % (1 << deg)) % block.message_modulus.0 as u64))
622                        % block.message_modulus.0 as u64)
623                        * delta;
624                value >>= deg;
625            }
626        }
627        vec_lut
628    }
629
630    /// # Example
631    ///
632    /// ```rust
633    /// use concrete_integer::gen_keys;
634    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_2_CARRY_2;
635    /// use concrete_integer::wopbs::*;
636    /// use concrete_shortint::parameters::PARAM_MESSAGE_2_CARRY_2;
637    ///
638    /// let nb_block = 3;
639    /// //Generate the client key and the server key:
640    /// let (cks, sks) = gen_keys(&PARAM_MESSAGE_2_CARRY_2);
641    ///
642    /// //Generate wopbs_v0 key    ///
643    /// let wopbs_key = WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_2_CARRY_2);
644    /// let mut moduli = 1_u64;
645    /// for _ in 0..nb_block{
646    ///     moduli *= cks.parameters().message_modulus.0 as u64;
647    /// }
648    /// let clear1 = 42 % moduli;
649    /// let clear2 = 24 % moduli;
650    /// let ct1 = cks.encrypt_radix(clear1 as u64, nb_block);
651    /// let ct2 = cks.encrypt_radix(clear2 as u64, nb_block);
652    ///
653    /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct1);
654    /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks,&ct2);
655    /// let lut = wopbs_key.generate_lut_bivariate_radix(&ct1, &ct2, |x,y| 2 * x * y);
656    /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut);
657    /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
658    /// let res = cks.decrypt_radix(&ct_res);
659    ///
660    ///  assert_eq!(res, (2 * clear1 * clear2) % moduli);
661    /// ```
662    pub fn generate_lut_bivariate_radix<F>(
663        &self,
664        ct1: &RadixCiphertext,
665        ct2: &RadixCiphertext,
666        f: F,
667    ) -> Vec<Vec<u64>>
668    where
669        F: Fn(u64, u64) -> u64,
670    {
671        let mut nb_bit_to_extract = vec![0; 2];
672        let block_nb = ct1.blocks.len();
673        //ct2 & ct1 should have the same basis
674        let basis = ct1.moduli();
675
676        //This contains the basis of each block depending on the degree
677        let mut vec_deg_basis = vec![vec![]; 2];
678
679        let mut modulus = 1;
680        for (ct_num, ct) in [ct1, ct2].iter().enumerate() {
681            modulus = 1;
682            for deg in ct.blocks.iter() {
683                modulus *= self.wopbs_key.param.message_modulus.0 as u64;
684                let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
685                vec_deg_basis[ct_num].push(b);
686                nb_bit_to_extract[ct_num] += b;
687            }
688        }
689
690        let total_bit: u64 = nb_bit_to_extract.iter().sum();
691
692        let mut lut_size = 1 << total_bit;
693        if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
694            lut_size = self.wopbs_key.param.polynomial_size.0;
695        }
696        let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
697        let basis = ct1.moduli()[0];
698
699        let delta: u64 = (1 << 63)
700            / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
701                as u64;
702
703        for lut_index_val in 0..(1 << total_bit) {
704            let split = vec![
705                lut_index_val % (1 << nb_bit_to_extract[0]),
706                lut_index_val >> nb_bit_to_extract[0],
707            ];
708            let mut decoded_val = vec![0; 2];
709            for i in 0..2 {
710                let encoded_with_deg_val = encode_mix_radix(split[i], &vec_deg_basis[i], basis);
711                decoded_val[i] = decode_radix(encoded_with_deg_val.clone(), basis);
712            }
713            let f_val = f(decoded_val[0] % modulus, decoded_val[1] % modulus) % modulus;
714            let encoded_f_val = encode_radix(f_val, basis, block_nb as u64);
715            for lut_number in 0..block_nb {
716                vec_lut[lut_number as usize][lut_index_val as usize] =
717                    encoded_f_val[lut_number] * delta;
718            }
719        }
720        vec_lut
721    }
722
723    /// generate bivariate LUT for 'fake' CRT
724    ///
725    /// # Example
726    ///
727    /// ```rust
728    /// 
729    /// use concrete_integer::gen_keys;
730    /// use concrete_integer::wopbs::*;
731    /// use concrete_shortint::parameters::PARAM_MESSAGE_3_CARRY_3;
732    /// use concrete_shortint::parameters::parameters_wopbs_message_carry::WOPBS_PARAM_MESSAGE_3_CARRY_3;
733    ///
734    /// let basis : Vec<u64> = vec![5,7];
735    /// //Generate the client key and the server key:
736    /// let ( cks, sks) = gen_keys(&PARAM_MESSAGE_3_CARRY_3);
737    /// let wopbs_key =  WopbsKey::new_wopbs_key(&cks, &sks, &WOPBS_PARAM_MESSAGE_3_CARRY_3);
738    ///
739    /// let mut msg_space = 1;
740    /// for modulus in basis.iter() {
741    ///     msg_space *= modulus;
742    /// }
743    /// let clear1 = 42 % msg_space;    // Encrypt the integers
744    /// let clear2 = 24 % msg_space;    // Encrypt the integers
745    /// let ct1 = cks.encrypt_crt(clear1, basis.clone());
746    /// let ct2 = cks.encrypt_crt(clear2, basis.clone());
747    ///
748    /// let ct1 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct1);
749    /// let ct2 = wopbs_key.keyswitch_to_wopbs_params(&sks, &ct2);
750    ///
751    ///
752    /// let lut = wopbs_key.generate_lut_bivariate_crt(&ct1, &ct2, |x,y| x * y * 2);
753    /// let ct_res = wopbs_key.bivariate_wopbs_with_degree(&ct1, &ct2, &lut);
754    /// let ct_res = wopbs_key.keyswitch_to_pbs_params(&ct_res);
755    /// let res = cks.decrypt_crt(&ct_res);
756    /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space );
757    /// ```
758    pub fn generate_lut_bivariate_crt<F>(
759        &self,
760        ct1: &CrtCiphertext,
761        ct2: &CrtCiphertext,
762        f: F,
763    ) -> Vec<Vec<u64>>
764    where
765        F: Fn(u64, u64) -> u64,
766    {
767        let mut bit = vec![];
768        let mut nb_bit_to_extract = vec![0; 2];
769        let mut modulus = 1;
770
771        //ct2 & ct1 should have the same basis
772        let basis = ct1.moduli();
773
774        for (ct_num, ct) in [ct1, ct2].iter().enumerate() {
775            for (i, deg) in basis.iter().zip(ct.blocks.iter()) {
776                modulus *= i;
777                let b = f64::log2((deg.degree.0 + 1) as f64).ceil() as u64;
778                nb_bit_to_extract[ct_num] += b;
779                bit.push(b);
780            }
781        }
782
783        let total_bit: u64 = nb_bit_to_extract.iter().sum();
784
785        let mut lut_size = 1 << total_bit;
786        if 1 << total_bit < self.wopbs_key.param.polynomial_size.0 as u64 {
787            lut_size = self.wopbs_key.param.polynomial_size.0;
788        }
789        let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
790
791        let delta: u64 = (1 << 63)
792            / (self.wopbs_key.param.message_modulus.0 * self.wopbs_key.param.carry_modulus.0)
793                as u64;
794
795        for index in 0..(1 << total_bit) {
796            let mut split = encode_radix(index, 1 << nb_bit_to_extract[0], 2);
797            let mut crt_value = vec![vec![0; ct1.blocks.len()]; 2];
798            for (j, base) in basis.iter().enumerate().take(ct1.blocks.len()) {
799                let deg_1 = f64::log2((ct1.blocks[j].degree.0 + 1) as f64).ceil() as u64;
800                let deg_2 = f64::log2((ct2.blocks[j].degree.0 + 1) as f64).ceil() as u64;
801                crt_value[0][j] = (split[0] % (1 << deg_1)) % base;
802                crt_value[1][j] = (split[1] % (1 << deg_2)) % base;
803                split[0] >>= deg_1;
804                split[1] >>= deg_2;
805            }
806            let value_1 = i_crt(&ct1.moduli(), &crt_value[0]);
807            let value_2 = i_crt(&ct2.moduli(), &crt_value[1]);
808            for (j, current_mod) in basis.iter().enumerate() {
809                let value = f(value_1, value_2) % current_mod;
810                vec_lut[j][index as usize] = (value % current_mod) * delta;
811            }
812        }
813
814        vec_lut
815    }
816
817    /// generate bivariate LUT for 'true' CRT
818    /// # Example
819    ///
820    /// ```rust
821    /// use concrete_integer::gen_keys;
822    /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
823    /// use concrete_integer::wopbs::WopbsKey;
824    ///
825    /// let basis: Vec<u64> = vec![9, 11];
826    ///
827    /// let param = PARAM_4_BITS_5_BLOCKS;
828    /// //Generate the client key and the server key:
829    /// let (cks, sks) = gen_keys(&param);
830    /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
831    ///
832    /// let mut msg_space = 1;
833    /// for modulus in basis.iter() {
834    ///     msg_space *= modulus;
835    /// }
836    /// let clear1 = 42 % msg_space;
837    /// let clear2 = 24 % msg_space;
838    /// let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone());
839    /// let mut ct2 = cks.encrypt_native_crt(clear2, basis.clone());
840    /// let lut = wopbs_key.generate_lut_bivariate_native_crt(&ct1, |x, y| x * y * 2);
841    /// let ct_res = wopbs_key.bivariate_wopbs_native_crt(&mut ct1, &mut ct2, &lut);
842    /// let res = cks.decrypt_native_crt(&ct_res);
843    /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space);
844    /// ```
845    pub fn generate_lut_bivariate_native_crt<F>(&self, ct_1: &CrtCiphertext, f: F) -> Vec<Vec<u64>>
846    where
847        F: Fn(u64, u64) -> u64,
848    {
849        let mut bit = vec![];
850        let mut total_bit = 0;
851        let mut modulus = 1;
852        let basis = ct_1.moduli();
853        for i in basis.iter() {
854            modulus *= i;
855            let b = f64::log2(*i as f64).ceil() as u64;
856            total_bit += b;
857            bit.push(b);
858        }
859        let mut lut_size = 1 << (2 * total_bit);
860        if 1 << (2 * total_bit) < self.wopbs_key.param.polynomial_size.0 as u64 {
861            lut_size = self.wopbs_key.param.polynomial_size.0;
862        }
863        let mut vec_lut = vec![vec![0; lut_size]; basis.len()];
864
865        for value in 0..1 << (2 * total_bit) {
866            let value_1 = value % (1 << total_bit);
867            let value_2 = value >> total_bit;
868            let mut index_lut_1 = 0;
869            let mut index_lut_2 = 0;
870            let mut tmp = 1;
871            for (base, bit) in basis.iter().zip(bit.iter()) {
872                index_lut_1 += (((value_1 % base) << bit) / base) * tmp;
873                index_lut_2 += (((value_2 % base) << bit) / base) * tmp;
874                tmp <<= bit;
875            }
876            let index = (index_lut_2 << total_bit) + (index_lut_1);
877            for (j, b) in basis.iter().enumerate() {
878                vec_lut[j][index as usize] =
879                    (((f(value_1, value_2) % b) as u128 * (1 << 64)) / *b as u128) as u64
880            }
881        }
882        vec_lut
883    }
884
885    /// bivariate WOPBS for native CRT
886    /// # Example
887    ///
888    /// ```rust
889    /// use concrete_integer::gen_keys;
890    /// use concrete_integer::parameters::PARAM_4_BITS_5_BLOCKS;
891    /// use concrete_integer::wopbs::WopbsKey;
892    ///
893    /// let basis: Vec<u64> = vec![9, 11];
894    ///
895    /// let param = PARAM_4_BITS_5_BLOCKS;
896    /// //Generate the client key and the server key:
897    /// let (cks, sks) = gen_keys(&param);
898    /// let mut wopbs_key = WopbsKey::new_wopbs_key_only_for_wopbs(&cks, &sks);
899    ///
900    /// let mut msg_space = 1;
901    /// for modulus in basis.iter() {
902    ///     msg_space *= modulus;
903    /// }
904    /// let clear1 = 42 % msg_space;
905    /// let clear2 = 24 % msg_space;
906    /// let mut ct1 = cks.encrypt_native_crt(clear1, basis.clone());
907    /// let mut ct2 = cks.encrypt_native_crt(clear2, basis.clone());
908    /// let lut = wopbs_key.generate_lut_bivariate_native_crt(&ct1, |x, y| x * y * 2);
909    /// let ct_res = wopbs_key.bivariate_wopbs_native_crt(&mut ct1, &mut ct2, &lut);
910    /// let res = cks.decrypt_native_crt(&ct_res);
911    /// assert_eq!(res, (clear1 * clear2 * 2) % msg_space);
912    /// ```
913    pub fn bivariate_wopbs_native_crt(
914        &self,
915        ct1: &CrtCiphertext,
916        ct2: &CrtCiphertext,
917        lut: &[Vec<u64>],
918    ) -> CrtCiphertext {
919        self.circuit_bootstrap_vertical_packing_native_crt(&[ct1.clone(), ct2.clone()], lut)
920    }
921
922    fn circuit_bootstrap_vertical_packing_native_crt<T>(
923        &self,
924        vec_ct_in: &[T],
925        lut: &[Vec<u64>],
926    ) -> T
927    where
928        T: IntegerCiphertext,
929    {
930        let mut extracted_bits_blocks = vec![];
931        for ct_in in vec_ct_in.iter() {
932            let mut ct_in = ct_in.clone();
933            // Extraction of each bit for each block
934            for block in ct_in.blocks_mut().iter_mut() {
935                let nb_bit_to_extract =
936                    f64::log2((block.message_modulus.0 * block.carry_modulus.0) as f64).ceil()
937                        as usize;
938                let delta_log = DeltaLog(64 - nb_bit_to_extract);
939
940                // trick ( ct - delta/2 + delta/2^4  )
941                let lwe_size = block.ct.lwe_dimension().to_lwe_size().0;
942                let mut cont = vec![0u64; lwe_size];
943                cont[lwe_size - 1] =
944                    (1 << (64 - nb_bit_to_extract - 1)) - (1 << (64 - nb_bit_to_extract - 5));
945                let mut engine = DefaultEngine::new(Box::new(UnixSeeder::new(0))).unwrap();
946                let tmp = engine.create_lwe_ciphertext_from(cont).unwrap();
947                engine.fuse_sub_lwe_ciphertext(&mut block.ct, &tmp).unwrap();
948
949                let extracted_bits =
950                    self.wopbs_key
951                        .extract_bits(delta_log, block, nb_bit_to_extract);
952                extracted_bits_blocks.push(extracted_bits);
953            }
954        }
955
956        extracted_bits_blocks.reverse();
957
958        let vec_ct_out = self
959            .wopbs_key
960            .circuit_bootstrapping_vertical_packing(lut.to_vec(), extracted_bits_blocks);
961
962        let mut ct_vec_out: Vec<concrete_shortint::Ciphertext> = vec![];
963        for (block, block_out) in vec_ct_in[0].blocks().iter().zip(vec_ct_out.into_iter()) {
964            ct_vec_out.push(concrete_shortint::Ciphertext {
965                ct: block_out,
966                degree: Degree(block.message_modulus.0 - 1),
967                message_modulus: block.message_modulus,
968                carry_modulus: block.carry_modulus,
969            });
970        }
971        T::from_blocks(ct_vec_out)
972    }
973
974    pub fn keyswitch_to_wopbs_params<T>(&self, sks: &ServerKey, ct_in: &T) -> T
975    where
976        T: IntegerCiphertext,
977    {
978        let blocks: Vec<_> = ct_in
979            .blocks()
980            .par_iter()
981            .map(|block| self.wopbs_key.keyswitch_to_wopbs_params(&sks.key, block))
982            .collect();
983        T::from_blocks(blocks)
984    }
985
986    pub fn keyswitch_to_pbs_params<T>(&self, ct_in: &T) -> T
987    where
988        T: IntegerCiphertext,
989    {
990        let blocks: Vec<_> = ct_in
991            .blocks()
992            .par_iter()
993            .map(|block| self.wopbs_key.keyswitch_to_pbs_params(block))
994            .collect();
995        T::from_blocks(blocks)
996    }
997}