Skip to main content

tfhe/core_crypto/algorithms/
lwe_keyswitch.rs

1//! Module containing primitives pertaining to [`LWE ciphertext
2//! keyswitch`](`LweKeyswitchKey#lwe-keyswitch`).
3
4use crate::core_crypto::algorithms::slice_algorithms::*;
5use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
6use crate::core_crypto::commons::math::decomposition::{
7    SignedDecomposer, SignedDecomposerNonNative,
8};
9use crate::core_crypto::commons::parameters::{
10    DecompositionBaseLog, DecompositionLevelCount, ThreadCount,
11};
12use crate::core_crypto::commons::traits::*;
13use crate::core_crypto::entities::*;
14use rayon::prelude::*;
15
16/// Keyswitch an [`LWE ciphertext`](`LweCiphertext`) encrypted under an
17/// [`LWE secret key`](`LweSecretKey`) to another [`LWE secret key`](`LweSecretKey`).
18///
19/// Automatically dispatches to [`keyswitch_lwe_ciphertext_native_mod_compatible`] or
20/// [`keyswitch_lwe_ciphertext_other_mod`] depending on the ciphertext modulus of the input
21/// `lwe_keyswitch_key`.
22///
23/// # Formal Definition
24///
25/// See [`LWE keyswitch key`](`LweKeyswitchKey#lwe-keyswitch`).
26///
27/// # Example
28///
29/// ```rust
30/// use tfhe::core_crypto::prelude::*;
31///
32/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
33/// // computations
34/// // Define parameters for LweKeyswitchKey creation
35/// let input_lwe_dimension = LweDimension(742);
36/// let lwe_noise_distribution =
37///     Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
38/// let output_lwe_dimension = LweDimension(2048);
39/// let decomp_base_log = DecompositionBaseLog(3);
40/// let decomp_level_count = DecompositionLevelCount(5);
41/// let ciphertext_modulus = CiphertextModulus::new_native();
42///
43/// // Create the PRNG
44/// let mut seeder = new_seeder();
45/// let seeder = seeder.as_mut();
46/// let mut encryption_generator =
47///     EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
48/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
49///
50/// // Create the LweSecretKey
51/// let input_lwe_secret_key =
52///     allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
53/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
54///     output_lwe_dimension,
55///     &mut secret_generator,
56/// );
57///
58/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
59///     &input_lwe_secret_key,
60///     &output_lwe_secret_key,
61///     decomp_base_log,
62///     decomp_level_count,
63///     lwe_noise_distribution,
64///     ciphertext_modulus,
65///     &mut encryption_generator,
66/// );
67///
68/// // Create the plaintext
69/// let msg = 3u64;
70/// let plaintext = Plaintext(msg << 60);
71///
72/// // Create a new LweCiphertext
73/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
74///     &input_lwe_secret_key,
75///     plaintext,
76///     lwe_noise_distribution,
77///     ciphertext_modulus,
78///     &mut encryption_generator,
79/// );
80///
81/// let mut output_lwe = LweCiphertext::new(
82///     0,
83///     output_lwe_secret_key.lwe_dimension().to_lwe_size(),
84///     ciphertext_modulus,
85/// );
86///
87/// keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
88///
89/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
90///
91/// // Round and remove encoding
92/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
93/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
94///
95/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
96///
97/// // Remove the encoding
98/// let cleartext = rounded >> 60;
99///
100/// // Check we recovered the original message
101/// assert_eq!(cleartext, msg);
102/// ```
103pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
104    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
105    input_lwe_ciphertext: &LweCiphertext<InputCont>,
106    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
107) where
108    Scalar: UnsignedInteger,
109    KSKCont: Container<Element = Scalar>,
110    InputCont: Container<Element = Scalar>,
111    OutputCont: ContainerMut<Element = Scalar>,
112{
113    let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
114
115    if ciphertext_modulus.is_compatible_with_native_modulus() {
116        keyswitch_lwe_ciphertext_native_mod_compatible(
117            lwe_keyswitch_key,
118            input_lwe_ciphertext,
119            output_lwe_ciphertext,
120        )
121    } else {
122        keyswitch_lwe_ciphertext_other_mod(
123            lwe_keyswitch_key,
124            input_lwe_ciphertext,
125            output_lwe_ciphertext,
126        )
127    }
128}
129
130/// Specialized implementation of an LWE keyswitch when inputs have power of two moduli.
131///
132/// # Panics
133///
134/// Panics if the modulus of the inputs are not power of twos.
135/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
136/// modulus.
137pub fn keyswitch_lwe_ciphertext_native_mod_compatible<Scalar, KSKCont, InputCont, OutputCont>(
138    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
139    input_lwe_ciphertext: &LweCiphertext<InputCont>,
140    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
141) where
142    Scalar: UnsignedInteger,
143    KSKCont: Container<Element = Scalar>,
144    InputCont: Container<Element = Scalar>,
145    OutputCont: ContainerMut<Element = Scalar>,
146{
147    assert!(
148        lwe_keyswitch_key.input_key_lwe_dimension()
149            == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
150        "Mismatched input LweDimension. \
151        LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
152        lwe_keyswitch_key.input_key_lwe_dimension(),
153        input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
154    );
155    assert!(
156        lwe_keyswitch_key.output_key_lwe_dimension()
157            == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
158        "Mismatched output LweDimension. \
159        LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
160        lwe_keyswitch_key.output_key_lwe_dimension(),
161        output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
162    );
163
164    let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
165
166    assert_eq!(
167        lwe_keyswitch_key.ciphertext_modulus(),
168        output_ciphertext_modulus,
169        "Mismatched CiphertextModulus. \
170        LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
171        lwe_keyswitch_key.ciphertext_modulus(),
172        output_ciphertext_modulus
173    );
174    assert!(
175        output_ciphertext_modulus.is_compatible_with_native_modulus(),
176        "This operation currently only supports power of 2 moduli"
177    );
178
179    let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
180
181    assert!(
182        input_ciphertext_modulus.is_compatible_with_native_modulus(),
183        "This operation currently only supports power of 2 moduli"
184    );
185
186    // Clear the output ciphertext, as it will get updated gradually
187    output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
188
189    // Copy the input body to the output ciphertext
190    *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
191
192    // If the moduli are not the same, we need to round the body in the output ciphertext
193    if output_ciphertext_modulus != input_ciphertext_modulus
194        && !output_ciphertext_modulus.is_native_modulus()
195    {
196        let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
197        let output_decomposer = SignedDecomposer::new(
198            DecompositionBaseLog(modulus_bits),
199            DecompositionLevelCount(1),
200        );
201
202        *output_lwe_ciphertext.get_mut_body().data =
203            output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
204    }
205
206    // We instantiate a decomposer
207    let decomposer = SignedDecomposer::new(
208        lwe_keyswitch_key.decomposition_base_log(),
209        lwe_keyswitch_key.decomposition_level_count(),
210    );
211
212    for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
213        .iter()
214        .zip(input_lwe_ciphertext.get_mask().as_ref())
215    {
216        let decomposition_iter = decomposer.decompose(input_mask_element);
217        // Loop over the levels
218        for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
219        {
220            slice_wrapping_sub_scalar_mul_assign(
221                output_lwe_ciphertext.as_mut(),
222                level_key_ciphertext.as_ref(),
223                decomposed.value(),
224            );
225        }
226    }
227}
228
229/// Specialized implementation of an LWE keyswitch when inputs have non power of two moduli.
230///
231/// # Panics
232///
233/// Panics if the modulus of the inputs are power of twos.
234/// Panics if the modulus of the inputs are not all equal.
235pub fn keyswitch_lwe_ciphertext_other_mod<Scalar, KSKCont, InputCont, OutputCont>(
236    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
237    input_lwe_ciphertext: &LweCiphertext<InputCont>,
238    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
239) where
240    Scalar: UnsignedInteger,
241    KSKCont: Container<Element = Scalar>,
242    InputCont: Container<Element = Scalar>,
243    OutputCont: ContainerMut<Element = Scalar>,
244{
245    assert!(
246        lwe_keyswitch_key.input_key_lwe_dimension()
247            == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
248        "Mismatched input LweDimension. \
249        LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
250        lwe_keyswitch_key.input_key_lwe_dimension(),
251        input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
252    );
253    assert!(
254        lwe_keyswitch_key.output_key_lwe_dimension()
255            == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
256        "Mismatched output LweDimension. \
257        LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
258        lwe_keyswitch_key.output_key_lwe_dimension(),
259        output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
260    );
261
262    assert_eq!(
263        lwe_keyswitch_key.ciphertext_modulus(),
264        output_lwe_ciphertext.ciphertext_modulus(),
265        "Mismatched CiphertextModulus. \
266        LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
267        lwe_keyswitch_key.ciphertext_modulus(),
268        output_lwe_ciphertext.ciphertext_modulus()
269    );
270
271    assert_eq!(
272        lwe_keyswitch_key.ciphertext_modulus(),
273        input_lwe_ciphertext.ciphertext_modulus(),
274        "Mismatched CiphertextModulus. \
275        LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
276        lwe_keyswitch_key.ciphertext_modulus(),
277        input_lwe_ciphertext.ciphertext_modulus()
278    );
279
280    let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
281
282    assert!(
283        !ciphertext_modulus.is_compatible_with_native_modulus(),
284        "This operation currently only supports non power of 2 moduli"
285    );
286
287    // Clear the output ciphertext, as it will get updated gradually
288    output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
289
290    // Copy the input body to the output ciphertext
291    *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
292
293    // We instantiate a decomposer
294    let decomposer = SignedDecomposerNonNative::new(
295        lwe_keyswitch_key.decomposition_base_log(),
296        lwe_keyswitch_key.decomposition_level_count(),
297        ciphertext_modulus,
298    );
299
300    for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
301        .iter()
302        .zip(input_lwe_ciphertext.get_mask().as_ref())
303    {
304        let decomposition_iter = decomposer.decompose(input_mask_element);
305        // Loop over the levels
306        for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
307        {
308            slice_wrapping_sub_scalar_mul_assign_custom_modulus(
309                output_lwe_ciphertext.as_mut(),
310                level_key_ciphertext.as_ref(),
311                decomposed.modular_value(),
312                ciphertext_modulus.get_custom_modulus().cast_into(),
313            );
314        }
315    }
316}
317
318/// Keyswitch an [`LWE ciphertext`](`LweCiphertext`) with a certain InputScalar type to represent
319/// data encrypted under an [`LWE secret key`](`LweSecretKey`) to another [`LWE secret
320/// key`](`LweSecretKey`) using a different OutputScalar type to represent data.
321///
322/// # Notes
323///
324/// This function only supports power of 2 moduli and going from a large InputScalar with
325/// `input_bits` to a a smaller OutputScalar with `output_bits` and `output_bits` < `input_bits`.
326///
327/// The product of the `lwe_keyswitch_key`'s
328/// [`DecompositionBaseLog`](`crate::core_crypto::commons::parameters::DecompositionBaseLog`) and
329/// [`DecompositionLevelCount`](`crate::core_crypto::commons::parameters::DecompositionLevelCount`)
330/// needs to be smaller than `output_bits`.
331pub fn keyswitch_lwe_ciphertext_with_scalar_change<
332    InputScalar,
333    OutputScalar,
334    KSKCont,
335    InputCont,
336    OutputCont,
337>(
338    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
339    input_lwe_ciphertext: &LweCiphertext<InputCont>,
340    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
341) where
342    InputScalar: UnsignedInteger,
343    OutputScalar: UnsignedInteger + CastFrom<InputScalar>,
344    KSKCont: Container<Element = OutputScalar>,
345    InputCont: Container<Element = InputScalar>,
346    OutputCont: ContainerMut<Element = OutputScalar>,
347{
348    assert!(
349        InputScalar::BITS > OutputScalar::BITS,
350        "This operation only supports going from a large InputScalar type \
351        to a strictly smaller OutputScalar type."
352    );
353    assert!(
354        lwe_keyswitch_key.decomposition_base_log().0
355            * lwe_keyswitch_key.decomposition_level_count().0
356            <= OutputScalar::BITS,
357        "This operation only supports a DecompositionBaseLog and DecompositionLevelCount product \
358        smaller than the OutputScalar bit count."
359    );
360
361    assert!(
362        lwe_keyswitch_key.input_key_lwe_dimension()
363            == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
364        "Mismatched input LweDimension. \
365        LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
366        lwe_keyswitch_key.input_key_lwe_dimension(),
367        input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
368    );
369    assert!(
370        lwe_keyswitch_key.output_key_lwe_dimension()
371            == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
372        "Mismatched output LweDimension. \
373        LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
374        lwe_keyswitch_key.output_key_lwe_dimension(),
375        output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
376    );
377
378    let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
379
380    assert_eq!(
381        lwe_keyswitch_key.ciphertext_modulus(),
382        output_ciphertext_modulus,
383        "Mismatched CiphertextModulus. \
384        LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
385        lwe_keyswitch_key.ciphertext_modulus(),
386        output_ciphertext_modulus
387    );
388    assert!(
389        output_ciphertext_modulus.is_compatible_with_native_modulus(),
390        "This operation currently only supports power of 2 moduli"
391    );
392
393    let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
394
395    assert!(
396        input_ciphertext_modulus.is_compatible_with_native_modulus(),
397        "This operation currently only supports power of 2 moduli"
398    );
399
400    // Clear the output ciphertext, as it will get updated gradually
401    output_lwe_ciphertext.as_mut().fill(OutputScalar::ZERO);
402
403    let output_modulus_bits = match output_ciphertext_modulus.kind() {
404        CiphertextModulusKind::Native => OutputScalar::BITS,
405        CiphertextModulusKind::NonNativePowerOfTwo => {
406            output_ciphertext_modulus.get_custom_modulus().ilog2() as usize
407        }
408        CiphertextModulusKind::Other => unreachable!(),
409    };
410
411    let input_body_decomposer = SignedDecomposer::new(
412        DecompositionBaseLog(output_modulus_bits),
413        DecompositionLevelCount(1),
414    );
415
416    // Power of two are encoded in the MSBs of the types so we need to scale the type to the other
417    // one without having to worry about the moduli
418    let input_to_output_scaling_factor = InputScalar::BITS - OutputScalar::BITS;
419
420    let rounded_downscaled_body = input_body_decomposer
421        .closest_representable(*input_lwe_ciphertext.get_body().data)
422        >> input_to_output_scaling_factor;
423
424    *output_lwe_ciphertext.get_mut_body().data = rounded_downscaled_body.cast_into();
425
426    // We instantiate a decomposer
427    let input_decomposer = SignedDecomposer::<InputScalar>::new(
428        lwe_keyswitch_key.decomposition_base_log(),
429        lwe_keyswitch_key.decomposition_level_count(),
430    );
431
432    for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
433        .iter()
434        .zip(input_lwe_ciphertext.get_mask().as_ref())
435    {
436        let decomposition_iter = input_decomposer.decompose(input_mask_element);
437        // Loop over the levels
438        for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
439        {
440            slice_wrapping_sub_scalar_mul_assign(
441                output_lwe_ciphertext.as_mut(),
442                level_key_ciphertext.as_ref(),
443                decomposed.value().cast_into(),
444            );
445        }
446    }
447}
448
449/// Parallel variant of [`keyswitch_lwe_ciphertext`].
450///
451/// This will use all threads available in the current rayon thread pool.
452///
453/// Automatically dispatches to
454/// [`par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible`] or
455/// [`par_keyswitch_lwe_ciphertext_with_thread_count_other_mod`] depending on the ciphertext modulus
456/// of the input `lwe_keyswitch_key`.
457///
458/// # Example
459///
460/// ```rust
461/// use tfhe::core_crypto::prelude::*;
462///
463/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
464/// // computations
465/// // Define parameters for LweKeyswitchKey creation
466/// let input_lwe_dimension = LweDimension(742);
467/// let lwe_noise_distribution =
468///     Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
469/// let output_lwe_dimension = LweDimension(2048);
470/// let decomp_base_log = DecompositionBaseLog(3);
471/// let decomp_level_count = DecompositionLevelCount(5);
472/// let ciphertext_modulus = CiphertextModulus::new_native();
473///
474/// // Create the PRNG
475/// let mut seeder = new_seeder();
476/// let seeder = seeder.as_mut();
477/// let mut encryption_generator =
478///     EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
479/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
480///
481/// // Create the LweSecretKey
482/// let input_lwe_secret_key =
483///     allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
484/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
485///     output_lwe_dimension,
486///     &mut secret_generator,
487/// );
488///
489/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
490///     &input_lwe_secret_key,
491///     &output_lwe_secret_key,
492///     decomp_base_log,
493///     decomp_level_count,
494///     lwe_noise_distribution,
495///     ciphertext_modulus,
496///     &mut encryption_generator,
497/// );
498///
499/// // Create the plaintext
500/// let msg = 3u64;
501/// let plaintext = Plaintext(msg << 60);
502///
503/// // Create a new LweCiphertext
504/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
505///     &input_lwe_secret_key,
506///     plaintext,
507///     lwe_noise_distribution,
508///     ciphertext_modulus,
509///     &mut encryption_generator,
510/// );
511///
512/// let mut output_lwe = LweCiphertext::new(
513///     0,
514///     output_lwe_secret_key.lwe_dimension().to_lwe_size(),
515///     ciphertext_modulus,
516/// );
517///
518/// // Use all threads available in the current rayon thread pool
519/// par_keyswitch_lwe_ciphertext(&ksk, &input_lwe, &mut output_lwe);
520///
521/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
522///
523/// // Round and remove encoding
524/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
525/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
526///
527/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
528///
529/// // Remove the encoding
530/// let cleartext = rounded >> 60;
531///
532/// // Check we recovered the original message
533/// assert_eq!(cleartext, msg);
534/// ```
535pub fn par_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
536    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
537    input_lwe_ciphertext: &LweCiphertext<InputCont>,
538    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
539) where
540    Scalar: UnsignedInteger + Send + Sync,
541    KSKCont: Container<Element = Scalar>,
542    InputCont: Container<Element = Scalar>,
543    OutputCont: ContainerMut<Element = Scalar>,
544{
545    let thread_count = ThreadCount(rayon::current_num_threads());
546    par_keyswitch_lwe_ciphertext_with_thread_count(
547        lwe_keyswitch_key,
548        input_lwe_ciphertext,
549        output_lwe_ciphertext,
550        thread_count,
551    );
552}
553
554/// Parallel variant of [`keyswitch_lwe_ciphertext`].
555///
556/// This will try to use `thread_count` threads for the computation, if this number is bigger than
557/// the available number of threads in the current rayon thread pool then only the number of
558/// available threads will be used. Note that `thread_count` cannot be 0.
559///
560/// Automatically dispatches to
561/// [`par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible`] or
562/// [`par_keyswitch_lwe_ciphertext_with_thread_count_other_mod`] depending on the ciphertext modulus
563/// of the input `lwe_keyswitch_key`.
564///
565/// # Example
566///
567/// ```rust
568/// use tfhe::core_crypto::prelude::*;
569///
570/// // DISCLAIMER: these toy example parameters are not guaranteed to be secure or yield correct
571/// // computations
572/// // Define parameters for LweKeyswitchKey creation
573/// let input_lwe_dimension = LweDimension(742);
574/// let lwe_noise_distribution =
575///     Gaussian::from_dispersion_parameter(StandardDev(0.000007069849454709433), 0.0);
576/// let output_lwe_dimension = LweDimension(2048);
577/// let decomp_base_log = DecompositionBaseLog(3);
578/// let decomp_level_count = DecompositionLevelCount(5);
579/// let ciphertext_modulus = CiphertextModulus::new_native();
580///
581/// // Create the PRNG
582/// let mut seeder = new_seeder();
583/// let seeder = seeder.as_mut();
584/// let mut encryption_generator =
585///     EncryptionRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed(), seeder);
586/// let mut secret_generator = SecretRandomGenerator::<DefaultRandomGenerator>::new(seeder.seed());
587///
588/// // Create the LweSecretKey
589/// let input_lwe_secret_key =
590///     allocate_and_generate_new_binary_lwe_secret_key(input_lwe_dimension, &mut secret_generator);
591/// let output_lwe_secret_key = allocate_and_generate_new_binary_lwe_secret_key(
592///     output_lwe_dimension,
593///     &mut secret_generator,
594/// );
595///
596/// let ksk = allocate_and_generate_new_lwe_keyswitch_key(
597///     &input_lwe_secret_key,
598///     &output_lwe_secret_key,
599///     decomp_base_log,
600///     decomp_level_count,
601///     lwe_noise_distribution,
602///     ciphertext_modulus,
603///     &mut encryption_generator,
604/// );
605///
606/// // Create the plaintext
607/// let msg = 3u64;
608/// let plaintext = Plaintext(msg << 60);
609///
610/// // Create a new LweCiphertext
611/// let input_lwe = allocate_and_encrypt_new_lwe_ciphertext(
612///     &input_lwe_secret_key,
613///     plaintext,
614///     lwe_noise_distribution,
615///     ciphertext_modulus,
616///     &mut encryption_generator,
617/// );
618///
619/// let mut output_lwe = LweCiphertext::new(
620///     0,
621///     output_lwe_secret_key.lwe_dimension().to_lwe_size(),
622///     ciphertext_modulus,
623/// );
624///
625/// // Try to use 4 threads for the keyswitch if enough are available
626/// // in the current rayon thread pool
627/// par_keyswitch_lwe_ciphertext_with_thread_count(
628///     &ksk,
629///     &input_lwe,
630///     &mut output_lwe,
631///     ThreadCount(4),
632/// );
633///
634/// let decrypted_plaintext = decrypt_lwe_ciphertext(&output_lwe_secret_key, &output_lwe);
635///
636/// // Round and remove encoding
637/// // First create a decomposer working on the high 4 bits corresponding to our encoding.
638/// let decomposer = SignedDecomposer::new(DecompositionBaseLog(4), DecompositionLevelCount(1));
639///
640/// let rounded = decomposer.closest_representable(decrypted_plaintext.0);
641///
642/// // Remove the encoding
643/// let cleartext = rounded >> 60;
644///
645/// // Check we recovered the original message
646/// assert_eq!(cleartext, msg);
647/// ```
648pub fn par_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont, OutputCont>(
649    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
650    input_lwe_ciphertext: &LweCiphertext<InputCont>,
651    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
652    thread_count: ThreadCount,
653) where
654    Scalar: UnsignedInteger + Send + Sync,
655    KSKCont: Container<Element = Scalar>,
656    InputCont: Container<Element = Scalar>,
657    OutputCont: ContainerMut<Element = Scalar>,
658{
659    let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
660
661    if ciphertext_modulus.is_compatible_with_native_modulus() {
662        par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible(
663            lwe_keyswitch_key,
664            input_lwe_ciphertext,
665            output_lwe_ciphertext,
666            thread_count,
667        )
668    } else {
669        par_keyswitch_lwe_ciphertext_with_thread_count_other_mod(
670            lwe_keyswitch_key,
671            input_lwe_ciphertext,
672            output_lwe_ciphertext,
673            thread_count,
674        )
675    }
676}
677
678/// Specialized implementation of a parallel LWE keyswitch when inputs have power of two moduli.
679///
680/// # Panics
681///
682/// Panics if the modulus of the inputs are not power of twos.
683/// Panics if the output `output_lwe_ciphertext` modulus is not equal to the `lwe_keyswitch_key`
684/// modulus.
685pub fn par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible<
686    Scalar,
687    KSKCont,
688    InputCont,
689    OutputCont,
690>(
691    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
692    input_lwe_ciphertext: &LweCiphertext<InputCont>,
693    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
694    thread_count: ThreadCount,
695) where
696    Scalar: UnsignedInteger + Send + Sync,
697    KSKCont: Container<Element = Scalar>,
698    InputCont: Container<Element = Scalar>,
699    OutputCont: ContainerMut<Element = Scalar>,
700{
701    assert!(
702        lwe_keyswitch_key.input_key_lwe_dimension()
703            == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
704        "Mismatched input LweDimension. \
705        LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
706        lwe_keyswitch_key.input_key_lwe_dimension(),
707        input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
708    );
709    assert!(
710        lwe_keyswitch_key.output_key_lwe_dimension()
711            == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
712        "Mismatched output LweDimension. \
713        LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
714        lwe_keyswitch_key.output_key_lwe_dimension(),
715        output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
716    );
717
718    let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
719
720    assert_eq!(
721        lwe_keyswitch_key.ciphertext_modulus(),
722        output_ciphertext_modulus,
723        "Mismatched CiphertextModulus. \
724        LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
725        lwe_keyswitch_key.ciphertext_modulus(),
726        output_ciphertext_modulus
727    );
728    assert!(
729        output_ciphertext_modulus.is_compatible_with_native_modulus(),
730        "This operation currently only supports power of 2 moduli"
731    );
732
733    let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
734
735    assert!(
736        input_ciphertext_modulus.is_compatible_with_native_modulus(),
737        "This operation currently only supports power of 2 moduli"
738    );
739
740    assert!(
741        thread_count.0 != 0,
742        "Got thread_count == 0, this is not supported"
743    );
744
745    // Clear the output ciphertext, as it will get updated gradually
746    output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
747
748    let output_lwe_size = output_lwe_ciphertext.lwe_size();
749
750    // Copy the input body to the output ciphertext
751    *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
752
753    // If the moduli are not the same, we need to round the body in the output ciphertext
754    if output_ciphertext_modulus != input_ciphertext_modulus
755        && !output_ciphertext_modulus.is_native_modulus()
756    {
757        let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
758        let output_decomposer = SignedDecomposer::new(
759            DecompositionBaseLog(modulus_bits),
760            DecompositionLevelCount(1),
761        );
762
763        *output_lwe_ciphertext.get_mut_body().data =
764            output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
765    }
766
767    // We instantiate a decomposer
768    let decomposer = SignedDecomposer::new(
769        lwe_keyswitch_key.decomposition_base_log(),
770        lwe_keyswitch_key.decomposition_level_count(),
771    );
772
773    // Don't go above the current number of threads
774    let thread_count = thread_count.0.min(rayon::current_num_threads());
775    let mut intermediate_accumulators = Vec::with_capacity(thread_count);
776
777    // Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size
778    let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
779
780    lwe_keyswitch_key
781        .par_chunks(chunk_size)
782        .zip(
783            input_lwe_ciphertext
784                .get_mask()
785                .as_ref()
786                .par_chunks(chunk_size),
787        )
788        .map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
789            let mut buffer =
790                LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus);
791
792            for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
793                .iter()
794                .zip(input_mask_element_chunk.iter())
795            {
796                let decomposition_iter = decomposer.decompose(input_mask_element);
797                // Loop over the levels
798                for (level_key_ciphertext, decomposed) in
799                    keyswitch_key_block.iter().zip(decomposition_iter)
800                {
801                    slice_wrapping_sub_scalar_mul_assign(
802                        buffer.as_mut(),
803                        level_key_ciphertext.as_ref(),
804                        decomposed.value(),
805                    );
806                }
807            }
808            buffer
809        })
810        .collect_into_vec(&mut intermediate_accumulators);
811
812    let reduced = intermediate_accumulators
813        .par_iter_mut()
814        .reduce_with(|lhs, rhs| {
815            lhs.as_mut()
816                .iter_mut()
817                .zip(rhs.as_ref().iter())
818                .for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src));
819
820            lhs
821        })
822        .unwrap();
823
824    output_lwe_ciphertext
825        .get_mut_mask()
826        .as_mut()
827        .copy_from_slice(reduced.get_mask().as_ref());
828    let reduced_ksed_body = *reduced.get_body().data;
829
830    // Add the reduced body of the keyswitch to the output body to complete the keyswitch
831    *output_lwe_ciphertext.get_mut_body().data =
832        (*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body);
833}
834
835/// Specialized implementation of a parallel LWE keyswitch when inputs have non power of two moduli.
836///
837/// # Panics
838///
839/// Panics if the modulus of the inputs are power of twos.
840/// Panics if the modulus of the inputs are not all equal.
841pub fn par_keyswitch_lwe_ciphertext_with_thread_count_other_mod<
842    Scalar,
843    KSKCont,
844    InputCont,
845    OutputCont,
846>(
847    lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
848    input_lwe_ciphertext: &LweCiphertext<InputCont>,
849    output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
850    thread_count: ThreadCount,
851) where
852    Scalar: UnsignedInteger + Send + Sync,
853    KSKCont: Container<Element = Scalar>,
854    InputCont: Container<Element = Scalar>,
855    OutputCont: ContainerMut<Element = Scalar>,
856{
857    assert!(
858        lwe_keyswitch_key.input_key_lwe_dimension()
859            == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
860        "Mismatched input LweDimension. \
861        LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
862        lwe_keyswitch_key.input_key_lwe_dimension(),
863        input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
864    );
865    assert!(
866        lwe_keyswitch_key.output_key_lwe_dimension()
867            == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
868        "Mismatched output LweDimension. \
869        LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
870        lwe_keyswitch_key.output_key_lwe_dimension(),
871        output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
872    );
873
874    assert_eq!(
875        lwe_keyswitch_key.ciphertext_modulus(),
876        output_lwe_ciphertext.ciphertext_modulus(),
877        "Mismatched CiphertextModulus. \
878        LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
879        lwe_keyswitch_key.ciphertext_modulus(),
880        output_lwe_ciphertext.ciphertext_modulus()
881    );
882
883    assert_eq!(
884        lwe_keyswitch_key.ciphertext_modulus(),
885        input_lwe_ciphertext.ciphertext_modulus(),
886        "Mismatched CiphertextModulus. \
887        LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
888        lwe_keyswitch_key.ciphertext_modulus(),
889        input_lwe_ciphertext.ciphertext_modulus()
890    );
891
892    let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
893
894    assert!(
895        !ciphertext_modulus.is_compatible_with_native_modulus(),
896        "This operation currently only supports non power of 2 moduli"
897    );
898
899    let ciphertext_modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into();
900
901    assert!(
902        thread_count.0 != 0,
903        "Got thread_count == 0, this is not supported"
904    );
905
906    // Clear the output ciphertext, as it will get updated gradually
907    output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
908
909    let output_lwe_size = output_lwe_ciphertext.lwe_size();
910
911    // Copy the input body to the output ciphertext
912    *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
913
914    // We instantiate a decomposer
915    let decomposer = SignedDecomposerNonNative::new(
916        lwe_keyswitch_key.decomposition_base_log(),
917        lwe_keyswitch_key.decomposition_level_count(),
918        ciphertext_modulus,
919    );
920
921    // Don't go above the current number of threads
922    let thread_count = thread_count.0.min(rayon::current_num_threads());
923    let mut intermediate_accumulators = Vec::with_capacity(thread_count);
924
925    // Smallest chunk_size such that thread_count * chunk_size >= input_lwe_size
926    let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
927
928    lwe_keyswitch_key
929        .par_chunks(chunk_size)
930        .zip(
931            input_lwe_ciphertext
932                .get_mask()
933                .as_ref()
934                .par_chunks(chunk_size),
935        )
936        .map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
937            let mut buffer = LweCiphertext::new(Scalar::ZERO, output_lwe_size, ciphertext_modulus);
938
939            for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
940                .iter()
941                .zip(input_mask_element_chunk.iter())
942            {
943                let decomposition_iter = decomposer.decompose(input_mask_element);
944                // Loop over the levels
945                for (level_key_ciphertext, decomposed) in
946                    keyswitch_key_block.iter().zip(decomposition_iter)
947                {
948                    slice_wrapping_sub_scalar_mul_assign_custom_modulus(
949                        buffer.as_mut(),
950                        level_key_ciphertext.as_ref(),
951                        decomposed.modular_value(),
952                        ciphertext_modulus_as_scalar,
953                    );
954                }
955            }
956            buffer
957        })
958        .collect_into_vec(&mut intermediate_accumulators);
959
960    let reduced = intermediate_accumulators
961        .par_iter_mut()
962        .reduce_with(|lhs, rhs| {
963            lhs.as_mut()
964                .iter_mut()
965                .zip(rhs.as_ref().iter())
966                .for_each(|(dst, &src)| {
967                    *dst = (*dst).wrapping_add_custom_mod(src, ciphertext_modulus_as_scalar)
968                });
969
970            lhs
971        })
972        .unwrap();
973
974    output_lwe_ciphertext
975        .get_mut_mask()
976        .as_mut()
977        .copy_from_slice(reduced.get_mask().as_ref());
978    let reduced_ksed_body = *reduced.get_body().data;
979
980    // Add the reduced body of the keyswitch to the output body to complete the keyswitch
981    *output_lwe_ciphertext.get_mut_body().data = (*output_lwe_ciphertext.get_mut_body().data)
982        .wrapping_add_custom_mod(reduced_ksed_body, ciphertext_modulus_as_scalar);
983}
984
985// ============== Noise measurement trait implementations ============== //
986use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
987    AllocateLweKeyswitchResult, LweKeyswitch,
988};
989use crate::core_crypto::fft_impl::fft64::math::fft::id;
990use std::any::TypeId;
991
992impl<Scalar: UnsignedInteger, KeyCont: Container<Element = Scalar>> AllocateLweKeyswitchResult
993    for LweKeyswitchKey<KeyCont>
994{
995    type Output = LweCiphertextOwned<Scalar>;
996    type SideResources = ();
997
998    fn allocate_lwe_keyswitch_result(
999        &self,
1000        _side_resources: &mut Self::SideResources,
1001    ) -> Self::Output {
1002        Self::Output::new(
1003            Scalar::ZERO,
1004            self.output_lwe_size(),
1005            self.ciphertext_modulus(),
1006        )
1007    }
1008}
1009
1010impl<
1011        InputScalar: UnsignedInteger,
1012        OutputScalar: UnsignedInteger + CastFrom<InputScalar>,
1013        KeyCont: Container<Element = OutputScalar>,
1014        InputCont: Container<Element = InputScalar>,
1015        OutputCont: ContainerMut<Element = OutputScalar>,
1016    > LweKeyswitch<LweCiphertext<InputCont>, LweCiphertext<OutputCont>>
1017    for LweKeyswitchKey<KeyCont>
1018{
1019    type SideResources = ();
1020
1021    fn lwe_keyswitch(
1022        &self,
1023        input: &LweCiphertext<InputCont>,
1024        output: &mut LweCiphertext<OutputCont>,
1025        _side_resources: &mut Self::SideResources,
1026    ) {
1027        // We are forced to do this because rust complains of conflicting trait implementations
1028        // even though generics are different, it's not enough to rule that actual
1029        // concrete types are different, but in our case they would be mutually
1030        // exclusive
1031        if TypeId::of::<InputScalar>() == TypeId::of::<OutputScalar>() {
1032            // Cannot use Any as Any requires a type to be 'static (lifetime information is not
1033            // available at runtime, it's lost during compilation and only used for the rust
1034            // borrock "proofs", so types need to be 'static to use the dynamic
1035            // runtime Any facilities) Let's operate on views, we know types are
1036            // supposed to be the same, so convert the slice (as we already have
1037            // the primitive) and cast the modulus which will be a no-op
1038            // in practice
1039            let input_content = input.as_ref();
1040            let input_as_output_scalar = LweCiphertext::from_container(
1041                id(input_content),
1042                input.ciphertext_modulus().try_to().unwrap(),
1043            );
1044            keyswitch_lwe_ciphertext(self, &input_as_output_scalar, output);
1045        } else {
1046            keyswitch_lwe_ciphertext_with_scalar_change(self, input, output);
1047        }
1048    }
1049}