use std::{
ffi::CStr,
iter::{IntoIterator, Iterator},
mem,
os::raw::{c_char, c_int, c_uint, c_ulonglong, c_void},
panic,
ptr,
};
use tokio::{
runtime::{Builder, Runtime},
time::Duration,
};
use crate::{
client::{Client, ClientError, Task},
mask::model::{FromPrimitives, IntoPrimitives, Model},
};
#[derive(Clone, Copy, Debug)]
#[repr(C)]
pub struct PrimitiveModel {
pub ptr: *mut c_void,
pub len: c_ulonglong,
pub dtype: c_uint,
}
#[derive(Clone, Debug)]
pub(crate) enum CachedModel {
F32(Vec<f32>),
F64(Vec<f64>),
I32(Vec<i32>),
I64(Vec<i64>),
}
pub struct FFIClient {
client: Client,
runtime: Runtime,
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
pub unsafe extern "C" fn new_client(address: *const c_char, period: c_ulonglong) -> *mut FFIClient {
if address.is_null() || period == 0 {
return ptr::null_mut() as *mut FFIClient;
}
let address = if let Ok(address) = unsafe {
CStr::from_ptr(address)
}
.to_str()
{
address
} else {
return ptr::null_mut() as *mut FFIClient;
};
let runtime = if let Ok(runtime) = Builder::new()
.threaded_scheduler()
.core_threads(1)
.max_threads(4)
.thread_name("xaynet-client-runtime-worker")
.enable_all()
.build()
{
runtime
} else {
return ptr::null_mut() as *mut FFIClient;
};
let client = if let Ok(client) =
runtime.enter(move || Client::new_with_addr(period as u64, 0, address))
{
client
} else {
return ptr::null_mut() as *mut FFIClient;
};
Box::into_raw(Box::new(FFIClient { runtime, client }))
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
pub unsafe extern "C" fn run_client(client: *mut FFIClient) -> c_int {
if client.is_null() {
return -1_i32 as c_int;
}
let (runtime, client) = unsafe {
(&(*client).runtime, &mut (*client).client)
};
match panic::catch_unwind(unsafe {
panic::AssertUnwindSafe(|| runtime.handle().block_on(client.start()))
}) {
Ok(Ok(_)) => 0_i32 as c_int,
Err(_) => 1_i32 as c_int,
Ok(Err(ClientError::ParticipantInitErr(_))) => 2_i32 as c_int,
Ok(Err(ClientError::ParticipantErr(_))) => 3_i32 as c_int,
Ok(Err(ClientError::DeserialiseErr(_))) => 4_i32 as c_int,
Ok(Err(ClientError::NetworkErr(_))) => 5_i32 as c_int,
Ok(Err(ClientError::ParseErr)) => 6_i32 as c_int,
Ok(Err(ClientError::GeneralErr)) => 7_i32 as c_int,
Ok(Err(ClientError::Fetch(_))) => 8_i32 as c_int,
Ok(Err(ClientError::PetMessage(_))) => 9_i32 as c_int,
}
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
pub unsafe extern "C" fn drop_client(client: *mut FFIClient, timeout: c_ulonglong) {
if !client.is_null() {
let client = unsafe {
Box::from_raw(client)
};
if timeout as usize != 0 {
client
.runtime
.shutdown_timeout(Duration::from_secs(timeout as u64));
}
}
}
#[allow(unused_unsafe)]
#[no_mangle]
pub unsafe extern "C" fn is_next_round(client: *mut FFIClient) -> bool {
if client.is_null() {
false
} else {
let client = unsafe {
&mut (*client).client
};
mem::replace(&mut client.has_new_coord_pk_since_last_check, false)
}
}
#[allow(unused_unsafe)]
#[no_mangle]
pub unsafe extern "C" fn has_next_model(client: *mut FFIClient) -> bool {
if client.is_null() {
false
} else {
let client = unsafe {
&mut (*client).client
};
mem::replace(&mut client.has_new_global_model_since_last_check, false)
}
}
#[allow(unused_unsafe)]
#[no_mangle]
pub unsafe extern "C" fn is_update_participant(client: *mut FFIClient) -> bool {
if client.is_null() {
false
} else {
let client = unsafe {
&(*client).client
};
client.participant.task == Task::Update
}
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
pub unsafe extern "C" fn new_model(
client: *mut FFIClient,
dtype: c_uint,
len: c_ulonglong,
) -> PrimitiveModel {
let max_len = match dtype {
1 | 3 => isize::MAX / 4,
2 | 4 => isize::MAX / 8,
_ => 0,
} as c_ulonglong;
if client.is_null() || dtype == 0 || dtype > 4 || len == 0 || len > max_len {
return PrimitiveModel {
ptr: ptr::null_mut() as *mut c_void,
len: 0_u64 as c_ulonglong,
dtype: 0_u32 as c_uint,
};
}
let client = unsafe {
&mut (*client).client
};
let ptr = match dtype {
1 => {
let mut cached_model = vec![0_f32; len as usize];
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::F32(cached_model));
ptr
}
2 => {
let mut cached_model = vec![0_f64; len as usize];
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::F64(cached_model));
ptr
}
3 => {
let mut cached_model = vec![0_i32; len as usize];
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::I32(cached_model));
ptr
}
4 => {
let mut cached_model = vec![0_i64; len as usize];
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::I64(cached_model));
ptr
}
_ => unreachable!(),
};
PrimitiveModel { ptr, len, dtype }
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
pub unsafe extern "C" fn get_model(client: *mut FFIClient, dtype: c_uint) -> PrimitiveModel {
if client.is_null() || dtype == 0 || dtype > 4 {
return PrimitiveModel {
ptr: ptr::null_mut() as *mut c_void,
len: 0_u64 as c_ulonglong,
dtype: 0_u32 as c_uint,
};
}
let client = unsafe {
&mut (*client).client
};
if let Some(ref global_model) = client.global_model {
if !client.has_new_global_model_since_last_cache {
match dtype {
1 => {
if let Some(CachedModel::F32(ref mut cached_model)) = client.cached_model {
return PrimitiveModel {
ptr: cached_model.as_mut_ptr() as *mut c_void,
len: cached_model.len() as c_ulonglong,
dtype,
};
}
}
2 => {
if let Some(CachedModel::F64(ref mut cached_model)) = client.cached_model {
return PrimitiveModel {
ptr: cached_model.as_mut_ptr() as *mut c_void,
len: cached_model.len() as c_ulonglong,
dtype,
};
}
}
3 => {
if let Some(CachedModel::I32(ref mut cached_model)) = client.cached_model {
return PrimitiveModel {
ptr: cached_model.as_mut_ptr() as *mut c_void,
len: cached_model.len() as c_ulonglong,
dtype,
};
}
}
4 => {
if let Some(CachedModel::I64(ref mut cached_model)) = client.cached_model {
return PrimitiveModel {
ptr: cached_model.as_mut_ptr() as *mut c_void,
len: cached_model.len() as c_ulonglong,
dtype,
};
}
}
_ => unreachable!(),
}
}
client.has_new_global_model_since_last_cache = false;
let len = global_model.len() as c_ulonglong;
let ptr = match dtype {
1 => {
if let Ok(mut cached_model) = global_model
.to_primitives()
.map(|res| res.map_err(|_| ()))
.collect::<Result<Vec<f32>, ()>>()
{
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::F32(cached_model));
ptr
} else {
client.cached_model = None;
ptr::null_mut() as *mut c_void
}
}
2 => {
if let Ok(mut cached_model) = global_model
.to_primitives()
.map(|res| res.map_err(|_| ()))
.collect::<Result<Vec<f64>, ()>>()
{
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::F64(cached_model));
ptr
} else {
client.cached_model = None;
ptr::null_mut() as *mut c_void
}
}
3 => {
if let Ok(mut cached_model) = global_model
.to_primitives()
.map(|res| res.map_err(|_| ()))
.collect::<Result<Vec<i32>, ()>>()
{
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::I32(cached_model));
ptr
} else {
client.cached_model = None;
ptr::null_mut() as *mut c_void
}
}
4 => {
if let Ok(mut cached_model) = global_model
.to_primitives()
.map(|res| res.map_err(|_| ()))
.collect::<Result<Vec<i64>, ()>>()
{
let ptr = cached_model.as_mut_ptr() as *mut c_void;
client.cached_model = Some(CachedModel::I64(cached_model));
ptr
} else {
client.cached_model = None;
ptr::null_mut() as *mut c_void
}
}
_ => unreachable!(),
};
return PrimitiveModel { ptr, len, dtype };
}
client.cached_model = None;
PrimitiveModel {
ptr: ptr::null_mut() as *mut c_void,
len: 0_u64 as c_ulonglong,
dtype: 0_u32 as c_uint,
}
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
pub unsafe extern "C" fn update_model(client: *mut FFIClient) -> c_int {
if client.is_null() {
return -1_i32 as c_int;
}
let client = unsafe {
&mut (*client).client
};
client.local_model = match client.cached_model.take() {
Some(CachedModel::F32(cached_model)) => {
Some(Model::from_primitives_bounded(cached_model.into_iter()))
}
Some(CachedModel::F64(cached_model)) => {
Some(Model::from_primitives_bounded(cached_model.into_iter()))
}
Some(CachedModel::I32(cached_model)) => {
Some(Model::from_primitives_bounded(cached_model.into_iter()))
}
Some(CachedModel::I64(cached_model)) => {
Some(Model::from_primitives_bounded(cached_model.into_iter()))
}
None => return 1_i32 as c_int,
};
0_i32 as c_int
}
#[allow(unused_unsafe)]
#[no_mangle]
pub unsafe extern "C" fn drop_model(client: *mut FFIClient) {
if !client.is_null() {
let client = unsafe {
&mut (*client).client
};
client.cached_model.take();
}
}
pub use self::dart::*;
mod dart {
use std::os::raw::c_uint;
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
#[doc(hidden)]
pub unsafe extern "C" fn is_next_round_dart(client: *mut super::FFIClient) -> c_uint {
if unsafe {
super::is_next_round(client)
} {
1_u32 as c_uint
} else {
0_u32 as c_uint
}
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
#[doc(hidden)]
pub unsafe extern "C" fn has_next_model_dart(client: *mut super::FFIClient) -> c_uint {
if unsafe {
super::has_next_model(client)
} {
1_u32 as c_uint
} else {
0_u32 as c_uint
}
}
#[allow(unused_unsafe)]
#[allow(clippy::unnecessary_cast)]
#[no_mangle]
#[doc(hidden)]
pub unsafe extern "C" fn is_update_participant_dart(client: *mut super::FFIClient) -> c_uint {
if unsafe {
super::is_update_participant(client)
} {
1_u32 as c_uint
} else {
0_u32 as c_uint
}
}
}
#[cfg(test)]
mod tests {
use std::{ffi::CString, iter::FromIterator};
use num::rational::Ratio;
use super::*;
#[test]
fn test_new_client() {
let client = unsafe { new_client(CString::new("0.0.0.0:0000").unwrap().as_ptr(), 10) };
assert!(!client.is_null());
unsafe { drop_client(client, 0) };
}
#[test]
fn test_run_client() {
let client = unsafe { new_client(CString::new("0.0.0.0:0000").unwrap().as_ptr(), 10) };
assert_eq!(unsafe { run_client(client) }, 5);
unsafe { drop_client(client, 0) };
}
fn dummy_model(val: f64, len: usize) -> Model {
Model::from_iter(vec![Ratio::from_float(val).unwrap(); len].into_iter())
}
macro_rules! test_new_model {
($prim:ty, $dtype:expr) => {
paste::item! {
#[allow(unused_unsafe)]
#[test]
fn [<test_new_model_ $prim>]() {
let client = unsafe { new_client(CString::new("0.0.0.0:0000").unwrap().as_ptr(), 10) };
let model = dummy_model(0., 10);
let prim_model = unsafe { new_model(client, $dtype as c_uint, 10 as c_ulonglong) };
if let Some(CachedModel::[<$prim:upper>](ref cached_model)) = unsafe { &mut *client }.client.cached_model {
assert_eq!(prim_model.ptr, cached_model.as_ptr() as *mut c_void);
assert_eq!(prim_model.len, cached_model.len() as c_ulonglong);
assert_eq!(prim_model.dtype, $dtype as c_uint);
assert_eq!(model, Model::from_primitives_bounded(cached_model.iter().cloned()));
} else {
panic!();
}
unsafe { drop_client(client, 0) };
}
}
};
}
test_new_model!(f32, 1);
test_new_model!(f64, 2);
test_new_model!(i32, 3);
test_new_model!(i64, 4);
macro_rules! test_get_model {
($prim:ty, $dtype:expr) => {
paste::item! {
#[allow(unused_unsafe)]
#[test]
fn [<test_get_model_ $prim>]() {
let client = unsafe { new_client(CString::new("0.0.0.0:0000").unwrap().as_ptr(), 10) };
assert!(unsafe { &*client }.client.global_model.is_none());
let prim_model = unsafe { get_model(client, $dtype as c_uint) };
assert!(unsafe { &*client }.client.cached_model.is_none());
assert!(prim_model.ptr.is_null());
assert_eq!(prim_model.len, 0);
assert_eq!(prim_model.dtype, 0);
let model = dummy_model(0., 10);
unsafe { &mut *client }.client.global_model = Some(model.clone());
let prim_model = unsafe { get_model(client, $dtype as c_uint) };
if let Some(CachedModel::[<$prim:upper>](ref cached_model)) = unsafe { &mut *client }.client.cached_model {
assert_eq!(prim_model.ptr, cached_model.as_ptr() as *mut c_void);
assert_eq!(prim_model.len, cached_model.len() as c_ulonglong);
assert_eq!(prim_model.dtype, $dtype as c_uint);
assert_eq!(model, Model::from_primitives_bounded(cached_model.iter().cloned()));
} else {
panic!();
}
unsafe { drop_client(client, 0) };
}
}
};
}
test_get_model!(f32, 1);
test_get_model!(f64, 2);
test_get_model!(i32, 3);
test_get_model!(i64, 4);
macro_rules! test_update_model {
($prim:ty, $dtype:expr) => {
paste::item! {
#[test]
fn [<test_update_model_ $prim>]() {
let client = unsafe { new_client(CString::new("0.0.0.0:0000").unwrap().as_ptr(), 10) };
let model = dummy_model(0., 10);
unsafe { &mut *client }.client.global_model = Some(model.clone());
let prim_model = unsafe { get_model(client, $dtype as c_uint) };
if let Some(CachedModel::[<$prim:upper>](ref cached_model)) = unsafe { &mut *client }.client.cached_model {
assert_eq!(prim_model.ptr, cached_model.as_ptr() as *mut c_void);
assert_eq!(prim_model.len, cached_model.len() as c_ulonglong);
assert_eq!(prim_model.dtype, $dtype as c_uint);
} else {
panic!();
}
assert!(unsafe { &*client }.client.local_model.is_none());
assert_eq!(unsafe { update_model(client) }, 0);
assert!(unsafe { &mut *client }.client.cached_model.is_none());
if let Some(ref local_model) = unsafe { &*client }.client.local_model {
assert_eq!(&model, local_model);
} else {
panic!();
}
unsafe { drop_client(client, 0) };
}
}
};
}
test_update_model!(f32, 1);
test_update_model!(f64, 2);
test_update_model!(i32, 3);
test_update_model!(i64, 4);
}