native_neural_network 0.3.1

Lib no_std Rust for native neural network (.rnn)
Documentation
use crate::benchmark;
use core::ffi::c_char;

pub use crate::public_api::*;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum FfiApiError {
    InvalidArgument,
    EncodeFailed,
    StorageWriteFailed,
}

impl From<benchmark::BenchmarkEncodeError> for FfiApiError {
    fn from(_: benchmark::BenchmarkEncodeError) -> Self {
        Self::EncodeFailed
    }
}

impl From<benchmark::BenchmarkIoError> for FfiApiError {
    fn from(_: benchmark::BenchmarkIoError) -> Self {
        Self::StorageWriteFailed
    }
}

#[derive(Clone, Copy, Debug, PartialEq)]
pub struct FfiBenchmarkRecord<'a> {
    pub model_name: &'a str,
    pub precision: &'a str,
    pub elapsed_ms: u64,
    pub iterations: u64,
    pub avg_loss: f32,
    pub last_loss: f32,
    pub output_bytes: usize,
    pub train_samples: u64,
    pub total_params: u64,
    pub layer_count: u32,
    pub input_dim: u32,
    pub output_dim: u32,
    pub benchmark_flags: u64,
    pub weights_bytes: u64,
    pub biases_bytes: u64,
    pub min_loss: f32,
    pub max_loss: f32,
    pub loss_stddev: f32,
    pub iterations_per_sec: f32,
    pub samples_per_sec: f32,
}

impl<'a> FfiBenchmarkRecord<'a> {
    pub fn to_runtime_metrics(&self) -> benchmark::BenchmarkMetrics<'a> {
        benchmark::BenchmarkMetrics {
            model_name: self.model_name,
            precision: self.precision,
            elapsed_ms: self.elapsed_ms,
            iterations: self.iterations,
            train_samples: self.train_samples,
            avg_loss: self.avg_loss,
            last_loss: self.last_loss,
            output_bytes: self.output_bytes,
            total_params: self.total_params,
            layer_count: self.layer_count,
            input_dim: self.input_dim,
            output_dim: self.output_dim,
            benchmark_flags: self.benchmark_flags,
            weights_bytes: self.weights_bytes,
            biases_bytes: self.biases_bytes,
            min_loss: self.min_loss,
            max_loss: self.max_loss,
            loss_stddev: self.loss_stddev,
            iterations_per_sec: self.iterations_per_sec,
            samples_per_sec: self.samples_per_sec,
        }
    }
}

pub fn ffi_encode_benchmark_blob(
    record: &FfiBenchmarkRecord<'_>,
    out: &mut [u8],
) -> Result<usize, FfiApiError> {
    if record.model_name.is_empty() || record.precision.is_empty() {
        return Err(FfiApiError::InvalidArgument);
    }

    let metrics = record.to_runtime_metrics();
    benchmark::encode_benchmark_blob(&metrics, out).map_err(FfiApiError::from)
}

pub fn ffi_persist_benchmark_blob(
    storage: &mut dyn benchmark::BenchmarkStorage,
    path: &str,
    record: &FfiBenchmarkRecord<'_>,
    scratch: &mut [u8],
) -> Result<usize, FfiApiError> {
    if path.is_empty() || record.model_name.is_empty() || record.precision.is_empty() {
        return Err(FfiApiError::InvalidArgument);
    }

    let metrics = record.to_runtime_metrics();
    benchmark::persist_benchmark_blob(storage, path, &metrics, scratch).map_err(FfiApiError::from)
}

#[repr(C)]
pub struct RnnFfiBenchmarkRecord {
    pub model_name_ptr: *const u8,
    pub model_name_len: usize,
    pub precision_ptr: *const u8,
    pub precision_len: usize,
    pub elapsed_ms: u64,
    pub iterations: u64,
    pub avg_loss: f32,
    pub last_loss: f32,
    pub output_bytes: u64,
    pub train_samples: u64,
    pub total_params: u64,
    pub layer_count: u32,
    pub input_dim: u32,
    pub output_dim: u32,
    pub benchmark_flags: u64,
    pub weights_bytes: u64,
    pub biases_bytes: u64,
    pub min_loss: f32,
    pub max_loss: f32,
    pub loss_stddev: f32,
    pub iterations_per_sec: f32,
    pub samples_per_sec: f32,
}

#[repr(C)]
pub struct RnnFfiBenchmarkView {
    pub model_name_ptr: *const u8,
    pub model_name_len: usize,
    pub precision_ptr: *const u8,
    pub precision_len: usize,
    pub elapsed_ms: u64,
    pub iterations: u64,
    pub avg_loss: f32,
    pub last_loss: f32,
    pub output_bytes: u64,
    pub train_samples: u64,
    pub total_params: u64,
    pub layer_count: u32,
    pub input_dim: u32,
    pub output_dim: u32,
    pub benchmark_flags: u64,
    pub weights_bytes: u64,
    pub biases_bytes: u64,
    pub min_loss: f32,
    pub max_loss: f32,
    pub loss_stddev: f32,
    pub iterations_per_sec: f32,
    pub samples_per_sec: f32,
}

const RNN_FFI_OK: i32 = 0;
const RNN_FFI_NULL_POINTER: i32 = 1;
const RNN_FFI_INVALID_ARGUMENT: i32 = 2;
const RNN_FFI_BAD_BYTES: i32 = 3;
const RNN_FFI_CAPACITY_TOO_SMALL: i32 = 4;
const RNN_FFI_INTERNAL: i32 = 8;

fn ffi_code_from_encode_error(err: benchmark::BenchmarkEncodeError) -> i32 {
    match err {
        benchmark::BenchmarkEncodeError::BufferTooSmall => RNN_FFI_CAPACITY_TOO_SMALL,
        benchmark::BenchmarkEncodeError::InvalidFormat => RNN_FFI_BAD_BYTES,
    }
}

unsafe fn ffi_read_str<'a>(ptr: *const u8, len: usize) -> Result<&'a str, i32> {
    if ptr.is_null() {
        return Err(RNN_FFI_NULL_POINTER);
    }
    let bytes = unsafe { core::slice::from_raw_parts(ptr, len) };
    core::str::from_utf8(bytes).map_err(|_| RNN_FFI_INVALID_ARGUMENT)
}

unsafe fn ffi_read_record<'a>(
    rec_ptr: *const RnnFfiBenchmarkRecord,
) -> Result<benchmark::BenchmarkMetrics<'a>, i32> {
    if rec_ptr.is_null() {
        return Err(RNN_FFI_NULL_POINTER);
    }
    let rec = unsafe { &*rec_ptr };

    let model_name = unsafe { ffi_read_str(rec.model_name_ptr, rec.model_name_len)? };
    let precision = unsafe { ffi_read_str(rec.precision_ptr, rec.precision_len)? };
    if model_name.is_empty() || precision.is_empty() {
        return Err(RNN_FFI_INVALID_ARGUMENT);
    }

    Ok(benchmark::BenchmarkMetrics {
        model_name,
        precision,
        elapsed_ms: rec.elapsed_ms,
        iterations: rec.iterations,
        train_samples: rec.train_samples,
        avg_loss: rec.avg_loss,
        last_loss: rec.last_loss,
        output_bytes: rec.output_bytes as usize,
        total_params: rec.total_params,
        layer_count: rec.layer_count,
        input_dim: rec.input_dim,
        output_dim: rec.output_dim,
        benchmark_flags: rec.benchmark_flags,
        weights_bytes: rec.weights_bytes,
        biases_bytes: rec.biases_bytes,
        min_loss: rec.min_loss,
        max_loss: rec.max_loss,
        loss_stddev: rec.loss_stddev,
        iterations_per_sec: rec.iterations_per_sec,
        samples_per_sec: rec.samples_per_sec,
    })
}

#[no_mangle]
pub extern "C" fn rnn_ffi_api_version() -> u32 {
    1
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
/// # Safety
///
/// `record` must point to a valid `RnnFfiBenchmarkRecord`.
/// `out_size` must be a valid writable pointer.
pub unsafe extern "C" fn rnn_ffi_benchmark_encoded_size(
    record: *const RnnFfiBenchmarkRecord,
    out_size: *mut usize,
) -> i32 {
    if out_size.is_null() {
        return RNN_FFI_NULL_POINTER;
    }

    let metrics = match unsafe { ffi_read_record(record) } {
        Ok(v) => v,
        Err(code) => return code,
    };

    let size = match benchmark::encoded_size_benchmark_blob(&metrics) {
        Some(v) => v,
        None => return RNN_FFI_INTERNAL,
    };

    unsafe {
        *out_size = size;
    }
    RNN_FFI_OK
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
/// # Safety
///
/// `record` must point to a valid `RnnFfiBenchmarkRecord`.
/// `out_ptr` must be writable for `out_len` bytes.
/// `out_used` must be a valid writable pointer.
pub unsafe extern "C" fn rnn_ffi_encode_benchmark_blob(
    record: *const RnnFfiBenchmarkRecord,
    out_ptr: *mut u8,
    out_len: usize,
    out_used: *mut usize,
) -> i32 {
    if out_ptr.is_null() || out_used.is_null() {
        return RNN_FFI_NULL_POINTER;
    }

    let metrics = match unsafe { ffi_read_record(record) } {
        Ok(v) => v,
        Err(code) => return code,
    };

    let out = unsafe { core::slice::from_raw_parts_mut(out_ptr, out_len) };
    match benchmark::encode_benchmark_blob(&metrics, out) {
        Ok(used) => {
            unsafe {
                *out_used = used;
            }
            RNN_FFI_OK
        }
        Err(err) => ffi_code_from_encode_error(err),
    }
}

#[no_mangle]
#[allow(clippy::missing_safety_doc)]
/// # Safety
///
/// `blob_ptr` must point to `blob_len` readable bytes.
/// `out_view` must be a valid writable pointer.
pub unsafe extern "C" fn rnn_ffi_decode_benchmark_blob(
    blob_ptr: *const u8,
    blob_len: usize,
    out_view: *mut RnnFfiBenchmarkView,
) -> i32 {
    if blob_ptr.is_null() || out_view.is_null() {
        return RNN_FFI_NULL_POINTER;
    }

    let blob = unsafe { core::slice::from_raw_parts(blob_ptr, blob_len) };
    let decoded = match benchmark::decode_benchmark_blob(blob) {
        Ok(v) => v,
        Err(err) => return ffi_code_from_encode_error(err),
    };

    unsafe {
        (*out_view).model_name_ptr = decoded.model_name.as_ptr();
        (*out_view).model_name_len = decoded.model_name.len();
        (*out_view).precision_ptr = decoded.precision.as_ptr();
        (*out_view).precision_len = decoded.precision.len();
        (*out_view).elapsed_ms = decoded.elapsed_ms;
        (*out_view).iterations = decoded.iterations;
        (*out_view).avg_loss = decoded.avg_loss;
        (*out_view).last_loss = decoded.last_loss;
        (*out_view).output_bytes = decoded.output_bytes as u64;
        (*out_view).train_samples = decoded.train_samples;
        (*out_view).total_params = decoded.total_params;
        (*out_view).layer_count = decoded.layer_count;
        (*out_view).input_dim = decoded.input_dim;
        (*out_view).output_dim = decoded.output_dim;
        (*out_view).benchmark_flags = decoded.benchmark_flags;
        (*out_view).weights_bytes = decoded.weights_bytes;
        (*out_view).biases_bytes = decoded.biases_bytes;
        (*out_view).min_loss = decoded.min_loss;
        (*out_view).max_loss = decoded.max_loss;
        (*out_view).loss_stddev = decoded.loss_stddev;
        (*out_view).iterations_per_sec = decoded.iterations_per_sec;
        (*out_view).samples_per_sec = decoded.samples_per_sec;
    }

    RNN_FFI_OK
}

#[no_mangle]
pub extern "C" fn rnn_ffi_error_message(code: i32) -> *const c_char {
    const OK: &[u8] = b"ok\0";
    const NULL_POINTER: &[u8] = b"null pointer\0";
    const INVALID_ARGUMENT: &[u8] = b"invalid argument\0";
    const BAD_BYTES: &[u8] = b"bad bytes\0";
    const CAPACITY: &[u8] = b"capacity too small\0";
    const INTERNAL: &[u8] = b"internal error\0";
    const UNKNOWN: &[u8] = b"unknown error\0";

    let msg = match code {
        RNN_FFI_OK => OK,
        RNN_FFI_NULL_POINTER => NULL_POINTER,
        RNN_FFI_INVALID_ARGUMENT => INVALID_ARGUMENT,
        RNN_FFI_BAD_BYTES => BAD_BYTES,
        RNN_FFI_CAPACITY_TOO_SMALL => CAPACITY,
        RNN_FFI_INTERNAL => INTERNAL,
        _ => UNKNOWN,
    };
    msg.as_ptr() as *const c_char
}