use crate::c_api::utils::*;
use std::os::raw::c_int;
use super::{ShortintCiphertext, ShortintServerKey};
pub type AccumulatorCallback = Option<extern "C" fn(u64) -> u64>;
pub type BivariateAccumulatorCallback = Option<extern "C" fn(u64, u64) -> u64>;
pub struct ShortintPBSAccumulator(pub(in crate::c_api) crate::shortint::server_key::Accumulator);
pub struct ShortintBivariatePBSAccumulator(
pub(in crate::c_api) crate::shortint::server_key::Accumulator,
);
#[no_mangle]
pub unsafe extern "C" fn shortint_server_key_generate_pbs_accumulator(
server_key: *const ShortintServerKey,
accumulator_callback: AccumulatorCallback,
result: *mut *mut ShortintPBSAccumulator,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
*result = std::ptr::null_mut();
let accumulator_callback = accumulator_callback.unwrap();
let server_key = get_ref_checked(server_key).unwrap();
#[allow(clippy::redundant_closure)]
let heap_allocated_accumulator = Box::new(ShortintPBSAccumulator(
server_key
.0
.generate_accumulator(|x: u64| accumulator_callback(x)),
));
*result = Box::into_raw(heap_allocated_accumulator);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_server_key_programmable_bootstrap(
server_key: *const ShortintServerKey,
accumulator: *const ShortintPBSAccumulator,
ct_in: *const ShortintCiphertext,
result: *mut *mut ShortintCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
*result = std::ptr::null_mut();
let server_key = get_ref_checked(server_key).unwrap();
let accumulator = get_ref_checked(accumulator).unwrap();
let ct_in = get_ref_checked(ct_in).unwrap();
let heap_allocated_result = Box::new(ShortintCiphertext(
server_key
.0
.keyswitch_programmable_bootstrap(&ct_in.0, &accumulator.0),
));
*result = Box::into_raw(heap_allocated_result);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_server_key_programmable_bootstrap_assign(
server_key: *const ShortintServerKey,
accumulator: *const ShortintPBSAccumulator,
ct_in_and_result: *mut ShortintCiphertext,
) -> c_int {
catch_panic(|| {
let server_key = get_ref_checked(server_key).unwrap();
let accumulator = get_ref_checked(accumulator).unwrap();
let ct_in_and_result = get_mut_checked(ct_in_and_result).unwrap();
server_key
.0
.keyswitch_programmable_bootstrap_assign(&mut ct_in_and_result.0, &accumulator.0);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_server_key_generate_bivariate_pbs_accumulator(
server_key: *const ShortintServerKey,
accumulator_callback: BivariateAccumulatorCallback,
result: *mut *mut ShortintBivariatePBSAccumulator,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
*result = std::ptr::null_mut();
let accumulator_callback = accumulator_callback.unwrap();
let server_key = get_ref_checked(server_key).unwrap();
#[allow(clippy::redundant_closure)]
let heap_allocated_accumulator = Box::new(ShortintBivariatePBSAccumulator(
server_key
.0
.generate_accumulator_bivariate(|x: u64, y: u64| accumulator_callback(x, y)),
));
*result = Box::into_raw(heap_allocated_accumulator);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_server_key_bivariate_programmable_bootstrap(
server_key: *const ShortintServerKey,
accumulator: *const ShortintBivariatePBSAccumulator,
ct_left: *const ShortintCiphertext,
ct_right: *mut ShortintCiphertext,
result: *mut *mut ShortintCiphertext,
) -> c_int {
catch_panic(|| {
check_ptr_is_non_null_and_aligned(result).unwrap();
*result = std::ptr::null_mut();
let server_key = get_ref_checked(server_key).unwrap();
let accumulator = get_ref_checked(accumulator).unwrap();
let ct_left = get_ref_checked(ct_left).unwrap();
let ct_right = get_mut_checked(ct_right).unwrap();
let heap_allocated_result = Box::new(ShortintCiphertext(
crate::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
engine
.smart_bivariate_pbs(&server_key.0, &ct_left.0, &mut ct_right.0, &accumulator.0)
.unwrap()
}),
));
*result = Box::into_raw(heap_allocated_result);
})
}
#[no_mangle]
pub unsafe extern "C" fn shortint_server_key_bivariate_programmable_bootstrap_assign(
server_key: *const ShortintServerKey,
accumulator: *const ShortintBivariatePBSAccumulator,
ct_left_and_result: *mut ShortintCiphertext,
ct_right: *mut ShortintCiphertext,
) -> c_int {
catch_panic(|| {
let server_key = get_ref_checked(server_key).unwrap();
let accumulator = get_ref_checked(accumulator).unwrap();
let ct_left_and_result = get_mut_checked(ct_left_and_result).unwrap();
let ct_right = get_mut_checked(ct_right).unwrap();
crate::shortint::engine::ShortintEngine::with_thread_local_mut(|engine| {
engine
.smart_bivariate_pbs_assign(
&server_key.0,
&mut ct_left_and_result.0,
&mut ct_right.0,
&accumulator.0,
)
.unwrap()
});
})
}