use rand::Rng;
use crate::core_crypto::gpu::get_number_of_gpus;
use crate::high_level_api::global_state::CustomMultiGpuIndexes;
use crate::prelude::*;
use crate::{
clear_gpu_thread_locals, set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device,
FheUint32, FheUint8, GpuIndex,
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::ThreadPoolBuilder;
#[test]
fn test_drop_rayon_pool_with_gpu_server_key_thread_locals() {
let config = ConfigBuilder::default().build();
let cks = ClientKey::generate(config);
let num_gpus = get_number_of_gpus() as usize;
let compressed_sks = CompressedServerKey::new(&cks);
let sks_vec: Vec<_> = (0..num_gpus)
.map(|i| compressed_sks.decompress_to_specific_gpu(GpuIndex::new(i as u32)))
.collect();
let pool = ThreadPoolBuilder::new()
.num_threads(4 * num_gpus)
.exit_handler(|_| clear_gpu_thread_locals())
.build()
.unwrap();
let results: Vec<u8> = pool.install(|| {
(0..4 * num_gpus)
.into_par_iter()
.map_init(
|| {
let gpu_index = rayon::current_thread_index().unwrap_or(0) % num_gpus;
set_server_key(sks_vec[gpu_index].clone());
},
|(), _| {
let ct = FheUint8::encrypt_trivial(42u8);
let result: u8 = ct.decrypt(&cks);
result
},
)
.collect()
});
for val in &results {
assert_eq!(*val, 42u8);
}
drop(pool);
}
#[test]
fn test_gpu_selection() {
let config = ConfigBuilder::default().build();
let keys = ClientKey::generate(config);
let compressed_server_keys = CompressedServerKey::new(&keys);
let mut rng = rand::thread_rng();
let last_gpu = GpuIndex::new(get_number_of_gpus() - 1);
let clear_a: u32 = rng.gen();
let clear_b: u32 = rng.gen();
let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap();
let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap();
assert_eq!(a.current_device(), Device::Cpu);
assert_eq!(b.current_device(), Device::Cpu);
assert_eq!(a.gpu_indexes(), &[]);
assert_eq!(b.gpu_indexes(), &[]);
let cuda_key = compressed_server_keys.decompress_to_specific_gpu(last_gpu);
set_server_key(cuda_key);
let c = &a + &b;
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[last_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
assert_eq!(a.current_device(), Device::Cpu);
assert_eq!(b.current_device(), Device::Cpu);
assert_eq!(a.gpu_indexes(), &[]);
assert_eq!(b.gpu_indexes(), &[]);
a.move_to_current_device();
b.move_to_current_device();
assert_eq!(a.current_device(), Device::CudaGpu);
assert_eq!(b.current_device(), Device::CudaGpu);
assert_eq!(a.gpu_indexes(), &[last_gpu]);
assert_eq!(b.gpu_indexes(), &[last_gpu]);
let c = &a + &b;
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[last_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
}
#[test]
fn test_gpu_selection_2() {
if get_number_of_gpus() < 2 {
return;
}
let config = ConfigBuilder::default().build();
let keys = ClientKey::generate(config);
let compressed_server_keys = CompressedServerKey::new(&keys);
let mut rng = rand::thread_rng();
let first_gpu = GpuIndex::new(0);
let last_gpu = GpuIndex::new(get_number_of_gpus() - 1);
let clear_a: u32 = rng.gen();
let clear_b: u32 = rng.gen();
let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap();
let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap();
assert_eq!(a.current_device(), Device::Cpu);
assert_eq!(b.current_device(), Device::Cpu);
assert_eq!(a.gpu_indexes(), &[]);
assert_eq!(b.gpu_indexes(), &[]);
let cuda_key = compressed_server_keys.decompress_to_specific_gpu(last_gpu);
set_server_key(cuda_key);
a.move_to_current_device();
b.move_to_current_device();
assert_eq!(a.current_device(), Device::CudaGpu);
assert_eq!(b.current_device(), Device::CudaGpu);
assert_eq!(a.gpu_indexes(), &[last_gpu]);
assert_eq!(b.gpu_indexes(), &[last_gpu]);
let c = &a + &b;
let cuda_key = compressed_server_keys.decompress_to_specific_gpu(first_gpu);
set_server_key(cuda_key);
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[last_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
let c = &a + &b;
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[first_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
}
#[test]
fn test_specific_gpu_selection() {
let config = ConfigBuilder::default().build();
let keys = ClientKey::generate(config);
let compressed_server_keys = CompressedServerKey::new(&keys);
let mut rng = rand::thread_rng();
let total_gpus = get_number_of_gpus() as usize;
for num_gpus_to_use in 1..(1 << total_gpus) {
let mut selected_indices = Vec::new();
for j in 0..total_gpus {
if (num_gpus_to_use & (1 << j)) != 0 {
selected_indices.push(j);
}
}
let gpus_to_be_used = CustomMultiGpuIndexes::new(
selected_indices
.iter()
.map(|idx| GpuIndex::new(*idx as u32))
.collect(),
);
let cuda_key = compressed_server_keys.decompress_to_specific_gpu(gpus_to_be_used);
let first_gpu = GpuIndex::new(selected_indices[0] as u32);
let clear_a: u32 = rng.gen();
let clear_b: u32 = rng.gen();
let mut a = FheUint32::try_encrypt(clear_a, &keys).unwrap();
let mut b = FheUint32::try_encrypt(clear_b, &keys).unwrap();
assert_eq!(a.current_device(), Device::Cpu);
assert_eq!(b.current_device(), Device::Cpu);
assert_eq!(a.gpu_indexes(), &[]);
assert_eq!(b.gpu_indexes(), &[]);
set_server_key(cuda_key);
let c = &a + &b;
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[first_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
assert_eq!(a.current_device(), Device::Cpu);
assert_eq!(b.current_device(), Device::Cpu);
assert_eq!(a.gpu_indexes(), &[]);
assert_eq!(b.gpu_indexes(), &[]);
a.move_to_current_device();
b.move_to_current_device();
assert_eq!(a.current_device(), Device::CudaGpu);
assert_eq!(b.current_device(), Device::CudaGpu);
assert_eq!(a.gpu_indexes(), &[first_gpu]);
assert_eq!(b.gpu_indexes(), &[first_gpu]);
let c = &a + &b;
let decrypted: u32 = c.decrypt(&keys);
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[first_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
clear_gpu_thread_locals();
}
}