tfhe/core_crypto/gpu/algorithms/
lwe_multi_bit_programmable_bootstrapping.rs1use crate::core_crypto::gpu::entities::glwe_ciphertext_list::CudaGlweCiphertextList;
2use crate::core_crypto::gpu::entities::lwe_ciphertext_list::CudaLweCiphertextList;
3use crate::core_crypto::gpu::entities::lwe_multi_bit_bootstrap_key::CudaLweMultiBitBootstrapKey;
4use crate::core_crypto::gpu::vec::CudaVec;
5use crate::core_crypto::gpu::{programmable_bootstrap_multi_bit_async, CudaStreams};
6use crate::core_crypto::prelude::{CastInto, UnsignedTorus};
7
8#[allow(clippy::too_many_arguments)]
13pub unsafe fn cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_async<Scalar>(
14 input: &CudaLweCiphertextList<Scalar>,
15 output: &mut CudaLweCiphertextList<Scalar>,
16 accumulator: &CudaGlweCiphertextList<Scalar>,
17 lut_indexes: &CudaVec<Scalar>,
18 output_indexes: &CudaVec<Scalar>,
19 input_indexes: &CudaVec<Scalar>,
20 multi_bit_bsk: &CudaLweMultiBitBootstrapKey,
21 streams: &CudaStreams,
22) where
23 Scalar: UnsignedTorus + CastInto<usize>,
25{
26 assert_eq!(
27 input.lwe_dimension(),
28 multi_bit_bsk.input_lwe_dimension(),
29 "Mismatched input LweDimension. LweCiphertext input LweDimension {:?}. \
30 FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
31 input.lwe_dimension(),
32 multi_bit_bsk.input_lwe_dimension(),
33 );
34
35 assert_eq!(
36 output.lwe_dimension(),
37 multi_bit_bsk.output_lwe_dimension(),
38 "Mismatched output LweDimension. LweCiphertext output LweDimension {:?}. \
39 FourierLweMultiBitBootstrapKey output LweDimension {:?}.",
40 output.lwe_dimension(),
41 multi_bit_bsk.output_lwe_dimension(),
42 );
43
44 assert_eq!(
45 accumulator.glwe_dimension(),
46 multi_bit_bsk.glwe_dimension(),
47 "Mismatched GlweSize. Accumulator GlweSize {:?}. \
48 FourierLweMultiBitBootstrapKey GlweSize {:?}.",
49 accumulator.glwe_dimension(),
50 multi_bit_bsk.glwe_dimension(),
51 );
52
53 assert_eq!(
54 accumulator.polynomial_size(),
55 multi_bit_bsk.polynomial_size(),
56 "Mismatched PolynomialSize. Accumulator PolynomialSize {:?}. \
57 FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
58 accumulator.polynomial_size(),
59 multi_bit_bsk.polynomial_size(),
60 );
61
62 assert_eq!(
63 input.ciphertext_modulus(),
64 output.ciphertext_modulus(),
65 "Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
66 input.ciphertext_modulus(),
67 output.ciphertext_modulus(),
68 );
69
70 assert_eq!(
71 input.ciphertext_modulus(),
72 accumulator.ciphertext_modulus(),
73 "Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
74 input.ciphertext_modulus(),
75 accumulator.ciphertext_modulus(),
76 );
77 assert_eq!(
78 streams.gpu_indexes[0],
79 multi_bit_bsk.d_vec.gpu_index(0),
80 "GPU error: first stream is on GPU {}, first bsk pointer is on GPU {}",
81 streams.gpu_indexes[0].get(),
82 multi_bit_bsk.d_vec.gpu_index(0).get(),
83 );
84 assert_eq!(
85 streams.gpu_indexes[0],
86 input.0.d_vec.gpu_index(0),
87 "GPU error: first stream is on GPU {}, first input pointer is on GPU {}",
88 streams.gpu_indexes[0].get(),
89 input.0.d_vec.gpu_index(0).get(),
90 );
91 assert_eq!(
92 streams.gpu_indexes[0],
93 output.0.d_vec.gpu_index(0),
94 "GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
95 streams.gpu_indexes[0].get(),
96 output.0.d_vec.gpu_index(0).get(),
97 );
98 assert_eq!(
99 streams.gpu_indexes[0],
100 accumulator.0.d_vec.gpu_index(0),
101 "GPU error: first stream is on GPU {}, first accumulator pointer is on GPU {}",
102 streams.gpu_indexes[0].get(),
103 accumulator.0.d_vec.gpu_index(0).get(),
104 );
105 assert_eq!(
106 streams.gpu_indexes[0],
107 input_indexes.gpu_index(0),
108 "GPU error: first stream is on GPU {}, first input indexes pointer is on GPU {}",
109 streams.gpu_indexes[0].get(),
110 input_indexes.gpu_index(0).get(),
111 );
112 assert_eq!(
113 streams.gpu_indexes[0],
114 output_indexes.gpu_index(0),
115 "GPU error: first stream is on GPU {}, first output indexes pointer is on GPU {}",
116 streams.gpu_indexes[0].get(),
117 output_indexes.gpu_index(0).get(),
118 );
119 assert_eq!(
120 streams.gpu_indexes[0],
121 lut_indexes.gpu_index(0),
122 "GPU error: first stream is on GPU {}, first lut indexes pointer is on GPU {}",
123 streams.gpu_indexes[0].get(),
124 lut_indexes.gpu_index(0).get(),
125 );
126
127 programmable_bootstrap_multi_bit_async(
128 streams,
129 &mut output.0.d_vec,
130 output_indexes,
131 &accumulator.0.d_vec,
132 lut_indexes,
133 &input.0.d_vec,
134 input_indexes,
135 &multi_bit_bsk.d_vec,
136 input.lwe_dimension(),
137 multi_bit_bsk.glwe_dimension(),
138 multi_bit_bsk.polynomial_size(),
139 multi_bit_bsk.decomp_base_log(),
140 multi_bit_bsk.decomp_level_count(),
141 multi_bit_bsk.grouping_factor(),
142 input.lwe_ciphertext_count().0 as u32,
143 );
144}
145
146#[allow(clippy::too_many_arguments)]
147pub fn cuda_multi_bit_programmable_bootstrap_lwe_ciphertext<Scalar>(
148 input: &CudaLweCiphertextList<Scalar>,
149 output: &mut CudaLweCiphertextList<Scalar>,
150 accumulator: &CudaGlweCiphertextList<Scalar>,
151 lut_indexes: &CudaVec<Scalar>,
152 output_indexes: &CudaVec<Scalar>,
153 input_indexes: &CudaVec<Scalar>,
154 multi_bit_bsk: &CudaLweMultiBitBootstrapKey,
155 streams: &CudaStreams,
156) where
157 Scalar: UnsignedTorus + CastInto<usize>,
159{
160 unsafe {
161 cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_async(
162 input,
163 output,
164 accumulator,
165 lut_indexes,
166 output_indexes,
167 input_indexes,
168 multi_bit_bsk,
169 streams,
170 );
171 }
172 streams.synchronize();
173}