extern crate alloc;
use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use core::ffi::c_char;
use core::mem::size_of;
use crate::activations::ActivationKind;
use crate::engine::forward_dense_plan;
use crate::layers::{DenseLayerDesc, LayerPlan, LayerSpec};
use crate::rnn_api::{
rnn_dense_required_infer_scratch_from_specs,
rnn_required_dense_from_bytes_v1,
rnn_required_dense_from_topology,
rnn_unpack_dense_v1,
RnnApiError,
};
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RnnFfiCounts {
pub layers: u64,
pub weights: u64,
pub biases: u64,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RnnFfiModelInfo {
pub layers: u64,
pub weights: u64,
pub biases: u64,
pub input_size: u64,
pub output_size: u64,
}
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct RnnFfiAbiInfo {
pub abi_version: u32,
pub struct_size: u32,
pub flags: u64,
pub reserved0: u64,
pub reserved1: u64,
}
#[repr(C)]
pub struct RnnFfiModelHandle {
private: [u8; 0],
}
#[repr(i32)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RnnFfiCode {
Ok = 0,
NullPointer = 1,
InvalidArgument = 2,
BadBytes = 3,
CapacityTooSmall = 4,
Layer = 5,
Model = 6,
Forward = 7,
Internal = 8,
}
const ABI_VERSION: u32 = 2;
const ABI_FLAG_HANDLE_API: u64 = 1 << 0;
const ABI_FLAG_BATCH_API: u64 = 1 << 1;
pub const RNN_FFI_ABI_VERSION: u32 = ABI_VERSION;
pub const RNN_FFI_ABI_MIN_COMPAT_VERSION: u32 = 1;
pub const RNN_FFI_ABI_MAX_COMPAT_VERSION: u32 = ABI_VERSION;
pub const RNN_FFI_ABI_FLAG_HANDLE_API: u64 = ABI_FLAG_HANDLE_API;
pub const RNN_FFI_ABI_FLAG_BATCH_API: u64 = ABI_FLAG_BATCH_API;
struct RnnFfiDenseModel {
layers: Vec<LayerSpec>,
weights: Vec<f32>,
biases: Vec<f32>,
input_size: usize,
output_size: usize,
infer_scratch_len: usize,
}
impl RnnFfiDenseModel {
fn from_bytes(bytes: &[u8]) -> Result<Self, RnnFfiCode> {
let counts = rnn_required_dense_from_bytes_v1(bytes).map_err(map_rnn_error)?;
if counts.layers == 0 {
return Err(RnnFfiCode::BadBytes);
}
let placeholder = LayerSpec::Dense(DenseLayerDesc {
input_size: 1,
output_size: 1,
weight_offset: 0,
bias_offset: 0,
activation: ActivationKind::Identity,
});
let mut layer_specs = vec![placeholder; counts.layers];
let mut weights = vec![0.0f32; counts.weights];
let mut biases = vec![0.0f32; counts.biases];
let decoded = rnn_unpack_dense_v1(bytes, &mut layer_specs, &mut weights, &mut biases)
.map_err(map_rnn_error)?;
let layers = layer_specs[..decoded.layers].to_vec();
if layers.is_empty() {
return Err(RnnFfiCode::BadBytes);
}
let input_size = layers
.first()
.map(|l| l.input_size())
.ok_or(RnnFfiCode::BadBytes)?;
let output_size = layers
.last()
.map(|l| l.output_size())
.ok_or(RnnFfiCode::BadBytes)?;
let infer_scratch_len = rnn_dense_required_infer_scratch_from_specs(&layers)
.map_err(map_rnn_error)?;
Ok(Self {
layers,
weights: weights[..decoded.weights].to_vec(),
biases: biases[..decoded.biases].to_vec(),
input_size,
output_size,
infer_scratch_len,
})
}
fn info(&self) -> Result<RnnFfiModelInfo, RnnFfiCode> {
Ok(RnnFfiModelInfo {
layers: u64::try_from(self.layers.len()).map_err(|_| RnnFfiCode::CapacityTooSmall)?,
weights: u64::try_from(self.weights.len()).map_err(|_| RnnFfiCode::CapacityTooSmall)?,
biases: u64::try_from(self.biases.len()).map_err(|_| RnnFfiCode::CapacityTooSmall)?,
input_size: u64::try_from(self.input_size).map_err(|_| RnnFfiCode::CapacityTooSmall)?,
output_size: u64::try_from(self.output_size).map_err(|_| RnnFfiCode::CapacityTooSmall)?,
})
}
fn run_single(&self, input: &[f32], output: &mut [f32]) -> Result<(), RnnFfiCode> {
if input.len() != self.input_size || output.len() != self.output_size {
return Err(RnnFfiCode::InvalidArgument);
}
let mut scratch = vec![0.0f32; self.infer_scratch_len];
let plan = LayerPlan {
layers: &self.layers,
weights: &self.weights,
biases: &self.biases,
};
forward_dense_plan(&plan, input, output, &mut scratch).map_err(|_| RnnFfiCode::Forward)
}
fn run_batch(
&self,
input: &[f32],
output: &mut [f32],
batch_size: usize,
) -> Result<(), RnnFfiCode> {
if batch_size == 0 {
return Err(RnnFfiCode::InvalidArgument);
}
let expected_in = self
.input_size
.checked_mul(batch_size)
.ok_or(RnnFfiCode::CapacityTooSmall)?;
let expected_out = self
.output_size
.checked_mul(batch_size)
.ok_or(RnnFfiCode::CapacityTooSmall)?;
if input.len() != expected_in || output.len() != expected_out {
return Err(RnnFfiCode::InvalidArgument);
}
let mut scratch = vec![0.0f32; self.infer_scratch_len];
let plan = LayerPlan {
layers: &self.layers,
weights: &self.weights,
biases: &self.biases,
};
for row in 0..batch_size {
let in_off = row * self.input_size;
let out_off = row * self.output_size;
forward_dense_plan(
&plan,
&input[in_off..in_off + self.input_size],
&mut output[out_off..out_off + self.output_size],
&mut scratch,
)
.map_err(|_| RnnFfiCode::Forward)?;
}
Ok(())
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_api_version() -> u32 {
ABI_VERSION
}
#[no_mangle]
pub extern "C" fn rnn_ffi_is_abi_compatible(requested_abi_version: u32) -> i32 {
if requested_abi_version < RNN_FFI_ABI_MIN_COMPAT_VERSION {
return 0;
}
if requested_abi_version > RNN_FFI_ABI_MAX_COMPAT_VERSION {
return 0;
}
1
}
#[no_mangle]
pub extern "C" fn rnn_ffi_abi_info(out_info: *mut RnnFfiAbiInfo) -> i32 {
if out_info.is_null() {
return RnnFfiCode::NullPointer as i32;
}
unsafe {
*out_info = RnnFfiAbiInfo {
abi_version: ABI_VERSION,
struct_size: size_of::<RnnFfiAbiInfo>() as u32,
flags: ABI_FLAG_HANDLE_API | ABI_FLAG_BATCH_API,
reserved0: 0,
reserved1: 0,
};
}
RnnFfiCode::Ok as i32
}
#[no_mangle]
pub extern "C" fn rnn_ffi_required_dense_from_bytes_v1(
bytes_ptr: *const u8,
bytes_len: usize,
out_counts: *mut RnnFfiCounts,
) -> i32 {
if out_counts.is_null() {
return RnnFfiCode::NullPointer as i32;
}
let bytes = match unsafe { slice_from_ptr(bytes_ptr, bytes_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
match rnn_required_dense_from_bytes_v1(bytes) {
Ok(c) => {
let layers = match u64::try_from(c.layers) {
Ok(v) => v,
Err(_) => return RnnFfiCode::CapacityTooSmall as i32,
};
let weights = match u64::try_from(c.weights) {
Ok(v) => v,
Err(_) => return RnnFfiCode::CapacityTooSmall as i32,
};
let biases = match u64::try_from(c.biases) {
Ok(v) => v,
Err(_) => return RnnFfiCode::CapacityTooSmall as i32,
};
unsafe {
*out_counts = RnnFfiCounts {
layers,
weights,
biases,
};
}
RnnFfiCode::Ok as i32
}
Err(e) => map_rnn_error(e) as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_required_dense_from_topology(
topology_ptr: *const u64,
topology_len: usize,
out_counts: *mut RnnFfiCounts,
) -> i32 {
if out_counts.is_null() {
return RnnFfiCode::NullPointer as i32;
}
let topology_u64 = match unsafe { slice_from_ptr(topology_ptr, topology_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let mut topology = vec![0usize; topology_u64.len()];
for (i, dim) in topology_u64.iter().enumerate() {
topology[i] = match usize::try_from(*dim) {
Ok(v) => v,
Err(_) => return RnnFfiCode::InvalidArgument as i32,
};
}
match rnn_required_dense_from_topology(&topology) {
Ok(c) => {
let layers = match u64::try_from(c.layers) {
Ok(v) => v,
Err(_) => return RnnFfiCode::CapacityTooSmall as i32,
};
let weights = match u64::try_from(c.weights) {
Ok(v) => v,
Err(_) => return RnnFfiCode::CapacityTooSmall as i32,
};
let biases = match u64::try_from(c.biases) {
Ok(v) => v,
Err(_) => return RnnFfiCode::CapacityTooSmall as i32,
};
unsafe {
*out_counts = RnnFfiCounts {
layers,
weights,
biases,
};
}
RnnFfiCode::Ok as i32
}
Err(e) => map_rnn_error(e) as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_model_info_from_bytes_v1(
bytes_ptr: *const u8,
bytes_len: usize,
out_info: *mut RnnFfiModelInfo,
) -> i32 {
if out_info.is_null() {
return RnnFfiCode::NullPointer as i32;
}
let bytes = match unsafe { slice_from_ptr(bytes_ptr, bytes_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
match RnnFfiDenseModel::from_bytes(bytes) {
Ok(model) => match model.info() {
Ok(info) => {
unsafe { *out_info = info; }
RnnFfiCode::Ok as i32
}
Err(code) => code as i32,
},
Err(code) => code as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_model_create_from_bytes_v1(
bytes_ptr: *const u8,
bytes_len: usize,
out_handle: *mut *mut RnnFfiModelHandle,
) -> i32 {
if out_handle.is_null() {
return RnnFfiCode::NullPointer as i32;
}
let bytes = match unsafe { slice_from_ptr(bytes_ptr, bytes_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let model = match RnnFfiDenseModel::from_bytes(bytes) {
Ok(v) => v,
Err(code) => return code as i32,
};
let boxed = Box::new(model);
let raw_model = Box::into_raw(boxed);
unsafe {
*out_handle = raw_model as *mut RnnFfiModelHandle;
}
RnnFfiCode::Ok as i32
}
#[no_mangle]
pub extern "C" fn rnn_ffi_model_destroy(handle: *mut RnnFfiModelHandle) -> i32 {
if handle.is_null() {
return RnnFfiCode::NullPointer as i32;
}
unsafe {
let raw = handle as *mut RnnFfiDenseModel;
drop(Box::from_raw(raw));
}
RnnFfiCode::Ok as i32
}
#[no_mangle]
pub extern "C" fn rnn_ffi_model_get_info(
handle: *const RnnFfiModelHandle,
out_info: *mut RnnFfiModelInfo,
) -> i32 {
if out_info.is_null() {
return RnnFfiCode::NullPointer as i32;
}
let model = match unsafe { model_ref(handle) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
match model.info() {
Ok(info) => {
unsafe { *out_info = info; }
RnnFfiCode::Ok as i32
}
Err(code) => code as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_model_run_dense(
handle: *const RnnFfiModelHandle,
input_ptr: *const f32,
input_len: usize,
output_ptr: *mut f32,
output_len: usize,
) -> i32 {
let model = match unsafe { model_ref(handle) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let input = match unsafe { slice_from_ptr(input_ptr, input_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let output = match unsafe { slice_from_ptr_mut(output_ptr, output_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
match model.run_single(input, output) {
Ok(()) => RnnFfiCode::Ok as i32,
Err(code) => code as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_model_run_dense_batch(
handle: *const RnnFfiModelHandle,
input_ptr: *const f32,
input_len: usize,
output_ptr: *mut f32,
output_len: usize,
batch_size: usize,
) -> i32 {
let model = match unsafe { model_ref(handle) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let input = match unsafe { slice_from_ptr(input_ptr, input_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let output = match unsafe { slice_from_ptr_mut(output_ptr, output_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
match model.run_batch(input, output, batch_size) {
Ok(()) => RnnFfiCode::Ok as i32,
Err(code) => code as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_run_dense_v1(
bytes_ptr: *const u8,
bytes_len: usize,
input_ptr: *const f32,
input_len: usize,
output_ptr: *mut f32,
output_len: usize,
) -> i32 {
let bytes = match unsafe { slice_from_ptr(bytes_ptr, bytes_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let input = match unsafe { slice_from_ptr(input_ptr, input_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let output = match unsafe { slice_from_ptr_mut(output_ptr, output_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let model = match RnnFfiDenseModel::from_bytes(bytes) {
Ok(v) => v,
Err(code) => return code as i32,
};
match model.run_single(input, output) {
Ok(()) => RnnFfiCode::Ok as i32,
Err(code) => code as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_run_dense_batch_v1(
bytes_ptr: *const u8,
bytes_len: usize,
input_ptr: *const f32,
input_len: usize,
output_ptr: *mut f32,
output_len: usize,
batch_size: usize,
) -> i32 {
let bytes = match unsafe { slice_from_ptr(bytes_ptr, bytes_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let input = match unsafe { slice_from_ptr(input_ptr, input_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let output = match unsafe { slice_from_ptr_mut(output_ptr, output_len) } {
Some(v) => v,
None => return RnnFfiCode::NullPointer as i32,
};
let model = match RnnFfiDenseModel::from_bytes(bytes) {
Ok(v) => v,
Err(code) => return code as i32,
};
match model.run_batch(input, output, batch_size) {
Ok(()) => RnnFfiCode::Ok as i32,
Err(code) => code as i32,
}
}
#[no_mangle]
pub extern "C" fn rnn_ffi_error_message(code: i32) -> *const c_char {
match code {
x if x == RnnFfiCode::Ok as i32 => b"ok\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::NullPointer as i32 => b"null pointer\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::InvalidArgument as i32 => b"invalid argument\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::BadBytes as i32 => b"invalid model bytes\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::CapacityTooSmall as i32 => b"capacity too small\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::Layer as i32 => b"layer error\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::Model as i32 => b"model format error\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::Forward as i32 => b"forward error\0".as_ptr() as *const c_char,
x if x == RnnFfiCode::Internal as i32 => b"internal error\0".as_ptr() as *const c_char,
_ => b"unknown error\0".as_ptr() as *const c_char,
}
}
fn map_rnn_error(err: RnnApiError) -> RnnFfiCode {
match err {
RnnApiError::InvalidTopology => RnnFfiCode::InvalidArgument,
RnnApiError::CapacityTooSmall => RnnFfiCode::CapacityTooSmall,
RnnApiError::BadBytes => RnnFfiCode::BadBytes,
RnnApiError::Layer => RnnFfiCode::Layer,
RnnApiError::Model => RnnFfiCode::Model,
RnnApiError::Forward => RnnFfiCode::Forward,
}
}
unsafe fn slice_from_ptr<'a, T>(ptr: *const T, len: usize) -> Option<&'a [T]> {
if len == 0 {
return Some(&[]);
}
if ptr.is_null() {
return None;
}
Some(core::slice::from_raw_parts(ptr, len))
}
unsafe fn slice_from_ptr_mut<'a, T>(ptr: *mut T, len: usize) -> Option<&'a mut [T]> {
if len == 0 {
return Some(&mut []);
}
if ptr.is_null() {
return None;
}
Some(core::slice::from_raw_parts_mut(ptr, len))
}
unsafe fn model_ref<'a>(handle: *const RnnFfiModelHandle) -> Option<&'a RnnFfiDenseModel> {
if handle.is_null() {
return None;
}
let raw = handle as *const RnnFfiDenseModel;
Some(&*raw)
}