use std::ffi::{CStr, CString};
use std::path::Path;
use std::ptr::{null, null_mut};
use tract_api::*;
use tract_proxy_sys as sys;
use anyhow::{Context, Result};
mod ndarray_interop;
pub use ndarray_interop::__ndarray_interop;
macro_rules! check {
($expr:expr) => {
unsafe {
if $expr == sys::TRACT_RESULT_TRACT_RESULT_KO {
let buf = CStr::from_ptr(sys::tract_get_last_error());
Err(anyhow::anyhow!(buf.to_string_lossy().to_string()))
} else {
Ok(())
}
}
};
}
macro_rules! wrapper {
($new_type:ident, $c_type:ident, $dest:ident $(, $typ:ty )*) => {
#[derive(Debug)]
pub struct $new_type(*mut sys::$c_type $(, $typ)*);
impl Drop for $new_type {
fn drop(&mut self) {
unsafe {
sys::$dest(&mut self.0);
}
}
}
};
}
macro_rules! wrapper_clone {
($new_type:ident, $clone_fn:ident) => {
impl Clone for $new_type {
fn clone(&self) -> Self {
let mut clone = null_mut();
unsafe {
sys::$clone_fn(self.0, &mut clone);
}
$new_type(clone)
}
}
};
}
pub fn nnef() -> Result<Nnef> {
let mut nnef = null_mut();
check!(sys::tract_nnef_create(&mut nnef))?;
Ok(Nnef(nnef))
}
pub fn onnx() -> Result<Onnx> {
let mut onnx = null_mut();
check!(sys::tract_onnx_create(&mut onnx))?;
Ok(Onnx(onnx))
}
pub fn version() -> &'static str {
unsafe { CStr::from_ptr(sys::tract_version()).to_str().unwrap() }
}
wrapper!(Nnef, TractNnef, tract_nnef_destroy);
impl NnefInterface for Nnef {
type Model = Model;
fn load(&self, path: impl AsRef<Path>) -> Result<Model> {
let path = path.as_ref();
let path = CString::new(
path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
)?;
let mut model = null_mut();
check!(sys::tract_nnef_load(self.0, path.as_ptr(), &mut model))?;
Ok(Model(model))
}
fn load_buffer(&self, data: &[u8]) -> Result<Model> {
let mut model = null_mut();
check!(sys::tract_nnef_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
Ok(Model(model))
}
fn enable_tract_core(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_tract_core(self.0))
}
fn enable_tract_extra(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_tract_extra(self.0))
}
fn enable_tract_transformers(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_tract_transformers(self.0))
}
fn enable_onnx(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_onnx(self.0))
}
fn enable_pulse(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_pulse(self.0))
}
fn enable_extended_identifier_syntax(&mut self) -> Result<()> {
check!(sys::tract_nnef_enable_extended_identifier_syntax(self.0))
}
fn write_model_to_dir(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
let path = path.as_ref();
let path = CString::new(
path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
)?;
check!(sys::tract_nnef_write_model_to_dir(self.0, path.as_ptr(), model.0))?;
Ok(())
}
fn write_model_to_tar(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
let path = path.as_ref();
let path = CString::new(
path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
)?;
check!(sys::tract_nnef_write_model_to_tar(self.0, path.as_ptr(), model.0))?;
Ok(())
}
fn write_model_to_tar_gz(&self, path: impl AsRef<Path>, model: &Model) -> Result<()> {
let path = path.as_ref();
let path = CString::new(
path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
)?;
check!(sys::tract_nnef_write_model_to_tar_gz(self.0, path.as_ptr(), model.0))?;
Ok(())
}
}
wrapper!(Onnx, TractOnnx, tract_onnx_destroy);
impl OnnxInterface for Onnx {
type InferenceModel = InferenceModel;
fn load(&self, path: impl AsRef<Path>) -> Result<InferenceModel> {
let path = path.as_ref();
let path = CString::new(
path.to_str().with_context(|| format!("Failed to re-encode {path:?} to uff-8"))?,
)?;
let mut model = null_mut();
check!(sys::tract_onnx_load(self.0, path.as_ptr(), &mut model))?;
Ok(InferenceModel(model))
}
fn load_buffer(&self, data: &[u8]) -> Result<InferenceModel> {
let mut model = null_mut();
check!(sys::tract_onnx_load_buffer(self.0, data.as_ptr() as _, data.len(), &mut model))?;
Ok(InferenceModel(model))
}
}
wrapper!(InferenceModel, TractInferenceModel, tract_inference_model_destroy);
impl InferenceModelInterface for InferenceModel {
type Model = Model;
type InferenceFact = InferenceFact;
fn input_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_inference_model_input_count(self.0, &mut count))?;
Ok(count)
}
fn output_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_inference_model_output_count(self.0, &mut count))?;
Ok(count)
}
fn input_name(&self, id: usize) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_inference_model_input_name(self.0, id, &mut ptr))?;
unsafe {
let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
sys::tract_free_cstring(ptr);
Ok(ret)
}
}
fn output_name(&self, id: usize) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_inference_model_output_name(self.0, id, &mut ptr))?;
unsafe {
let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
sys::tract_free_cstring(ptr);
Ok(ret)
}
}
fn input_fact(&self, id: usize) -> Result<InferenceFact> {
let mut ptr = null_mut();
check!(sys::tract_inference_model_input_fact(self.0, id, &mut ptr))?;
Ok(InferenceFact(ptr))
}
fn set_input_fact(
&mut self,
id: usize,
fact: impl AsFact<Self, Self::InferenceFact>,
) -> Result<()> {
let fact = fact.as_fact(self)?;
check!(sys::tract_inference_model_set_input_fact(self.0, id, fact.0))?;
Ok(())
}
fn output_fact(&self, id: usize) -> Result<InferenceFact> {
let mut ptr = null_mut();
check!(sys::tract_inference_model_output_fact(self.0, id, &mut ptr))?;
Ok(InferenceFact(ptr))
}
fn set_output_fact(
&mut self,
id: usize,
fact: impl AsFact<InferenceModel, InferenceFact>,
) -> Result<()> {
let fact = fact.as_fact(self)?;
check!(sys::tract_inference_model_set_output_fact(self.0, id, fact.0))?;
Ok(())
}
fn analyse(&mut self) -> Result<()> {
check!(sys::tract_inference_model_analyse(self.0))?;
Ok(())
}
fn into_model(mut self) -> Result<Self::Model> {
let mut ptr = null_mut();
check!(sys::tract_inference_model_into_model(&mut self.0, &mut ptr))?;
Ok(Model(ptr))
}
}
wrapper!(Model, TractModel, tract_model_destroy);
impl ModelInterface for Model {
type Fact = Fact;
type Tensor = Tensor;
type Runnable = Runnable;
fn input_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_model_input_count(self.0, &mut count))?;
Ok(count)
}
fn output_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_model_output_count(self.0, &mut count))?;
Ok(count)
}
fn input_name(&self, id: usize) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_model_input_name(self.0, id, &mut ptr))?;
unsafe {
let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
sys::tract_free_cstring(ptr);
Ok(ret)
}
}
fn output_name(&self, id: usize) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_model_output_name(self.0, id, &mut ptr))?;
unsafe {
let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
sys::tract_free_cstring(ptr);
Ok(ret)
}
}
fn input_fact(&self, id: usize) -> Result<Fact> {
let mut ptr = null_mut();
check!(sys::tract_model_input_fact(self.0, id, &mut ptr))?;
Ok(Fact(ptr))
}
fn output_fact(&self, id: usize) -> Result<Fact> {
let mut ptr = null_mut();
check!(sys::tract_model_output_fact(self.0, id, &mut ptr))?;
Ok(Fact(ptr))
}
fn into_runnable(self) -> Result<Runnable> {
let mut model = self;
let mut runnable = null_mut();
check!(sys::tract_model_into_runnable(&mut model.0, &mut runnable))?;
Ok(Runnable(runnable))
}
fn transform(&mut self, spec: impl Into<TransformSpec>) -> Result<()> {
let transform = spec.into().to_transform_string();
let t = CString::new(transform)?;
check!(sys::tract_model_transform(self.0, t.as_ptr()))?;
Ok(())
}
fn property_keys(&self) -> Result<Vec<String>> {
let mut len = 0;
check!(sys::tract_model_property_count(self.0, &mut len))?;
let mut keys = vec![null_mut(); len];
check!(sys::tract_model_property_names(self.0, keys.as_mut_ptr()))?;
unsafe {
keys.into_iter()
.map(|pc| {
let s = CStr::from_ptr(pc).to_str()?.to_owned();
sys::tract_free_cstring(pc);
Ok(s)
})
.collect()
}
}
fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
let mut v = null_mut();
let name = CString::new(name.as_ref())?;
check!(sys::tract_model_property(self.0, name.as_ptr(), &mut v))?;
Ok(Tensor(v))
}
fn parse_fact(&self, spec: &str) -> Result<Self::Fact> {
let spec = CString::new(spec)?;
let mut ptr = null_mut();
check!(sys::tract_model_parse_fact(self.0, spec.as_ptr(), &mut ptr))?;
Ok(Fact(ptr))
}
}
wrapper!(Runtime, TractRuntime, tract_runtime_release);
pub fn runtime_for_name(name: &str) -> Result<Runtime> {
let mut rt = null_mut();
let name = CString::new(name)?;
check!(sys::tract_runtime_for_name(name.as_ptr(), &mut rt))?;
Ok(Runtime(rt))
}
impl RuntimeInterface for Runtime {
type Runnable = Runnable;
type Model = Model;
fn name(&self) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_runtime_name(self.0, &mut ptr))?;
unsafe {
let ret = CStr::from_ptr(ptr).to_str()?.to_owned();
sys::tract_free_cstring(ptr);
Ok(ret)
}
}
fn prepare(&self, model: Self::Model) -> Result<Self::Runnable> {
let mut model = model;
let mut runnable = null_mut();
check!(sys::tract_runtime_prepare(self.0, &mut model.0, &mut runnable))?;
Ok(Runnable(runnable))
}
}
wrapper!(Runnable, TractRunnable, tract_runnable_release);
unsafe impl Send for Runnable {}
unsafe impl Sync for Runnable {}
impl RunnableInterface for Runnable {
type Tensor = Tensor;
type State = State;
type Fact = Fact;
fn run(&self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
StateInterface::run(&mut self.spawn_state()?, inputs.into_inputs()?)
}
fn spawn_state(&self) -> Result<State> {
let mut state = null_mut();
check!(sys::tract_runnable_spawn_state(self.0, &mut state))?;
Ok(State(state))
}
fn input_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_runnable_input_count(self.0, &mut count))?;
Ok(count)
}
fn output_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_runnable_output_count(self.0, &mut count))?;
Ok(count)
}
fn input_fact(&self, id: usize) -> Result<Self::Fact> {
let mut ptr = null_mut();
check!(sys::tract_runnable_input_fact(self.0, id, &mut ptr))?;
Ok(Fact(ptr))
}
fn output_fact(&self, id: usize) -> Result<Self::Fact> {
let mut ptr = null_mut();
check!(sys::tract_runnable_output_fact(self.0, id, &mut ptr))?;
Ok(Fact(ptr))
}
fn property_keys(&self) -> Result<Vec<String>> {
let mut len = 0;
check!(sys::tract_runnable_property_count(self.0, &mut len))?;
let mut keys = vec![null_mut(); len];
check!(sys::tract_runnable_property_names(self.0, keys.as_mut_ptr()))?;
unsafe {
keys.into_iter()
.map(|pc| {
let s = CStr::from_ptr(pc).to_str()?.to_owned();
sys::tract_free_cstring(pc);
Ok(s)
})
.collect()
}
}
fn property(&self, name: impl AsRef<str>) -> Result<Tensor> {
let mut v = null_mut();
let name = CString::new(name.as_ref())?;
check!(sys::tract_runnable_property(self.0, name.as_ptr(), &mut v))?;
Ok(Tensor(v))
}
fn cost_json(&self) -> Result<String> {
let input: Option<Vec<Tensor>> = None;
self.profile_json(input)
}
fn profile_json<I, IV, IE>(&self, inputs: Option<I>) -> Result<String>
where
I: IntoIterator<Item = IV>,
IV: TryInto<Self::Tensor, Error = IE>,
IE: Into<anyhow::Error>,
{
let inputs = if let Some(inputs) = inputs {
let inputs = inputs
.into_iter()
.map(|i| i.try_into().map_err(|e| e.into()))
.collect::<Result<Vec<Tensor>>>()?;
anyhow::ensure!(self.input_count()? == inputs.len());
Some(inputs)
} else {
None
};
let mut iptrs: Option<Vec<*mut sys::TractTensor>> =
inputs.as_ref().map(|is| is.iter().map(|v| v.0).collect());
let mut json: *mut i8 = null_mut();
let values = iptrs.as_mut().map(|it| it.as_mut_ptr()).unwrap_or(null_mut());
check!(sys::tract_runnable_profile_json(self.0, values, &mut json))?;
anyhow::ensure!(!json.is_null());
unsafe {
let s = CStr::from_ptr(json).to_owned();
sys::tract_free_cstring(json);
Ok(s.to_str()?.to_owned())
}
}
}
pub struct State(*mut sys::TractState);
impl Drop for State {
fn drop(&mut self) {
unsafe {
sys::tract_state_destroy(&mut self.0);
}
}
}
impl std::fmt::Debug for State {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "State({:?})", self.0)
}
}
impl Clone for State {
fn clone(&self) -> Self {
let mut clone = null_mut();
unsafe {
sys::tract_state_clone(self.0, &mut clone);
}
State(clone)
}
}
unsafe impl Send for State {}
impl StateInterface for State {
type Tensor = Tensor;
type Fact = Fact;
fn run(&mut self, inputs: impl IntoInputs<Tensor>) -> Result<Vec<Tensor>> {
let inputs = inputs.into_inputs()?;
let mut outputs = vec![null_mut(); self.output_count()?];
let mut inputs: Vec<_> = inputs.iter().map(|v| v.0).collect();
check!(sys::tract_state_run(self.0, inputs.as_mut_ptr(), outputs.as_mut_ptr()))?;
let outputs = outputs.into_iter().map(Tensor).collect();
Ok(outputs)
}
fn input_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_state_input_count(self.0, &mut count))?;
Ok(count)
}
fn output_count(&self) -> Result<usize> {
let mut count = 0;
check!(sys::tract_state_output_count(self.0, &mut count))?;
Ok(count)
}
}
wrapper!(Tensor, TractTensor, tract_tensor_destroy);
wrapper_clone!(Tensor, tract_tensor_clone);
unsafe impl Send for Tensor {}
unsafe impl Sync for Tensor {}
impl TensorInterface for Tensor {
fn from_bytes(dt: DatumType, shape: &[usize], data: &[u8]) -> Result<Self> {
anyhow::ensure!(data.len() == shape.iter().product::<usize>() * dt.size_of());
let mut value = null_mut();
check!(sys::tract_tensor_from_bytes(
dt as _,
shape.len(),
shape.as_ptr(),
data.as_ptr() as _,
&mut value
))?;
Ok(Tensor(value))
}
fn as_bytes(&self) -> Result<(DatumType, &[usize], &[u8])> {
let mut rank = 0;
let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
let mut shape = null();
let mut data = null();
check!(sys::tract_tensor_as_bytes(self.0, &mut dt, &mut rank, &mut shape, &mut data))?;
unsafe {
let dt: DatumType = std::mem::transmute(dt);
let shape = std::slice::from_raw_parts(shape, rank);
let len: usize = shape.iter().product();
let data = std::slice::from_raw_parts(data as *const u8, len * dt.size_of());
Ok((dt, shape, data))
}
}
fn datum_type(&self) -> Result<DatumType> {
let mut dt = sys::DatumType_TRACT_DATUM_TYPE_BOOL as _;
check!(sys::tract_tensor_as_bytes(
self.0,
&mut dt,
std::ptr::null_mut(),
std::ptr::null_mut(),
std::ptr::null_mut()
))?;
unsafe {
let dt: DatumType = std::mem::transmute(dt);
Ok(dt)
}
}
fn convert_to(&self, to: DatumType) -> Result<Self> {
let mut new = null_mut();
check!(sys::tract_tensor_convert_to(self.0, to as _, &mut new))?;
Ok(Tensor(new))
}
}
impl PartialEq for Tensor {
fn eq(&self, other: &Self) -> bool {
let Ok((me_dt, me_shape, me_data)) = self.as_bytes() else { return false };
let Ok((other_dt, other_shape, other_data)) = other.as_bytes() else { return false };
me_dt == other_dt && me_shape == other_shape && me_data == other_data
}
}
wrapper!(Fact, TractFact, tract_fact_destroy);
wrapper_clone!(Fact, tract_fact_clone);
impl Fact {
fn new(model: &Model, spec: impl ToString) -> Result<Fact> {
let cstr = CString::new(spec.to_string())?;
let mut fact = null_mut();
check!(sys::tract_model_parse_fact(model.0, cstr.as_ptr(), &mut fact))?;
Ok(Fact(fact))
}
fn dump(&self) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_fact_dump(self.0, &mut ptr))?;
unsafe {
let s = CStr::from_ptr(ptr).to_owned();
sys::tract_free_cstring(ptr);
Ok(s.to_str()?.to_owned())
}
}
}
impl FactInterface for Fact {
type Dim = Dim;
fn datum_type(&self) -> Result<DatumType> {
let mut dt = 0u32;
check!(sys::tract_fact_datum_type(self.0, &mut dt as *const u32 as _))?;
Ok(unsafe { std::mem::transmute::<u32, DatumType>(dt) })
}
fn rank(&self) -> Result<usize> {
let mut rank = 0;
check!(sys::tract_fact_rank(self.0, &mut rank))?;
Ok(rank)
}
fn dim(&self, axis: usize) -> Result<Self::Dim> {
let mut ptr = null_mut();
check!(sys::tract_fact_dim(self.0, axis, &mut ptr))?;
Ok(Dim(ptr))
}
}
impl std::fmt::Display for Fact {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.dump() {
Ok(s) => f.write_str(&s),
Err(_) => Err(std::fmt::Error),
}
}
}
wrapper!(InferenceFact, TractInferenceFact, tract_inference_fact_destroy);
wrapper_clone!(InferenceFact, tract_inference_fact_clone);
impl InferenceFact {
fn new(model: &InferenceModel, spec: impl ToString) -> Result<InferenceFact> {
let cstr = CString::new(spec.to_string())?;
let mut fact = null_mut();
check!(sys::tract_inference_fact_parse(model.0, cstr.as_ptr(), &mut fact))?;
Ok(InferenceFact(fact))
}
fn dump(&self) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_inference_fact_dump(self.0, &mut ptr))?;
unsafe {
let s = CStr::from_ptr(ptr).to_owned();
sys::tract_free_cstring(ptr);
Ok(s.to_str()?.to_owned())
}
}
}
impl InferenceFactInterface for InferenceFact {
fn empty() -> Result<InferenceFact> {
let mut fact = null_mut();
check!(sys::tract_inference_fact_empty(&mut fact))?;
Ok(InferenceFact(fact))
}
}
impl Default for InferenceFact {
fn default() -> Self {
Self::empty().unwrap()
}
}
impl std::fmt::Display for InferenceFact {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.dump() {
Ok(s) => f.write_str(&s),
Err(_) => Err(std::fmt::Error),
}
}
}
as_inference_fact_impl!(InferenceModel, InferenceFact);
as_fact_impl!(Model, Fact);
wrapper!(Dim, TractDim, tract_dim_destroy);
wrapper_clone!(Dim, tract_dim_clone);
impl Dim {
fn dump(&self) -> Result<String> {
let mut ptr = null_mut();
check!(sys::tract_dim_dump(self.0, &mut ptr))?;
unsafe {
let s = CStr::from_ptr(ptr).to_owned();
sys::tract_free_cstring(ptr);
Ok(s.to_str()?.to_owned())
}
}
}
impl DimInterface for Dim {
fn eval(&self, values: impl IntoIterator<Item = (impl AsRef<str>, i64)>) -> Result<Self> {
let (names, values): (Vec<_>, Vec<_>) = values.into_iter().unzip();
let c_strings: Vec<CString> =
names.into_iter().map(|a| Ok(CString::new(a.as_ref())?)).collect::<Result<_>>()?;
let ptrs: Vec<_> = c_strings.iter().map(|cs| cs.as_ptr()).collect();
let mut ptr = null_mut();
check!(sys::tract_dim_eval(self.0, ptrs.len(), ptrs.as_ptr(), values.as_ptr(), &mut ptr))?;
Ok(Dim(ptr))
}
fn to_int64(&self) -> Result<i64> {
let mut i = 0;
check!(sys::tract_dim_to_int64(self.0, &mut i))?;
Ok(i)
}
}
impl std::fmt::Display for Dim {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.dump() {
Ok(s) => f.write_str(&s),
Err(_) => Err(std::fmt::Error),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn clone_tensor_no_double_free() {
let t = Tensor::from_slice::<f32>(&[2, 2], &[1.0, 2.0, 3.0, 4.0]).unwrap();
let clone = t.clone();
assert_eq!(t, clone);
}
}