tfhe/core_crypto/gpu/algorithms/
lwe_multi_bit_programmable_bootstrapping.rs

1use crate::core_crypto::gpu::entities::glwe_ciphertext_list::CudaGlweCiphertextList;
2use crate::core_crypto::gpu::entities::lwe_ciphertext_list::CudaLweCiphertextList;
3use crate::core_crypto::gpu::entities::lwe_multi_bit_bootstrap_key::CudaLweMultiBitBootstrapKey;
4use crate::core_crypto::gpu::vec::CudaVec;
5use crate::core_crypto::gpu::{programmable_bootstrap_multi_bit_async, CudaStreams};
6use crate::core_crypto::prelude::{CastInto, UnsignedTorus};
7
8/// # Safety
9///
10/// - `streams` __must__ be synchronized to guarantee computation has finished, and inputs must not
11///   be dropped until streams is synchronised
12#[allow(clippy::too_many_arguments)]
13pub unsafe fn cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_async<Scalar>(
14    input: &CudaLweCiphertextList<Scalar>,
15    output: &mut CudaLweCiphertextList<Scalar>,
16    accumulator: &CudaGlweCiphertextList<Scalar>,
17    lut_indexes: &CudaVec<Scalar>,
18    output_indexes: &CudaVec<Scalar>,
19    input_indexes: &CudaVec<Scalar>,
20    multi_bit_bsk: &CudaLweMultiBitBootstrapKey,
21    streams: &CudaStreams,
22) where
23    // CastInto required for PBS modulus switch which returns a usize
24    Scalar: UnsignedTorus + CastInto<usize>,
25{
26    assert_eq!(
27        input.lwe_dimension(),
28        multi_bit_bsk.input_lwe_dimension(),
29        "Mismatched input LweDimension. LweCiphertext input LweDimension {:?}. \
30        FourierLweMultiBitBootstrapKey input LweDimension {:?}.",
31        input.lwe_dimension(),
32        multi_bit_bsk.input_lwe_dimension(),
33    );
34
35    assert_eq!(
36        output.lwe_dimension(),
37        multi_bit_bsk.output_lwe_dimension(),
38        "Mismatched output LweDimension. LweCiphertext output LweDimension {:?}. \
39        FourierLweMultiBitBootstrapKey output LweDimension {:?}.",
40        output.lwe_dimension(),
41        multi_bit_bsk.output_lwe_dimension(),
42    );
43
44    assert_eq!(
45        accumulator.glwe_dimension(),
46        multi_bit_bsk.glwe_dimension(),
47        "Mismatched GlweSize. Accumulator GlweSize {:?}. \
48        FourierLweMultiBitBootstrapKey GlweSize {:?}.",
49        accumulator.glwe_dimension(),
50        multi_bit_bsk.glwe_dimension(),
51    );
52
53    assert_eq!(
54        accumulator.polynomial_size(),
55        multi_bit_bsk.polynomial_size(),
56        "Mismatched PolynomialSize. Accumulator PolynomialSize {:?}. \
57        FourierLweMultiBitBootstrapKey PolynomialSize {:?}.",
58        accumulator.polynomial_size(),
59        multi_bit_bsk.polynomial_size(),
60    );
61
62    assert_eq!(
63        input.ciphertext_modulus(),
64        output.ciphertext_modulus(),
65        "Mismatched CiphertextModulus between input ({:?}) and output ({:?})",
66        input.ciphertext_modulus(),
67        output.ciphertext_modulus(),
68    );
69
70    assert_eq!(
71        input.ciphertext_modulus(),
72        accumulator.ciphertext_modulus(),
73        "Mismatched CiphertextModulus between input ({:?}) and accumulator ({:?})",
74        input.ciphertext_modulus(),
75        accumulator.ciphertext_modulus(),
76    );
77    assert_eq!(
78        streams.gpu_indexes[0],
79        multi_bit_bsk.d_vec.gpu_index(0),
80        "GPU error: first stream is on GPU {}, first bsk pointer is on GPU {}",
81        streams.gpu_indexes[0].get(),
82        multi_bit_bsk.d_vec.gpu_index(0).get(),
83    );
84    assert_eq!(
85        streams.gpu_indexes[0],
86        input.0.d_vec.gpu_index(0),
87        "GPU error: first stream is on GPU {}, first input pointer is on GPU {}",
88        streams.gpu_indexes[0].get(),
89        input.0.d_vec.gpu_index(0).get(),
90    );
91    assert_eq!(
92        streams.gpu_indexes[0],
93        output.0.d_vec.gpu_index(0),
94        "GPU error: first stream is on GPU {}, first output pointer is on GPU {}",
95        streams.gpu_indexes[0].get(),
96        output.0.d_vec.gpu_index(0).get(),
97    );
98    assert_eq!(
99        streams.gpu_indexes[0],
100        accumulator.0.d_vec.gpu_index(0),
101        "GPU error: first stream is on GPU {}, first accumulator pointer is on GPU {}",
102        streams.gpu_indexes[0].get(),
103        accumulator.0.d_vec.gpu_index(0).get(),
104    );
105    assert_eq!(
106        streams.gpu_indexes[0],
107        input_indexes.gpu_index(0),
108        "GPU error: first stream is on GPU {}, first input indexes pointer is on GPU {}",
109        streams.gpu_indexes[0].get(),
110        input_indexes.gpu_index(0).get(),
111    );
112    assert_eq!(
113        streams.gpu_indexes[0],
114        output_indexes.gpu_index(0),
115        "GPU error: first stream is on GPU {}, first output indexes pointer is on GPU {}",
116        streams.gpu_indexes[0].get(),
117        output_indexes.gpu_index(0).get(),
118    );
119    assert_eq!(
120        streams.gpu_indexes[0],
121        lut_indexes.gpu_index(0),
122        "GPU error: first stream is on GPU {}, first lut indexes pointer is on GPU {}",
123        streams.gpu_indexes[0].get(),
124        lut_indexes.gpu_index(0).get(),
125    );
126
127    programmable_bootstrap_multi_bit_async(
128        streams,
129        &mut output.0.d_vec,
130        output_indexes,
131        &accumulator.0.d_vec,
132        lut_indexes,
133        &input.0.d_vec,
134        input_indexes,
135        &multi_bit_bsk.d_vec,
136        input.lwe_dimension(),
137        multi_bit_bsk.glwe_dimension(),
138        multi_bit_bsk.polynomial_size(),
139        multi_bit_bsk.decomp_base_log(),
140        multi_bit_bsk.decomp_level_count(),
141        multi_bit_bsk.grouping_factor(),
142        input.lwe_ciphertext_count().0 as u32,
143    );
144}
145
146#[allow(clippy::too_many_arguments)]
147pub fn cuda_multi_bit_programmable_bootstrap_lwe_ciphertext<Scalar>(
148    input: &CudaLweCiphertextList<Scalar>,
149    output: &mut CudaLweCiphertextList<Scalar>,
150    accumulator: &CudaGlweCiphertextList<Scalar>,
151    lut_indexes: &CudaVec<Scalar>,
152    output_indexes: &CudaVec<Scalar>,
153    input_indexes: &CudaVec<Scalar>,
154    multi_bit_bsk: &CudaLweMultiBitBootstrapKey,
155    streams: &CudaStreams,
156) where
157    // CastInto required for PBS modulus switch which returns a usize
158    Scalar: UnsignedTorus + CastInto<usize>,
159{
160    unsafe {
161        cuda_multi_bit_programmable_bootstrap_lwe_ciphertext_async(
162            input,
163            output,
164            accumulator,
165            lut_indexes,
166            output_indexes,
167            input_indexes,
168            multi_bit_bsk,
169            streams,
170        );
171    }
172    streams.synchronize();
173}