use crate::backends::cuda::engines::SharedMemoryAmount;
use crate::backends::cuda::private::crypto::glwe::list::CudaGlweList;
use crate::backends::cuda::private::crypto::lwe::list::CudaLweList;
use crate::backends::cuda::private::device::{CudaStream, GpuIndex, NumberOfGpus};
use crate::backends::cuda::private::vec::CudaVec;
use crate::backends::cuda::private::{compute_number_of_samples_on_gpu, number_of_active_gpus};
use crate::commons::crypto::bootstrap::StandardBootstrapKey;
use crate::commons::math::tensor::{AsRefSlice, AsRefTensor};
use crate::commons::numeric::UnsignedInteger;
use crate::prelude::{
CiphertextCount, DecompositionBaseLog, DecompositionLevelCount, GlweDimension,
LweCiphertextIndex, LweDimension, PolynomialSize,
};
use std::marker::PhantomData;
#[derive(Debug)]
pub(crate) struct CudaBootstrapKey<T: UnsignedInteger> {
pub(crate) d_vecs: Vec<CudaVec<f64>>,
pub(crate) input_lwe_dimension: LweDimension,
pub(crate) polynomial_size: PolynomialSize,
pub(crate) glwe_dimension: GlweDimension,
pub(crate) decomp_level: DecompositionLevelCount,
pub(crate) decomp_base_log: DecompositionBaseLog,
pub(crate) _phantom: PhantomData<T>,
}
unsafe impl<T> Send for CudaBootstrapKey<T> where T: Send + UnsignedInteger {}
unsafe impl<T> Sync for CudaBootstrapKey<T> where T: Sync + UnsignedInteger {}
pub(crate) unsafe fn convert_lwe_bootstrap_key_from_cpu_to_gpu<T: UnsignedInteger, Cont>(
streams: &[CudaStream],
input: &StandardBootstrapKey<Cont>,
number_of_gpus: NumberOfGpus,
) -> Vec<CudaVec<f64>>
where
Cont: AsRefSlice<Element = T>,
{
let mut vecs = Vec::with_capacity(number_of_gpus.0);
let total_polynomials =
input.key_size().0 * input.glwe_size().0 * input.glwe_size().0 * input.level_count().0;
let alloc_size = total_polynomials * input.polynomial_size().0;
for stream in streams.iter() {
let mut d_vec = stream.malloc::<f64>(alloc_size as u32);
let input_slice = input.as_tensor().as_slice();
stream.initialize_twiddles(input.polynomial_size());
stream.convert_lwe_bootstrap_key::<T>(
&mut d_vec,
input_slice,
input.key_size(),
input.glwe_size().to_glwe_dimension(),
input.level_count(),
input.polynomial_size(),
);
vecs.push(d_vec);
}
vecs
}
pub(crate) unsafe fn execute_lwe_ciphertext_vector_low_latency_bootstrap_on_gpu<
T: UnsignedInteger,
>(
streams: &[CudaStream],
output: &mut CudaLweList<T>,
input: &CudaLweList<T>,
acc: &CudaGlweList<T>,
bsk: &CudaBootstrapKey<T>,
number_of_available_gpus: NumberOfGpus,
cuda_shared_memory: SharedMemoryAmount,
) {
let number_of_gpus = number_of_active_gpus(
number_of_available_gpus,
CiphertextCount(input.lwe_ciphertext_count.0),
);
let samples_on_gpu_0 = compute_number_of_samples_on_gpu(
number_of_gpus,
CiphertextCount(input.lwe_ciphertext_count.0),
GpuIndex(0),
);
for (gpu_index, stream) in streams.iter().enumerate().take(number_of_gpus.0) {
let samples = compute_number_of_samples_on_gpu(
number_of_gpus,
CiphertextCount(input.lwe_ciphertext_count.0),
GpuIndex(gpu_index),
);
let test_vector_indexes = (0..samples.0 as u32).collect::<Vec<u32>>();
let mut d_test_vector_indexes = stream.malloc::<u32>(samples.0 as u32);
stream.copy_to_gpu(&mut d_test_vector_indexes, &test_vector_indexes);
stream.initialize_twiddles(bsk.polynomial_size);
stream.discard_bootstrap_low_latency_lwe_ciphertext_vector::<T>(
output.d_vecs.get_mut(gpu_index).unwrap(),
acc.d_vecs.get(gpu_index).unwrap(),
&d_test_vector_indexes,
input.d_vecs.get(gpu_index).unwrap(),
bsk.d_vecs.get(gpu_index).unwrap(),
input.lwe_dimension,
bsk.glwe_dimension,
bsk.polynomial_size,
bsk.decomp_base_log,
bsk.decomp_level,
samples,
LweCiphertextIndex(samples_on_gpu_0.0 * gpu_index),
cuda_shared_memory,
);
}
}
pub(crate) unsafe fn execute_lwe_ciphertext_vector_amortized_bootstrap_on_gpu<
T: UnsignedInteger,
>(
streams: &[CudaStream],
output: &mut CudaLweList<T>,
input: &CudaLweList<T>,
acc: &CudaGlweList<T>,
bsk: &CudaBootstrapKey<T>,
number_of_available_gpus: NumberOfGpus,
cuda_shared_memory: SharedMemoryAmount,
) {
let number_of_gpus = number_of_active_gpus(
number_of_available_gpus,
CiphertextCount(input.lwe_ciphertext_count.0),
);
let samples_on_gpu_0 = compute_number_of_samples_on_gpu(
number_of_gpus,
CiphertextCount(input.lwe_ciphertext_count.0),
GpuIndex(0),
);
for (gpu_index, stream) in streams.iter().enumerate().take(number_of_gpus.0) {
let samples = compute_number_of_samples_on_gpu(
number_of_gpus,
CiphertextCount(input.lwe_ciphertext_count.0),
GpuIndex(gpu_index),
);
let test_vector_indexes = (0..samples.0 as u32).collect::<Vec<u32>>();
let mut d_test_vector_indexes = stream.malloc::<u32>(samples.0 as u32);
stream.copy_to_gpu(&mut d_test_vector_indexes, &test_vector_indexes);
stream.initialize_twiddles(bsk.polynomial_size);
stream.discard_bootstrap_amortized_lwe_ciphertext_vector::<T>(
output.d_vecs.get_mut(gpu_index).unwrap(),
acc.d_vecs.get(gpu_index).unwrap(),
&d_test_vector_indexes,
input.d_vecs.get(gpu_index).unwrap(),
bsk.d_vecs.get(gpu_index).unwrap(),
input.lwe_dimension,
bsk.glwe_dimension,
bsk.polynomial_size,
bsk.decomp_base_log,
bsk.decomp_level,
samples,
LweCiphertextIndex(samples_on_gpu_0.0 * gpu_index),
cuda_shared_memory,
);
}
}