ort_tract/
lib.rs

1use ort_sys::OrtErrorCode;
2use tract_onnx::{Onnx, prelude::DatumType};
3
4mod api;
5pub(crate) mod error;
6mod memory;
7mod session;
8mod tensor;
9
10pub use self::api::api;
11use self::error::Error;
12
13pub(crate) struct Environment {
14	pub onnx: Onnx
15}
16
17impl Environment {
18	pub fn new_sys() -> *mut ort_sys::OrtEnv {
19		(Box::leak(Box::new(Self { onnx: tract_onnx::onnx() })) as *mut Environment).cast()
20	}
21
22	pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtEnv) -> Box<Environment> {
23		Box::from_raw(ptr.cast::<Environment>())
24	}
25}
26
27fn convert_sys_to_datum_type(sys: ort_sys::ONNXTensorElementDataType) -> Result<DatumType, Error> {
28	match sys {
29		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Ok(DatumType::Bool),
30		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Ok(DatumType::U8),
31		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Ok(DatumType::U16),
32		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Ok(DatumType::U32),
33		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Ok(DatumType::U64),
34		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Ok(DatumType::I8),
35		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Ok(DatumType::I16),
36		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Ok(DatumType::I32),
37		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Ok(DatumType::I64),
38		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Ok(DatumType::F16),
39		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Ok(DatumType::F32),
40		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Ok(DatumType::F64),
41		ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Ok(DatumType::String),
42		_ => Err(Error::new(OrtErrorCode::ORT_FAIL, "Element type not supported by tract"))
43	}
44}
45
46fn convert_datum_type_to_sys(dtype: DatumType) -> ort_sys::ONNXTensorElementDataType {
47	match dtype {
48		DatumType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
49		DatumType::U8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
50		DatumType::U16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
51		DatumType::U32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
52		DatumType::U64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
53		DatumType::I8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
54		DatumType::I16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
55		DatumType::I32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
56		DatumType::I64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
57		DatumType::F16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
58		DatumType::F32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
59		DatumType::F64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
60		DatumType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
61		_ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
62	}
63}