1use crate::core_crypto::algorithms::slice_algorithms::*;
5use crate::core_crypto::commons::ciphertext_modulus::CiphertextModulusKind;
6use crate::core_crypto::commons::math::decomposition::{
7 SignedDecomposer, SignedDecomposerNonNative,
8};
9use crate::core_crypto::commons::parameters::{
10 DecompositionBaseLog, DecompositionLevelCount, ThreadCount,
11};
12use crate::core_crypto::commons::traits::*;
13use crate::core_crypto::entities::*;
14use rayon::prelude::*;
15
16pub fn keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
104 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
105 input_lwe_ciphertext: &LweCiphertext<InputCont>,
106 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
107) where
108 Scalar: UnsignedInteger,
109 KSKCont: Container<Element = Scalar>,
110 InputCont: Container<Element = Scalar>,
111 OutputCont: ContainerMut<Element = Scalar>,
112{
113 let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
114
115 if ciphertext_modulus.is_compatible_with_native_modulus() {
116 keyswitch_lwe_ciphertext_native_mod_compatible(
117 lwe_keyswitch_key,
118 input_lwe_ciphertext,
119 output_lwe_ciphertext,
120 )
121 } else {
122 keyswitch_lwe_ciphertext_other_mod(
123 lwe_keyswitch_key,
124 input_lwe_ciphertext,
125 output_lwe_ciphertext,
126 )
127 }
128}
129
130pub fn keyswitch_lwe_ciphertext_native_mod_compatible<Scalar, KSKCont, InputCont, OutputCont>(
138 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
139 input_lwe_ciphertext: &LweCiphertext<InputCont>,
140 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
141) where
142 Scalar: UnsignedInteger,
143 KSKCont: Container<Element = Scalar>,
144 InputCont: Container<Element = Scalar>,
145 OutputCont: ContainerMut<Element = Scalar>,
146{
147 assert!(
148 lwe_keyswitch_key.input_key_lwe_dimension()
149 == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
150 "Mismatched input LweDimension. \
151 LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
152 lwe_keyswitch_key.input_key_lwe_dimension(),
153 input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
154 );
155 assert!(
156 lwe_keyswitch_key.output_key_lwe_dimension()
157 == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
158 "Mismatched output LweDimension. \
159 LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
160 lwe_keyswitch_key.output_key_lwe_dimension(),
161 output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
162 );
163
164 let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
165
166 assert_eq!(
167 lwe_keyswitch_key.ciphertext_modulus(),
168 output_ciphertext_modulus,
169 "Mismatched CiphertextModulus. \
170 LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
171 lwe_keyswitch_key.ciphertext_modulus(),
172 output_ciphertext_modulus
173 );
174 assert!(
175 output_ciphertext_modulus.is_compatible_with_native_modulus(),
176 "This operation currently only supports power of 2 moduli"
177 );
178
179 let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
180
181 assert!(
182 input_ciphertext_modulus.is_compatible_with_native_modulus(),
183 "This operation currently only supports power of 2 moduli"
184 );
185
186 output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
188
189 *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
191
192 if output_ciphertext_modulus != input_ciphertext_modulus
194 && !output_ciphertext_modulus.is_native_modulus()
195 {
196 let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
197 let output_decomposer = SignedDecomposer::new(
198 DecompositionBaseLog(modulus_bits),
199 DecompositionLevelCount(1),
200 );
201
202 *output_lwe_ciphertext.get_mut_body().data =
203 output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
204 }
205
206 let decomposer = SignedDecomposer::new(
208 lwe_keyswitch_key.decomposition_base_log(),
209 lwe_keyswitch_key.decomposition_level_count(),
210 );
211
212 for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
213 .iter()
214 .zip(input_lwe_ciphertext.get_mask().as_ref())
215 {
216 let decomposition_iter = decomposer.decompose(input_mask_element);
217 for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
219 {
220 slice_wrapping_sub_scalar_mul_assign(
221 output_lwe_ciphertext.as_mut(),
222 level_key_ciphertext.as_ref(),
223 decomposed.value(),
224 );
225 }
226 }
227}
228
229pub fn keyswitch_lwe_ciphertext_other_mod<Scalar, KSKCont, InputCont, OutputCont>(
236 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
237 input_lwe_ciphertext: &LweCiphertext<InputCont>,
238 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
239) where
240 Scalar: UnsignedInteger,
241 KSKCont: Container<Element = Scalar>,
242 InputCont: Container<Element = Scalar>,
243 OutputCont: ContainerMut<Element = Scalar>,
244{
245 assert!(
246 lwe_keyswitch_key.input_key_lwe_dimension()
247 == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
248 "Mismatched input LweDimension. \
249 LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
250 lwe_keyswitch_key.input_key_lwe_dimension(),
251 input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
252 );
253 assert!(
254 lwe_keyswitch_key.output_key_lwe_dimension()
255 == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
256 "Mismatched output LweDimension. \
257 LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
258 lwe_keyswitch_key.output_key_lwe_dimension(),
259 output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
260 );
261
262 assert_eq!(
263 lwe_keyswitch_key.ciphertext_modulus(),
264 output_lwe_ciphertext.ciphertext_modulus(),
265 "Mismatched CiphertextModulus. \
266 LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
267 lwe_keyswitch_key.ciphertext_modulus(),
268 output_lwe_ciphertext.ciphertext_modulus()
269 );
270
271 assert_eq!(
272 lwe_keyswitch_key.ciphertext_modulus(),
273 input_lwe_ciphertext.ciphertext_modulus(),
274 "Mismatched CiphertextModulus. \
275 LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
276 lwe_keyswitch_key.ciphertext_modulus(),
277 input_lwe_ciphertext.ciphertext_modulus()
278 );
279
280 let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
281
282 assert!(
283 !ciphertext_modulus.is_compatible_with_native_modulus(),
284 "This operation currently only supports non power of 2 moduli"
285 );
286
287 output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
289
290 *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
292
293 let decomposer = SignedDecomposerNonNative::new(
295 lwe_keyswitch_key.decomposition_base_log(),
296 lwe_keyswitch_key.decomposition_level_count(),
297 ciphertext_modulus,
298 );
299
300 for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
301 .iter()
302 .zip(input_lwe_ciphertext.get_mask().as_ref())
303 {
304 let decomposition_iter = decomposer.decompose(input_mask_element);
305 for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
307 {
308 slice_wrapping_sub_scalar_mul_assign_custom_modulus(
309 output_lwe_ciphertext.as_mut(),
310 level_key_ciphertext.as_ref(),
311 decomposed.modular_value(),
312 ciphertext_modulus.get_custom_modulus().cast_into(),
313 );
314 }
315 }
316}
317
318pub fn keyswitch_lwe_ciphertext_with_scalar_change<
332 InputScalar,
333 OutputScalar,
334 KSKCont,
335 InputCont,
336 OutputCont,
337>(
338 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
339 input_lwe_ciphertext: &LweCiphertext<InputCont>,
340 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
341) where
342 InputScalar: UnsignedInteger,
343 OutputScalar: UnsignedInteger + CastFrom<InputScalar>,
344 KSKCont: Container<Element = OutputScalar>,
345 InputCont: Container<Element = InputScalar>,
346 OutputCont: ContainerMut<Element = OutputScalar>,
347{
348 assert!(
349 InputScalar::BITS > OutputScalar::BITS,
350 "This operation only supports going from a large InputScalar type \
351 to a strictly smaller OutputScalar type."
352 );
353 assert!(
354 lwe_keyswitch_key.decomposition_base_log().0
355 * lwe_keyswitch_key.decomposition_level_count().0
356 <= OutputScalar::BITS,
357 "This operation only supports a DecompositionBaseLog and DecompositionLevelCount product \
358 smaller than the OutputScalar bit count."
359 );
360
361 assert!(
362 lwe_keyswitch_key.input_key_lwe_dimension()
363 == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
364 "Mismatched input LweDimension. \
365 LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
366 lwe_keyswitch_key.input_key_lwe_dimension(),
367 input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
368 );
369 assert!(
370 lwe_keyswitch_key.output_key_lwe_dimension()
371 == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
372 "Mismatched output LweDimension. \
373 LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
374 lwe_keyswitch_key.output_key_lwe_dimension(),
375 output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
376 );
377
378 let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
379
380 assert_eq!(
381 lwe_keyswitch_key.ciphertext_modulus(),
382 output_ciphertext_modulus,
383 "Mismatched CiphertextModulus. \
384 LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
385 lwe_keyswitch_key.ciphertext_modulus(),
386 output_ciphertext_modulus
387 );
388 assert!(
389 output_ciphertext_modulus.is_compatible_with_native_modulus(),
390 "This operation currently only supports power of 2 moduli"
391 );
392
393 let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
394
395 assert!(
396 input_ciphertext_modulus.is_compatible_with_native_modulus(),
397 "This operation currently only supports power of 2 moduli"
398 );
399
400 output_lwe_ciphertext.as_mut().fill(OutputScalar::ZERO);
402
403 let output_modulus_bits = match output_ciphertext_modulus.kind() {
404 CiphertextModulusKind::Native => OutputScalar::BITS,
405 CiphertextModulusKind::NonNativePowerOfTwo => {
406 output_ciphertext_modulus.get_custom_modulus().ilog2() as usize
407 }
408 CiphertextModulusKind::Other => unreachable!(),
409 };
410
411 let input_body_decomposer = SignedDecomposer::new(
412 DecompositionBaseLog(output_modulus_bits),
413 DecompositionLevelCount(1),
414 );
415
416 let input_to_output_scaling_factor = InputScalar::BITS - OutputScalar::BITS;
419
420 let rounded_downscaled_body = input_body_decomposer
421 .closest_representable(*input_lwe_ciphertext.get_body().data)
422 >> input_to_output_scaling_factor;
423
424 *output_lwe_ciphertext.get_mut_body().data = rounded_downscaled_body.cast_into();
425
426 let input_decomposer = SignedDecomposer::<InputScalar>::new(
428 lwe_keyswitch_key.decomposition_base_log(),
429 lwe_keyswitch_key.decomposition_level_count(),
430 );
431
432 for (keyswitch_key_block, &input_mask_element) in lwe_keyswitch_key
433 .iter()
434 .zip(input_lwe_ciphertext.get_mask().as_ref())
435 {
436 let decomposition_iter = input_decomposer.decompose(input_mask_element);
437 for (level_key_ciphertext, decomposed) in keyswitch_key_block.iter().zip(decomposition_iter)
439 {
440 slice_wrapping_sub_scalar_mul_assign(
441 output_lwe_ciphertext.as_mut(),
442 level_key_ciphertext.as_ref(),
443 decomposed.value().cast_into(),
444 );
445 }
446 }
447}
448
449pub fn par_keyswitch_lwe_ciphertext<Scalar, KSKCont, InputCont, OutputCont>(
536 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
537 input_lwe_ciphertext: &LweCiphertext<InputCont>,
538 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
539) where
540 Scalar: UnsignedInteger + Send + Sync,
541 KSKCont: Container<Element = Scalar>,
542 InputCont: Container<Element = Scalar>,
543 OutputCont: ContainerMut<Element = Scalar>,
544{
545 let thread_count = ThreadCount(rayon::current_num_threads());
546 par_keyswitch_lwe_ciphertext_with_thread_count(
547 lwe_keyswitch_key,
548 input_lwe_ciphertext,
549 output_lwe_ciphertext,
550 thread_count,
551 );
552}
553
554pub fn par_keyswitch_lwe_ciphertext_with_thread_count<Scalar, KSKCont, InputCont, OutputCont>(
649 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
650 input_lwe_ciphertext: &LweCiphertext<InputCont>,
651 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
652 thread_count: ThreadCount,
653) where
654 Scalar: UnsignedInteger + Send + Sync,
655 KSKCont: Container<Element = Scalar>,
656 InputCont: Container<Element = Scalar>,
657 OutputCont: ContainerMut<Element = Scalar>,
658{
659 let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
660
661 if ciphertext_modulus.is_compatible_with_native_modulus() {
662 par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible(
663 lwe_keyswitch_key,
664 input_lwe_ciphertext,
665 output_lwe_ciphertext,
666 thread_count,
667 )
668 } else {
669 par_keyswitch_lwe_ciphertext_with_thread_count_other_mod(
670 lwe_keyswitch_key,
671 input_lwe_ciphertext,
672 output_lwe_ciphertext,
673 thread_count,
674 )
675 }
676}
677
678pub fn par_keyswitch_lwe_ciphertext_with_thread_count_native_mod_compatible<
686 Scalar,
687 KSKCont,
688 InputCont,
689 OutputCont,
690>(
691 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
692 input_lwe_ciphertext: &LweCiphertext<InputCont>,
693 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
694 thread_count: ThreadCount,
695) where
696 Scalar: UnsignedInteger + Send + Sync,
697 KSKCont: Container<Element = Scalar>,
698 InputCont: Container<Element = Scalar>,
699 OutputCont: ContainerMut<Element = Scalar>,
700{
701 assert!(
702 lwe_keyswitch_key.input_key_lwe_dimension()
703 == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
704 "Mismatched input LweDimension. \
705 LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
706 lwe_keyswitch_key.input_key_lwe_dimension(),
707 input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
708 );
709 assert!(
710 lwe_keyswitch_key.output_key_lwe_dimension()
711 == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
712 "Mismatched output LweDimension. \
713 LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
714 lwe_keyswitch_key.output_key_lwe_dimension(),
715 output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
716 );
717
718 let output_ciphertext_modulus = output_lwe_ciphertext.ciphertext_modulus();
719
720 assert_eq!(
721 lwe_keyswitch_key.ciphertext_modulus(),
722 output_ciphertext_modulus,
723 "Mismatched CiphertextModulus. \
724 LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
725 lwe_keyswitch_key.ciphertext_modulus(),
726 output_ciphertext_modulus
727 );
728 assert!(
729 output_ciphertext_modulus.is_compatible_with_native_modulus(),
730 "This operation currently only supports power of 2 moduli"
731 );
732
733 let input_ciphertext_modulus = input_lwe_ciphertext.ciphertext_modulus();
734
735 assert!(
736 input_ciphertext_modulus.is_compatible_with_native_modulus(),
737 "This operation currently only supports power of 2 moduli"
738 );
739
740 assert!(
741 thread_count.0 != 0,
742 "Got thread_count == 0, this is not supported"
743 );
744
745 output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
747
748 let output_lwe_size = output_lwe_ciphertext.lwe_size();
749
750 *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
752
753 if output_ciphertext_modulus != input_ciphertext_modulus
755 && !output_ciphertext_modulus.is_native_modulus()
756 {
757 let modulus_bits = output_ciphertext_modulus.get_custom_modulus().ilog2() as usize;
758 let output_decomposer = SignedDecomposer::new(
759 DecompositionBaseLog(modulus_bits),
760 DecompositionLevelCount(1),
761 );
762
763 *output_lwe_ciphertext.get_mut_body().data =
764 output_decomposer.closest_representable(*output_lwe_ciphertext.get_mut_body().data);
765 }
766
767 let decomposer = SignedDecomposer::new(
769 lwe_keyswitch_key.decomposition_base_log(),
770 lwe_keyswitch_key.decomposition_level_count(),
771 );
772
773 let thread_count = thread_count.0.min(rayon::current_num_threads());
775 let mut intermediate_accumulators = Vec::with_capacity(thread_count);
776
777 let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
779
780 lwe_keyswitch_key
781 .par_chunks(chunk_size)
782 .zip(
783 input_lwe_ciphertext
784 .get_mask()
785 .as_ref()
786 .par_chunks(chunk_size),
787 )
788 .map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
789 let mut buffer =
790 LweCiphertext::new(Scalar::ZERO, output_lwe_size, output_ciphertext_modulus);
791
792 for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
793 .iter()
794 .zip(input_mask_element_chunk.iter())
795 {
796 let decomposition_iter = decomposer.decompose(input_mask_element);
797 for (level_key_ciphertext, decomposed) in
799 keyswitch_key_block.iter().zip(decomposition_iter)
800 {
801 slice_wrapping_sub_scalar_mul_assign(
802 buffer.as_mut(),
803 level_key_ciphertext.as_ref(),
804 decomposed.value(),
805 );
806 }
807 }
808 buffer
809 })
810 .collect_into_vec(&mut intermediate_accumulators);
811
812 let reduced = intermediate_accumulators
813 .par_iter_mut()
814 .reduce_with(|lhs, rhs| {
815 lhs.as_mut()
816 .iter_mut()
817 .zip(rhs.as_ref().iter())
818 .for_each(|(dst, &src)| *dst = (*dst).wrapping_add(src));
819
820 lhs
821 })
822 .unwrap();
823
824 output_lwe_ciphertext
825 .get_mut_mask()
826 .as_mut()
827 .copy_from_slice(reduced.get_mask().as_ref());
828 let reduced_ksed_body = *reduced.get_body().data;
829
830 *output_lwe_ciphertext.get_mut_body().data =
832 (*output_lwe_ciphertext.get_mut_body().data).wrapping_add(reduced_ksed_body);
833}
834
835pub fn par_keyswitch_lwe_ciphertext_with_thread_count_other_mod<
842 Scalar,
843 KSKCont,
844 InputCont,
845 OutputCont,
846>(
847 lwe_keyswitch_key: &LweKeyswitchKey<KSKCont>,
848 input_lwe_ciphertext: &LweCiphertext<InputCont>,
849 output_lwe_ciphertext: &mut LweCiphertext<OutputCont>,
850 thread_count: ThreadCount,
851) where
852 Scalar: UnsignedInteger + Send + Sync,
853 KSKCont: Container<Element = Scalar>,
854 InputCont: Container<Element = Scalar>,
855 OutputCont: ContainerMut<Element = Scalar>,
856{
857 assert!(
858 lwe_keyswitch_key.input_key_lwe_dimension()
859 == input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
860 "Mismatched input LweDimension. \
861 LweKeyswitchKey input LweDimension: {:?}, input LweCiphertext LweDimension {:?}.",
862 lwe_keyswitch_key.input_key_lwe_dimension(),
863 input_lwe_ciphertext.lwe_size().to_lwe_dimension(),
864 );
865 assert!(
866 lwe_keyswitch_key.output_key_lwe_dimension()
867 == output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
868 "Mismatched output LweDimension. \
869 LweKeyswitchKey output LweDimension: {:?}, output LweCiphertext LweDimension {:?}.",
870 lwe_keyswitch_key.output_key_lwe_dimension(),
871 output_lwe_ciphertext.lwe_size().to_lwe_dimension(),
872 );
873
874 assert_eq!(
875 lwe_keyswitch_key.ciphertext_modulus(),
876 output_lwe_ciphertext.ciphertext_modulus(),
877 "Mismatched CiphertextModulus. \
878 LweKeyswitchKey CiphertextModulus: {:?}, output LweCiphertext CiphertextModulus {:?}.",
879 lwe_keyswitch_key.ciphertext_modulus(),
880 output_lwe_ciphertext.ciphertext_modulus()
881 );
882
883 assert_eq!(
884 lwe_keyswitch_key.ciphertext_modulus(),
885 input_lwe_ciphertext.ciphertext_modulus(),
886 "Mismatched CiphertextModulus. \
887 LweKeyswitchKey CiphertextModulus: {:?}, input LweCiphertext CiphertextModulus {:?}.",
888 lwe_keyswitch_key.ciphertext_modulus(),
889 input_lwe_ciphertext.ciphertext_modulus()
890 );
891
892 let ciphertext_modulus = lwe_keyswitch_key.ciphertext_modulus();
893
894 assert!(
895 !ciphertext_modulus.is_compatible_with_native_modulus(),
896 "This operation currently only supports non power of 2 moduli"
897 );
898
899 let ciphertext_modulus_as_scalar: Scalar = ciphertext_modulus.get_custom_modulus().cast_into();
900
901 assert!(
902 thread_count.0 != 0,
903 "Got thread_count == 0, this is not supported"
904 );
905
906 output_lwe_ciphertext.as_mut().fill(Scalar::ZERO);
908
909 let output_lwe_size = output_lwe_ciphertext.lwe_size();
910
911 *output_lwe_ciphertext.get_mut_body().data = *input_lwe_ciphertext.get_body().data;
913
914 let decomposer = SignedDecomposerNonNative::new(
916 lwe_keyswitch_key.decomposition_base_log(),
917 lwe_keyswitch_key.decomposition_level_count(),
918 ciphertext_modulus,
919 );
920
921 let thread_count = thread_count.0.min(rayon::current_num_threads());
923 let mut intermediate_accumulators = Vec::with_capacity(thread_count);
924
925 let chunk_size = input_lwe_ciphertext.lwe_size().0.div_ceil(thread_count);
927
928 lwe_keyswitch_key
929 .par_chunks(chunk_size)
930 .zip(
931 input_lwe_ciphertext
932 .get_mask()
933 .as_ref()
934 .par_chunks(chunk_size),
935 )
936 .map(|(keyswitch_key_block_chunk, input_mask_element_chunk)| {
937 let mut buffer = LweCiphertext::new(Scalar::ZERO, output_lwe_size, ciphertext_modulus);
938
939 for (keyswitch_key_block, &input_mask_element) in keyswitch_key_block_chunk
940 .iter()
941 .zip(input_mask_element_chunk.iter())
942 {
943 let decomposition_iter = decomposer.decompose(input_mask_element);
944 for (level_key_ciphertext, decomposed) in
946 keyswitch_key_block.iter().zip(decomposition_iter)
947 {
948 slice_wrapping_sub_scalar_mul_assign_custom_modulus(
949 buffer.as_mut(),
950 level_key_ciphertext.as_ref(),
951 decomposed.modular_value(),
952 ciphertext_modulus_as_scalar,
953 );
954 }
955 }
956 buffer
957 })
958 .collect_into_vec(&mut intermediate_accumulators);
959
960 let reduced = intermediate_accumulators
961 .par_iter_mut()
962 .reduce_with(|lhs, rhs| {
963 lhs.as_mut()
964 .iter_mut()
965 .zip(rhs.as_ref().iter())
966 .for_each(|(dst, &src)| {
967 *dst = (*dst).wrapping_add_custom_mod(src, ciphertext_modulus_as_scalar)
968 });
969
970 lhs
971 })
972 .unwrap();
973
974 output_lwe_ciphertext
975 .get_mut_mask()
976 .as_mut()
977 .copy_from_slice(reduced.get_mask().as_ref());
978 let reduced_ksed_body = *reduced.get_body().data;
979
980 *output_lwe_ciphertext.get_mut_body().data = (*output_lwe_ciphertext.get_mut_body().data)
982 .wrapping_add_custom_mod(reduced_ksed_body, ciphertext_modulus_as_scalar);
983}
984
985use crate::core_crypto::commons::noise_formulas::noise_simulation::traits::{
987 AllocateLweKeyswitchResult, LweKeyswitch,
988};
989use crate::core_crypto::fft_impl::fft64::math::fft::id;
990use std::any::TypeId;
991
992impl<Scalar: UnsignedInteger, KeyCont: Container<Element = Scalar>> AllocateLweKeyswitchResult
993 for LweKeyswitchKey<KeyCont>
994{
995 type Output = LweCiphertextOwned<Scalar>;
996 type SideResources = ();
997
998 fn allocate_lwe_keyswitch_result(
999 &self,
1000 _side_resources: &mut Self::SideResources,
1001 ) -> Self::Output {
1002 Self::Output::new(
1003 Scalar::ZERO,
1004 self.output_lwe_size(),
1005 self.ciphertext_modulus(),
1006 )
1007 }
1008}
1009
1010impl<
1011 InputScalar: UnsignedInteger,
1012 OutputScalar: UnsignedInteger + CastFrom<InputScalar>,
1013 KeyCont: Container<Element = OutputScalar>,
1014 InputCont: Container<Element = InputScalar>,
1015 OutputCont: ContainerMut<Element = OutputScalar>,
1016 > LweKeyswitch<LweCiphertext<InputCont>, LweCiphertext<OutputCont>>
1017 for LweKeyswitchKey<KeyCont>
1018{
1019 type SideResources = ();
1020
1021 fn lwe_keyswitch(
1022 &self,
1023 input: &LweCiphertext<InputCont>,
1024 output: &mut LweCiphertext<OutputCont>,
1025 _side_resources: &mut Self::SideResources,
1026 ) {
1027 if TypeId::of::<InputScalar>() == TypeId::of::<OutputScalar>() {
1032 let input_content = input.as_ref();
1040 let input_as_output_scalar = LweCiphertext::from_container(
1041 id(input_content),
1042 input.ciphertext_modulus().try_to().unwrap(),
1043 );
1044 keyswitch_lwe_ciphertext(self, &input_as_output_scalar, output);
1045 } else {
1046 keyswitch_lwe_ciphertext_with_scalar_change(self, input, output);
1047 }
1048 }
1049}