use crate::benchmark;
use core::ffi::c_char;
pub use crate::public_api::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FfiApiError {
InvalidArgument,
EncodeFailed,
StorageWriteFailed,
}
impl From<benchmark::BenchmarkEncodeError> for FfiApiError {
fn from(_: benchmark::BenchmarkEncodeError) -> Self {
Self::EncodeFailed
}
}
impl From<benchmark::BenchmarkIoError> for FfiApiError {
fn from(_: benchmark::BenchmarkIoError) -> Self {
Self::StorageWriteFailed
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct FfiBenchmarkRecord<'a> {
pub model_name: &'a str,
pub precision: &'a str,
pub elapsed_ms: u64,
pub iterations: u64,
pub avg_loss: f32,
pub last_loss: f32,
pub output_bytes: usize,
pub train_samples: u64,
pub total_params: u64,
pub layer_count: u32,
pub input_dim: u32,
pub output_dim: u32,
pub benchmark_flags: u64,
pub weights_bytes: u64,
pub biases_bytes: u64,
pub min_loss: f32,
pub max_loss: f32,
pub loss_stddev: f32,
pub iterations_per_sec: f32,
pub samples_per_sec: f32,
}
impl<'a> FfiBenchmarkRecord<'a> {
pub fn to_runtime_metrics(&self) -> benchmark::BenchmarkMetrics<'a> {
benchmark::BenchmarkMetrics {
model_name: self.model_name,
precision: self.precision,
elapsed_ms: self.elapsed_ms,
iterations: self.iterations,
train_samples: self.train_samples,
avg_loss: self.avg_loss,
last_loss: self.last_loss,
output_bytes: self.output_bytes,
total_params: self.total_params,
layer_count: self.layer_count,
input_dim: self.input_dim,
output_dim: self.output_dim,
benchmark_flags: self.benchmark_flags,
weights_bytes: self.weights_bytes,
biases_bytes: self.biases_bytes,
min_loss: self.min_loss,
max_loss: self.max_loss,
loss_stddev: self.loss_stddev,
iterations_per_sec: self.iterations_per_sec,
samples_per_sec: self.samples_per_sec,
}
}
}
pub fn ffi_encode_benchmark_blob(
record: &FfiBenchmarkRecord<'_>,
out: &mut [u8],
) -> Result<usize, FfiApiError> {
if record.model_name.is_empty() || record.precision.is_empty() {
return Err(FfiApiError::InvalidArgument);
}
let metrics = record.to_runtime_metrics();
benchmark::encode_benchmark_blob(&metrics, out).map_err(FfiApiError::from)
}
pub fn ffi_persist_benchmark_blob(
storage: &mut dyn benchmark::BenchmarkStorage,
path: &str,
record: &FfiBenchmarkRecord<'_>,
scratch: &mut [u8],
) -> Result<usize, FfiApiError> {
if path.is_empty() || record.model_name.is_empty() || record.precision.is_empty() {
return Err(FfiApiError::InvalidArgument);
}
let metrics = record.to_runtime_metrics();
benchmark::persist_benchmark_blob(storage, path, &metrics, scratch).map_err(FfiApiError::from)
}
#[repr(C)]
pub struct RnnFfiBenchmarkRecord {
pub model_name_ptr: *const u8,
pub model_name_len: usize,
pub precision_ptr: *const u8,
pub precision_len: usize,
pub elapsed_ms: u64,
pub iterations: u64,
pub avg_loss: f32,
pub last_loss: f32,
pub output_bytes: u64,
pub train_samples: u64,
pub total_params: u64,
pub layer_count: u32,
pub input_dim: u32,
pub output_dim: u32,
pub benchmark_flags: u64,
pub weights_bytes: u64,
pub biases_bytes: u64,
pub min_loss: f32,
pub max_loss: f32,
pub loss_stddev: f32,
pub iterations_per_sec: f32,
pub samples_per_sec: f32,
}
#[repr(C)]
pub struct RnnFfiBenchmarkView {
pub model_name_ptr: *const u8,
pub model_name_len: usize,
pub precision_ptr: *const u8,
pub precision_len: usize,
pub elapsed_ms: u64,
pub iterations: u64,
pub avg_loss: f32,
pub last_loss: f32,
pub output_bytes: u64,
pub train_samples: u64,
pub total_params: u64,
pub layer_count: u32,
pub input_dim: u32,
pub output_dim: u32,
pub benchmark_flags: u64,
pub weights_bytes: u64,
pub biases_bytes: u64,
pub min_loss: f32,
pub max_loss: f32,
pub loss_stddev: f32,
pub iterations_per_sec: f32,
pub samples_per_sec: f32,
}
const RNN_FFI_OK: i32 = 0;
const RNN_FFI_NULL_POINTER: i32 = 1;
const RNN_FFI_INVALID_ARGUMENT: i32 = 2;
const RNN_FFI_BAD_BYTES: i32 = 3;
const RNN_FFI_CAPACITY_TOO_SMALL: i32 = 4;
const RNN_FFI_INTERNAL: i32 = 8;
fn ffi_code_from_encode_error(err: benchmark::BenchmarkEncodeError) -> i32 {
match err {
benchmark::BenchmarkEncodeError::BufferTooSmall => RNN_FFI_CAPACITY_TOO_SMALL,
benchmark::BenchmarkEncodeError::InvalidFormat => RNN_FFI_BAD_BYTES,
}
}
unsafe fn ffi_read_str<'a>(ptr: *const u8, len: usize) -> Result<&'a str, i32> {
if ptr.is_null() {
return Err(RNN_FFI_NULL_POINTER);
}
let bytes = unsafe { core::slice::from_raw_parts(ptr, len) };
core::str::from_utf8(bytes).map_err(|_| RNN_FFI_INVALID_ARGUMENT)
}
unsafe fn ffi_read_record<'a>(
rec_ptr: *const RnnFfiBenchmarkRecord,
) -> Result<benchmark::BenchmarkMetrics<'a>, i32> {
if rec_ptr.is_null() {
return Err(RNN_FFI_NULL_POINTER);
}
let rec = unsafe { &*rec_ptr };
let model_name = unsafe { ffi_read_str(rec.model_name_ptr, rec.model_name_len)? };
let precision = unsafe { ffi_read_str(rec.precision_ptr, rec.precision_len)? };
if model_name.is_empty() || precision.is_empty() {
return Err(RNN_FFI_INVALID_ARGUMENT);
}
Ok(benchmark::BenchmarkMetrics {
model_name,
precision,
elapsed_ms: rec.elapsed_ms,
iterations: rec.iterations,
train_samples: rec.train_samples,
avg_loss: rec.avg_loss,
last_loss: rec.last_loss,
output_bytes: rec.output_bytes as usize,
total_params: rec.total_params,
layer_count: rec.layer_count,
input_dim: rec.input_dim,
output_dim: rec.output_dim,
benchmark_flags: rec.benchmark_flags,
weights_bytes: rec.weights_bytes,
biases_bytes: rec.biases_bytes,
min_loss: rec.min_loss,
max_loss: rec.max_loss,
loss_stddev: rec.loss_stddev,
iterations_per_sec: rec.iterations_per_sec,
samples_per_sec: rec.samples_per_sec,
})
}
#[no_mangle]
pub extern "C" fn rnn_ffi_api_version() -> u32 {
1
}
#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn rnn_ffi_benchmark_encoded_size(
record: *const RnnFfiBenchmarkRecord,
out_size: *mut usize,
) -> i32 {
if out_size.is_null() {
return RNN_FFI_NULL_POINTER;
}
let metrics = match unsafe { ffi_read_record(record) } {
Ok(v) => v,
Err(code) => return code,
};
let size = match benchmark::encoded_size_benchmark_blob(&metrics) {
Some(v) => v,
None => return RNN_FFI_INTERNAL,
};
unsafe {
*out_size = size;
}
RNN_FFI_OK
}
#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn rnn_ffi_encode_benchmark_blob(
record: *const RnnFfiBenchmarkRecord,
out_ptr: *mut u8,
out_len: usize,
out_used: *mut usize,
) -> i32 {
if out_ptr.is_null() || out_used.is_null() {
return RNN_FFI_NULL_POINTER;
}
let metrics = match unsafe { ffi_read_record(record) } {
Ok(v) => v,
Err(code) => return code,
};
let out = unsafe { core::slice::from_raw_parts_mut(out_ptr, out_len) };
match benchmark::encode_benchmark_blob(&metrics, out) {
Ok(used) => {
unsafe {
*out_used = used;
}
RNN_FFI_OK
}
Err(err) => ffi_code_from_encode_error(err),
}
}
#[no_mangle]
#[allow(clippy::missing_safety_doc)]
pub unsafe extern "C" fn rnn_ffi_decode_benchmark_blob(
blob_ptr: *const u8,
blob_len: usize,
out_view: *mut RnnFfiBenchmarkView,
) -> i32 {
if blob_ptr.is_null() || out_view.is_null() {
return RNN_FFI_NULL_POINTER;
}
let blob = unsafe { core::slice::from_raw_parts(blob_ptr, blob_len) };
let decoded = match benchmark::decode_benchmark_blob(blob) {
Ok(v) => v,
Err(err) => return ffi_code_from_encode_error(err),
};
unsafe {
(*out_view).model_name_ptr = decoded.model_name.as_ptr();
(*out_view).model_name_len = decoded.model_name.len();
(*out_view).precision_ptr = decoded.precision.as_ptr();
(*out_view).precision_len = decoded.precision.len();
(*out_view).elapsed_ms = decoded.elapsed_ms;
(*out_view).iterations = decoded.iterations;
(*out_view).avg_loss = decoded.avg_loss;
(*out_view).last_loss = decoded.last_loss;
(*out_view).output_bytes = decoded.output_bytes as u64;
(*out_view).train_samples = decoded.train_samples;
(*out_view).total_params = decoded.total_params;
(*out_view).layer_count = decoded.layer_count;
(*out_view).input_dim = decoded.input_dim;
(*out_view).output_dim = decoded.output_dim;
(*out_view).benchmark_flags = decoded.benchmark_flags;
(*out_view).weights_bytes = decoded.weights_bytes;
(*out_view).biases_bytes = decoded.biases_bytes;
(*out_view).min_loss = decoded.min_loss;
(*out_view).max_loss = decoded.max_loss;
(*out_view).loss_stddev = decoded.loss_stddev;
(*out_view).iterations_per_sec = decoded.iterations_per_sec;
(*out_view).samples_per_sec = decoded.samples_per_sec;
}
RNN_FFI_OK
}
#[no_mangle]
pub extern "C" fn rnn_ffi_error_message(code: i32) -> *const c_char {
const OK: &[u8] = b"ok\0";
const NULL_POINTER: &[u8] = b"null pointer\0";
const INVALID_ARGUMENT: &[u8] = b"invalid argument\0";
const BAD_BYTES: &[u8] = b"bad bytes\0";
const CAPACITY: &[u8] = b"capacity too small\0";
const INTERNAL: &[u8] = b"internal error\0";
const UNKNOWN: &[u8] = b"unknown error\0";
let msg = match code {
RNN_FFI_OK => OK,
RNN_FFI_NULL_POINTER => NULL_POINTER,
RNN_FFI_INVALID_ARGUMENT => INVALID_ARGUMENT,
RNN_FFI_BAD_BYTES => BAD_BYTES,
RNN_FFI_CAPACITY_TOO_SMALL => CAPACITY,
RNN_FFI_INTERNAL => INTERNAL,
_ => UNKNOWN,
};
msg.as_ptr() as *const c_char
}