use super::super::utils::{catch_panic, get_mut_checked};
use crate::c_api::high_level_api::keys::ServerKey;
use crate::c_api::utils::get_ref_checked;
use rayon::{ThreadPool, ThreadPoolBuilder};
use std::ffi::{c_int, c_void};
pub struct TfheThreadingContext {
pool: ThreadPool,
}
#[no_mangle]
pub unsafe extern "C" fn tfhe_threading_context_create(
num_threads: usize,
context: *mut *mut TfheThreadingContext,
) -> c_int {
catch_panic(|| {
*context = std::ptr::null_mut();
let pool = ThreadPoolBuilder::new()
.num_threads(num_threads)
.build()
.unwrap();
let ctx = TfheThreadingContext { pool };
*context = Box::into_raw(Box::new(ctx));
})
}
#[no_mangle]
pub unsafe extern "C" fn tfhe_threading_context_destroy(context: *mut TfheThreadingContext) {
let _ = catch_panic(|| {
if context.is_null() {
return;
}
drop(Box::from_raw(context));
});
}
#[no_mangle]
pub unsafe extern "C" fn tfhe_threading_context_set_server_key(
context: *mut TfheThreadingContext,
server_key: *const ServerKey,
) -> c_int {
catch_panic(|| {
let sks = get_ref_checked(server_key).map(|sks| &sks.0).unwrap();
let context = get_mut_checked(context).unwrap();
context
.pool
.broadcast(|_| crate::high_level_api::set_server_key(sks.clone()));
})
}
#[no_mangle]
pub unsafe extern "C" fn tfhe_threading_context_run(
context: *mut TfheThreadingContext,
func: extern "C" fn(*mut c_void) -> c_int,
data: *mut c_void,
) -> c_int {
struct TheUserEnsuresDataIsThreadSafe(*mut c_void);
unsafe impl Send for TheUserEnsuresDataIsThreadSafe {}
struct TheUserEnsuresTheFuncIsThreadSafe(extern "C" fn(*mut c_void) -> c_int);
unsafe impl Send for TheUserEnsuresTheFuncIsThreadSafe {}
#[allow(clippy::needless_pass_by_ref_mut)]
impl TheUserEnsuresTheFuncIsThreadSafe {
fn execute(&mut self, data: &TheUserEnsuresDataIsThreadSafe) -> c_int {
(self.0)(data.0)
}
}
let mut func = TheUserEnsuresTheFuncIsThreadSafe(func);
let data = TheUserEnsuresDataIsThreadSafe(data);
let mut result = 0;
let panic_result = catch_panic(|| {
let context = get_mut_checked(context).unwrap();
result = context.pool.install(move || func.execute(&data));
});
if panic_result != 0 {
panic_result
} else {
result
}
}