#![allow(clippy::missing_safety_doc)]
use anyhow::Context;
use tract_libcli::annotations::Annotations;
use tract_libcli::profile::BenchLimits;
use std::cell::RefCell;
use std::convert::TryFrom;
use std::ffi::{c_char, c_void, CStr, CString};
use std::sync::Arc;
use tract_data::internal::parse_tdim;
use tract_pulse::model::{PulsedModel, PulsedModelExt};
use tract_nnef::internal as native;
use tract_nnef::tract_core::prelude::*;
use tract_onnx::prelude::InferenceModelExt;
use tract_onnx::prelude::{self as onnx, InferenceFact};
#[repr(C)]
#[allow(non_camel_case_types)]
#[derive(Debug, PartialEq, Eq)]
pub enum TRACT_RESULT {
TRACT_RESULT_OK = 0,
TRACT_RESULT_KO = 1,
}
#[repr(C)]
#[allow(non_camel_case_types)]
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum TractDatumType {
TRACT_DATUM_TYPE_BOOL = 0x01,
TRACT_DATUM_TYPE_U8 = 0x11,
TRACT_DATUM_TYPE_U16 = 0x12,
TRACT_DATUM_TYPE_U32 = 0x14,
TRACT_DATUM_TYPE_U64 = 0x18,
TRACT_DATUM_TYPE_I8 = 0x21,
TRACT_DATUM_TYPE_I16 = 0x22,
TRACT_DATUM_TYPE_I32 = 0x24,
TRACT_DATUM_TYPE_I64 = 0x28,
TRACT_DATUM_TYPE_F16 = 0x32,
TRACT_DATUM_TYPE_F32 = 0x34,
TRACT_DATUM_TYPE_F64 = 0x38,
TRACT_DATUM_TYPE_COMPLEX_I16 = 0x42,
TRACT_DATUM_TYPE_COMPLEX_I32 = 0x44,
TRACT_DATUM_TYPE_COMPLEX_I64 = 0x48,
TRACT_DATUM_TYPE_COMPLEX_F16 = 0x52,
TRACT_DATUM_TYPE_COMPLEX_F32 = 0x54,
TRACT_DATUM_TYPE_COMPLEX_F64 = 0x58,
}
impl From<TractDatumType> for DatumType {
fn from(it: TractDatumType) -> Self {
use DatumType::*;
use TractDatumType::*;
match it {
TRACT_DATUM_TYPE_BOOL => Bool,
TRACT_DATUM_TYPE_U8 => U8,
TRACT_DATUM_TYPE_U16 => U16,
TRACT_DATUM_TYPE_U32 => U32,
TRACT_DATUM_TYPE_U64 => U64,
TRACT_DATUM_TYPE_I8 => I8,
TRACT_DATUM_TYPE_I16 => I16,
TRACT_DATUM_TYPE_I32 => I32,
TRACT_DATUM_TYPE_I64 => I64,
TRACT_DATUM_TYPE_F16 => F16,
TRACT_DATUM_TYPE_F32 => F32,
TRACT_DATUM_TYPE_F64 => F64,
TRACT_DATUM_TYPE_COMPLEX_I16 => ComplexI16,
TRACT_DATUM_TYPE_COMPLEX_I32 => ComplexI32,
TRACT_DATUM_TYPE_COMPLEX_I64 => ComplexI64,
TRACT_DATUM_TYPE_COMPLEX_F16 => ComplexF16,
TRACT_DATUM_TYPE_COMPLEX_F32 => ComplexF32,
TRACT_DATUM_TYPE_COMPLEX_F64 => ComplexF64,
}
}
}
impl TryFrom<DatumType> for TractDatumType {
type Error = TractError;
fn try_from(it: DatumType) -> TractResult<Self> {
use DatumType::*;
use TractDatumType::*;
match it {
Bool => Ok(TRACT_DATUM_TYPE_BOOL),
U8 => Ok(TRACT_DATUM_TYPE_U8),
U16 => Ok(TRACT_DATUM_TYPE_U16),
U32 => Ok(TRACT_DATUM_TYPE_U32),
U64 => Ok(TRACT_DATUM_TYPE_U64),
I8 => Ok(TRACT_DATUM_TYPE_I8),
I16 => Ok(TRACT_DATUM_TYPE_I16),
I32 => Ok(TRACT_DATUM_TYPE_I32),
I64 => Ok(TRACT_DATUM_TYPE_I64),
F16 => Ok(TRACT_DATUM_TYPE_F16),
F32 => Ok(TRACT_DATUM_TYPE_F32),
F64 => Ok(TRACT_DATUM_TYPE_F64),
ComplexI16 => Ok(TRACT_DATUM_TYPE_COMPLEX_I16),
ComplexI32 => Ok(TRACT_DATUM_TYPE_COMPLEX_I32),
ComplexI64 => Ok(TRACT_DATUM_TYPE_COMPLEX_I64),
ComplexF16 => Ok(TRACT_DATUM_TYPE_COMPLEX_F16),
ComplexF32 => Ok(TRACT_DATUM_TYPE_COMPLEX_F32),
ComplexF64 => Ok(TRACT_DATUM_TYPE_COMPLEX_F64),
_ => anyhow::bail!("tract C bindings do not support {:?} type", it),
}
}
}
thread_local! {
pub(crate) static LAST_ERROR: RefCell<Option<CString>> = RefCell::new(None);
}
fn wrap<F: FnOnce() -> anyhow::Result<()>>(func: F) -> TRACT_RESULT {
match func() {
Ok(_) => TRACT_RESULT::TRACT_RESULT_OK,
Err(e) => {
let msg = format!("{e:?}");
if std::env::var("TRACT_ERROR_STDERR").is_ok() {
eprintln!("{msg}");
}
LAST_ERROR.with(|p| {
*p.borrow_mut() = Some(CString::new(msg).unwrap_or_else(|_| {
CString::new("tract error message contains 0, can't convert to CString")
.unwrap()
}))
});
TRACT_RESULT::TRACT_RESULT_KO
}
}
}
#[no_mangle]
pub extern "C" fn tract_get_last_error() -> *const std::ffi::c_char {
LAST_ERROR.with(|msg| msg.borrow().as_ref().map(|s| s.as_ptr()).unwrap_or(std::ptr::null()))
}
#[no_mangle]
pub extern "C" fn tract_version() -> *const std::ffi::c_char {
unsafe {
CStr::from_bytes_with_nul_unchecked(concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes())
.as_ptr()
}
}
#[no_mangle]
pub unsafe extern "C" fn tract_free_cstring(ptr: *mut std::ffi::c_char) {
unsafe {
if !ptr.is_null() {
let _ = CString::from_raw(ptr);
}
}
}
macro_rules! check_not_null {
($($ptr:expr),*) => {
$(
if $ptr.is_null() {
anyhow::bail!(concat!("Unexpected null pointer ", stringify!($ptr)));
}
)*
}
}
macro_rules! release {
($ptr:expr) => {
wrap(|| unsafe {
check_not_null!($ptr, *$ptr);
let _ = Box::from_raw(*$ptr);
*$ptr = std::ptr::null_mut();
Ok(())
})
};
}
pub struct TractNnef(native::Nnef);
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_create(nnef: *mut *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
*nnef = Box::into_raw(Box::new(TractNnef(tract_nnef::nnef())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_enable_tract_core(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
(*nnef).0.enable_tract_core();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_enable_onnx(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
use tract_onnx::WithOnnx;
(*nnef).0.enable_onnx();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_enable_pulse(nnef: *mut TractNnef) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef);
use tract_pulse::WithPulse;
(*nnef).0.enable_pulse();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_destroy(nnef: *mut *mut TractNnef) -> TRACT_RESULT {
release!(nnef)
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_model_for_path(
nnef: *const TractNnef,
path: *const c_char,
model: *mut *mut TractModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef, model, path);
*model = std::ptr::null_mut();
let path = CStr::from_ptr(path).to_str()?;
let m = Box::new(TractModel(
(*nnef).0.model_for_path(path).with_context(|| format!("opening file {path:?}"))?,
));
*model = Box::into_raw(m);
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_write_model_to_tar(
nnef: *const TractNnef,
path: *const c_char,
model: *const TractModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef, model, path);
let path = CStr::from_ptr(path).to_str()?;
let f = std::fs::File::create(path).with_context(|| format!("creating file {path:?}"))?;
(*nnef).0.write_to_tar(&(*model).0, f)?;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_write_model_to_tar_gz(
nnef: *const TractNnef,
path: *const c_char,
model: *const TractModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef, model, path);
let path = CStr::from_ptr(path).to_str()?;
let f = std::fs::File::create(path).with_context(|| format!("creating file {path:?}"))?;
let f = flate2::write::GzEncoder::new(f, flate2::Compression::default());
(*nnef).0.write_to_tar(&(*model).0, f)?;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_nnef_write_model_to_dir(
nnef: *const TractNnef,
path: *const c_char,
model: *const TractModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(nnef, model, path);
let path = CStr::from_ptr(path).to_str()?;
(*nnef)
.0
.write_to_dir(&(*model).0, path)
.with_context(|| format!("writing model to dir {path:?}"))?;
Ok(())
})
}
pub struct TractOnnx(tract_onnx::Onnx);
#[no_mangle]
pub unsafe extern "C" fn tract_onnx_create(onnx: *mut *mut TractOnnx) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(onnx);
*onnx = Box::into_raw(Box::new(TractOnnx(onnx::onnx())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_onnx_destroy(onnx: *mut *mut TractOnnx) -> TRACT_RESULT {
release!(onnx)
}
#[no_mangle]
pub unsafe extern "C" fn tract_onnx_model_for_path(
onnx: *const TractOnnx,
path: *const c_char,
model: *mut *mut TractInferenceModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(onnx, path, model);
*model = std::ptr::null_mut();
let path = CStr::from_ptr(path).to_str()?;
let m = Box::new(TractInferenceModel(
(*onnx).0.model_for_path(path).with_context(|| format!("opening file {path:?}"))?,
));
*model = Box::into_raw(m);
Ok(())
})
}
pub struct TractInferenceModel(onnx::InferenceModel);
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_nbio(
model: *const TractInferenceModel,
inputs: *mut usize,
outputs: *mut usize,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model);
let model = &(*model).0;
if !inputs.is_null() {
*inputs = model.input_outlets()?.len()
}
if !outputs.is_null() {
*outputs = model.output_outlets()?.len()
}
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_input_name(
model: *const TractInferenceModel,
input: usize,
name: *mut *mut c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, name);
*name = std::ptr::null_mut();
let m = &(*model).0;
let outlet = m.input_outlets()?[input];
*name = CString::new(&*m.nodes[outlet.node].name)?.into_raw();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_output_name(
model: *const TractInferenceModel,
output: usize,
name: *mut *mut i8,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, name);
*name = std::ptr::null_mut();
let m = &(*model).0;
let outlet = m.output_outlets()?[output];
*name = CString::new(&*m.nodes[outlet.node].name)?.into_raw();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_input_fact(
model: *const TractInferenceModel,
input_id: usize,
fact: *mut *mut TractInferenceFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, fact);
*fact = std::ptr::null_mut();
let f = (*model).0.input_fact(input_id)?;
*fact = Box::into_raw(Box::new(TractInferenceFact(f.clone())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_set_input_fact(
model: *mut TractInferenceModel,
input_id: usize,
fact: *const TractInferenceFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model);
let f = fact.as_ref().map(|f| &f.0).cloned().unwrap_or_default();
(*model).0.set_input_fact(input_id, f)?;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_output_fact(
model: *const TractInferenceModel,
output_id: usize,
fact: *mut *mut TractInferenceFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, fact);
*fact = std::ptr::null_mut();
let f = (*model).0.output_fact(output_id)?;
*fact = Box::into_raw(Box::new(TractInferenceFact(f.clone())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_set_output_fact(
model: *mut TractInferenceModel,
output_id: usize,
fact: *const TractInferenceFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model);
let f = fact.as_ref().map(|f| &f.0).cloned().unwrap_or_default();
(*model).0.set_output_fact(output_id, f)?;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_analyse(
model: *mut TractInferenceModel,
obstinate: bool,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model);
(*model).0.analyse(obstinate)?;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_into_optimized(
model: *mut *mut TractInferenceModel,
optimized: *mut *mut TractModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, *model, optimized);
*optimized = std::ptr::null_mut();
let m = Box::from_raw(*model);
*model = std::ptr::null_mut();
let result = m.0.into_optimized()?;
*optimized = Box::into_raw(Box::new(TractModel(result))) as _;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_into_typed(
model: *mut *mut TractInferenceModel,
typed: *mut *mut TractModel,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, *model, typed);
*typed = std::ptr::null_mut();
let m = Box::from_raw(*model);
*model = std::ptr::null_mut();
let result = m.0.into_typed()?;
*typed = Box::into_raw(Box::new(TractModel(result))) as _;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_model_destroy(
model: *mut *mut TractInferenceModel,
) -> TRACT_RESULT {
release!(model)
}
pub struct TractModel(TypedModel);
#[no_mangle]
pub unsafe extern "C" fn tract_model_nbio(
model: *const TractModel,
inputs: *mut usize,
outputs: *mut usize,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model);
let model = &(*model).0;
if !inputs.is_null() {
*inputs = model.input_outlets()?.len()
}
if !outputs.is_null() {
*outputs = model.output_outlets()?.len()
}
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_input_name(
model: *const TractModel,
input: usize,
name: *mut *mut c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, name);
*name = std::ptr::null_mut();
let m = &(*model).0;
let outlet = m.input_outlets()?[input];
*name = CString::new(&*m.nodes[outlet.node].name)?.into_raw();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_input_fact(
model: *const TractModel,
input_id: usize,
fact: *mut *mut TractFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, fact);
*fact = std::ptr::null_mut();
let f = (*model).0.input_fact(input_id)?;
*fact = Box::into_raw(Box::new(TractFact(f.clone())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_output_name(
model: *const TractModel,
output: usize,
name: *mut *mut c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, name);
*name = std::ptr::null_mut();
let m = &(*model).0;
let outlet = m.output_outlets()?[output];
*name = CString::new(&*m.nodes[outlet.node].name)?.into_raw();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_output_fact(
model: *const TractModel,
input_id: usize,
fact: *mut *mut TractFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, fact);
*fact = std::ptr::null_mut();
let f = (*model).0.output_fact(input_id)?;
*fact = Box::into_raw(Box::new(TractFact(f.clone())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_concretize_symbols(
model: *mut TractModel,
nb_symbols: usize,
symbols: *const *const i8,
values: *const i64,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, symbols, values);
let mut table = SymbolValues::default();
let model = &mut (*model).0;
for i in 0..nb_symbols {
let name = CStr::from_ptr(*symbols.add(i)).to_str().with_context(|| {
format!("failed to parse symbol name for {i}th symbol (not utf8)")
})?;
table = table.with(&model.symbol_table.sym(name), *values.add(i));
}
let mut new = model.concretize_dims(&table)?;
std::mem::swap(model, &mut new);
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_pulse_simple(
model: *mut *mut TractModel,
stream_symbol: *const i8,
pulse_expr: *const i8,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, *model, stream_symbol, pulse_expr);
let model = &mut (**model).0;
let stream_sym = model.symbol_table.sym(
CStr::from_ptr(stream_symbol)
.to_str()
.context("failed to parse stream symbol name (not utf8)")?,
);
let pulse_dim = parse_tdim(
&model.symbol_table,
CStr::from_ptr(pulse_expr)
.to_str()
.context("failed to parse stream symbol name (not utf8)")?,
)?;
let mut pulsed = PulsedModel::new(model, stream_sym, &pulse_dim)?.into_typed()?;
std::mem::swap(model, &mut pulsed);
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_declutter(model: *mut TractModel) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model);
(*model).0.declutter()
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_optimize(model: *mut TractModel) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model);
(*model).0.optimize()
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_profile_json(
model: *mut TractModel,
inputs: *mut *mut TractValue,
json: *mut *mut i8,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, json);
let model = &(*model).0;
let mut annotations = Annotations::from_model(model)?;
tract_libcli::profile::extract_costs(&mut annotations, model)?;
if !inputs.is_null() {
let input_len = model.inputs.len();
let values:TVec<TValue> =
std::slice::from_raw_parts(inputs, input_len).iter().map(|tv| (**tv).0.clone()).collect();
tract_libcli::profile::profile(model, &BenchLimits::default(), &mut annotations, &values)?;
}
let export = tract_libcli::export::GraphPerfInfo::from(model, &annotations);
*json = CString::new(serde_json::to_string(&export)?)?.into_raw();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_into_runnable(
model: *mut *mut TractModel,
runnable: *mut *mut TractRunnable,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, runnable);
*runnable = std::ptr::null_mut();
let m = Box::from_raw(*model);
*model = std::ptr::null_mut();
let runnable_model = m.0.into_runnable()?;
*runnable = Box::into_raw(Box::new(TractRunnable(Arc::new(runnable_model)))) as _;
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_property_count(
model: *const TractModel,
count: *mut usize,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, count);
*count = (*model).0.properties.len();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_property_names(
model: *const TractModel,
names: *mut *mut i8,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, names);
for (ix, name) in (*model).0.properties.keys().enumerate() {
*names.add(ix) = CString::new(&**name)?.into_raw();
}
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_property(
model: *const TractModel,
name: *const i8,
value: *mut *mut TractValue,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, name, value);
let name = CStr::from_ptr(name)
.to_str()
.context("failed to parse property name (not utf8)")?
.to_owned();
let v = (*model).0.properties.get(&name).context("Property not found")?;
*value = Box::into_raw(Box::new(TractValue(v.clone().into_tvalue())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_model_destroy(model: *mut *mut TractModel) -> TRACT_RESULT {
release!(model)
}
pub struct TractRunnable(Arc<native::TypedRunnableModel<native::TypedModel>>);
#[no_mangle]
pub unsafe extern "C" fn tract_runnable_spawn_state(
runnable: *mut TractRunnable,
state: *mut *mut TractState,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(runnable, state);
*state = std::ptr::null_mut();
let s = native::TypedSimpleState::new((*runnable).0.clone())?;
*state = Box::into_raw(Box::new(TractState(s)));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_runnable_run(
runnable: *mut TractRunnable,
inputs: *mut *mut TractValue,
outputs: *mut *mut TractValue,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(runnable);
let mut s = native::TypedSimpleState::new((*runnable).0.clone())?;
state_run(&mut s, inputs, outputs)
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_runnable_nbio(
runnable: *const TractRunnable,
inputs: *mut usize,
outputs: *mut usize,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(runnable);
let model = (*runnable).0.model();
if !inputs.is_null() {
*inputs = model.input_outlets()?.len()
}
if !outputs.is_null() {
*outputs = model.output_outlets()?.len()
}
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_runnable_release(runnable: *mut *mut TractRunnable) -> TRACT_RESULT {
release!(runnable)
}
pub struct TractValue(TValue);
#[no_mangle]
pub unsafe extern "C" fn tract_value_create(
datum_type: TractDatumType,
rank: usize,
shape: *const usize,
data: *mut c_void,
value: *mut *mut TractValue,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(value);
*value = std::ptr::null_mut();
let dt: DatumType = datum_type.into();
let shape = std::slice::from_raw_parts(shape, rank);
let len = shape.iter().product::<usize>();
let content = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
let it = Tensor::from_raw_dt(dt, shape, content)?;
*value = Box::into_raw(Box::new(TractValue(it.into_tvalue())));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_value_destroy(value: *mut *mut TractValue) -> TRACT_RESULT {
release!(value)
}
#[no_mangle]
pub unsafe extern "C" fn tract_value_inspect(
value: *mut TractValue,
datum_type: *mut TractDatumType,
rank: *mut usize,
shape: *mut *const usize,
data: *mut *const c_void,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(value);
let value: &TValue = &(*value).0;
if !datum_type.is_null() {
*datum_type = value.datum_type().try_into()?;
}
if !rank.is_null() {
*rank = value.rank();
}
if !shape.is_null() {
*shape = value.shape().as_ptr();
}
if !data.is_null() {
*data = value.as_ptr_unchecked::<u8>() as _;
}
Ok(())
})
}
type NativeState = native::TypedSimpleState<
native::TypedModel,
Arc<native::TypedRunnableModel<native::TypedModel>>,
>;
pub struct TractState(NativeState);
#[no_mangle]
pub unsafe extern "C" fn tract_state_run(
state: *mut TractState,
inputs: *mut *mut TractValue,
outputs: *mut *mut TractValue,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(state, inputs, outputs);
state_run(&mut (*state).0, inputs, outputs)
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_state_destroy(state: *mut *mut TractState) -> TRACT_RESULT {
release!(state)
}
pub struct TractFact(TypedFact);
#[no_mangle]
pub unsafe extern "C" fn tract_fact_parse(
model: *mut TractModel,
spec: *const c_char,
fact: *mut *mut TractFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, spec, fact);
let spec = CStr::from_ptr(spec).to_str()?;
let f = tract_libcli::tensor::parse_spec(&(*model).0.symbol_table, spec)?
.to_typed_fact()?
.into_owned();
*fact = Box::into_raw(Box::new(TractFact(f)));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_fact_dump(
fact: *const TractFact,
spec: *mut *mut c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(fact, spec);
*spec = CString::new(format!("{:?}", (*fact).0))?.into_raw();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_fact_destroy(fact: *mut *mut TractFact) -> TRACT_RESULT {
release!(fact)
}
pub struct TractInferenceFact(InferenceFact);
#[no_mangle]
pub unsafe extern "C" fn tract_inference_fact_parse(
model: *mut TractInferenceModel,
spec: *const c_char,
fact: *mut *mut TractInferenceFact,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(model, spec, fact);
let spec = CStr::from_ptr(spec).to_str()?;
let f = tract_libcli::tensor::parse_spec(&(*model).0.symbol_table, spec)?;
*fact = Box::into_raw(Box::new(TractInferenceFact(f)));
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_fact_dump(
fact: *const TractInferenceFact,
spec: *mut *mut c_char,
) -> TRACT_RESULT {
wrap(|| unsafe {
check_not_null!(fact, spec);
*spec = CString::new(format!("{:?}", (*fact).0))?.into_raw();
Ok(())
})
}
#[no_mangle]
pub unsafe extern "C" fn tract_inference_fact_destroy(
fact: *mut *mut TractInferenceFact,
) -> TRACT_RESULT {
release!(fact)
}
unsafe fn state_run(
state: &mut NativeState,
inputs: *mut *mut TractValue,
outputs: *mut *mut TractValue,
) -> TractResult<()> {
let input_len = state.model().inputs.len();
let values =
std::slice::from_raw_parts(inputs, input_len).iter().map(|tv| (**tv).0.clone()).collect();
let values = state.run(values)?;
for (i, value) in values.into_iter().enumerate() {
*(outputs.add(i)) = Box::into_raw(Box::new(TractValue(value)))
}
Ok(())
}