#![feature(fn_traits, unboxed_closures, min_specialization)]
#[macro_use]
pub use paste::paste;
#[macro_use]
pub use lazy_static::lazy_static;
pub use rmp_serde;
pub use uniffi::ffi::RustBuffer;
use core::pin::Pin;
pub use ray_rs_sys::*;
pub use ray_rs_sys::ray;
pub mod remote_functions;
pub use remote_functions::*;
pub use std::ffi::CString;
use std::{collections::HashMap, sync::Mutex, os::raw::c_char, ops::Deref};
use libloading::{Library, Symbol};
type InvokerFunction = extern "C" fn(RustBuffer) -> RustBuffer;
type FunctionPtrMap = HashMap<CString, Symbol<'static, InvokerFunction>>;
#[macro_export]
macro_rules! ray_info {
($($arg:tt)*) => {
util::log_internal(format!("[rust] {}:{}: {}", file!(), line!(), format!($($arg)*)));
}
}
lazy_static::lazy_static! {
static ref GLOBAL_FUNCTION_MAP: Mutex<FunctionPtrMap> = {
Mutex::new(HashMap::new())
};
}
lazy_static::lazy_static! {
static ref LIBRARIES: Mutex<Vec<Library>> = {
Mutex::new(Vec::new())
};
}
pub fn load_code_paths_from_cmdline(argc: i32, argv: *mut *mut c_char) {
let slice = unsafe { std::slice::from_raw_parts(argv, argc as usize) };
for ptr in slice {
let arg = unsafe { std::ffi::CStr::from_ptr(*ptr).to_str().unwrap() };
if arg.starts_with("--ray_code_search_path=") {
let (_, path_str) = arg.clone().split_at("--ray_code_search_path=".len());
let paths = path_str.split(":").collect();
load_libraries_from_paths(&paths);
}
}
}
pub fn load_libraries_from_paths(paths: &Vec<&str>) {
let mut libs = LIBRARIES.lock().unwrap();
for path in paths {
match unsafe { Library::new(path).ok() } {
Some(lib) => libs.push(lib),
None => panic!("Shared-object library not found at path: {}", path),
}
}
}
pub extern "C" fn rust_worker_execute(
_task_type: RayInt,
ray_function_info: RaySlice,
args: RaySlice,
return_values: RaySlice,
) {
let args_slice = unsafe {
std::slice::from_raw_parts(
args.data as *mut *mut DataValue,
args.len as usize,
)
};
let mut arg_ptrs = Vec::<u64>::new();
let mut arg_sizes = Vec::<u64>::new();
for &arg in args_slice {
unsafe {
arg_ptrs.push((*(*arg).data).p as u64);
arg_sizes.push((*(*arg).data).size as u64);
}
}
let args_buffer = RustBuffer::from_vec(rmp_serde::to_vec(&(&arg_ptrs, &arg_sizes)).unwrap());
let fn_name = std::mem::ManuallyDrop::new(
unsafe {
CString::from_raw(*(ray_function_info.data as *mut *mut std::os::raw::c_char))
}
);
let libs = LIBRARIES.lock().unwrap();
let mut fn_map = GLOBAL_FUNCTION_MAP.lock().unwrap();
let mut ret_ref = fn_map.get(fn_name.deref());
if let None = ret_ref {
for lib in libs.iter() {
let ret = unsafe {
lib.get::<InvokerFunction>(fn_name.to_str().unwrap().as_bytes()).ok()
};
ray_info!("Loaded function {} as {:?}", fn_name.to_str().unwrap(), ret);
if let Some(symbol) = ret {
let static_symbol = unsafe {
std::mem::transmute::<Symbol<_, >, Symbol<'static, InvokerFunction>>(symbol)
};
fn_map.insert(fn_name.deref().clone(), static_symbol);
ret_ref = fn_map.get(fn_name.deref());
}
}
} else {
ray_info!("Using cached library symbol for {}: {:?}", fn_name.to_str().unwrap(), ret_ref);
}
let func = ret_ref.expect(&format!("Could not find symbol for fn of name {}", fn_name.to_str().unwrap()));
ray_info!("Executing: {}", fn_name.to_str().unwrap());
let ret = func(args_buffer);
ray_info!("Executed: {}", fn_name.to_str().unwrap());
let mut ret_owned = std::mem::ManuallyDrop::new(ret.destroy_into_vec());
unsafe {
let mut dv_ptr = c_worker_AllocateDataValue(
ret_owned.as_mut_ptr(),
ret_owned.len() as u64,
std::ptr::null_mut::<u8>(),
0,
);
let ret_slice = std::slice::from_raw_parts(
return_values.data as *mut *mut DataValue,
return_values.len as usize,
);
let dv_ptr = c_worker_AllocateDataValue(
ret_owned.as_mut_ptr(),
ret_owned.len() as u64,
std::ptr::null_mut::<u8>(),
0,
);
(*ret_slice[0]).data = (*dv_ptr).data;
}
}