1use ort_sys::OrtErrorCode;
2use tract_onnx::{Onnx, prelude::DatumType};
3
4mod api;
5pub(crate) mod error;
6mod memory;
7mod session;
8mod tensor;
9
10pub use self::api::api;
11use self::error::Error;
12
13pub(crate) struct Environment {
14 pub onnx: Onnx
15}
16
17impl Environment {
18 pub fn new_sys() -> *mut ort_sys::OrtEnv {
19 (Box::leak(Box::new(Self { onnx: tract_onnx::onnx() })) as *mut Environment).cast()
20 }
21
22 pub unsafe fn consume_sys(ptr: *mut ort_sys::OrtEnv) -> Box<Environment> {
23 Box::from_raw(ptr.cast::<Environment>())
24 }
25}
26
27fn convert_sys_to_datum_type(sys: ort_sys::ONNXTensorElementDataType) -> Result<DatumType, Error> {
28 match sys {
29 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => Ok(DatumType::Bool),
30 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 => Ok(DatumType::U8),
31 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 => Ok(DatumType::U16),
32 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 => Ok(DatumType::U32),
33 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => Ok(DatumType::U64),
34 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 => Ok(DatumType::I8),
35 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 => Ok(DatumType::I16),
36 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 => Ok(DatumType::I32),
37 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 => Ok(DatumType::I64),
38 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => Ok(DatumType::F16),
39 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => Ok(DatumType::F32),
40 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => Ok(DatumType::F64),
41 ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => Ok(DatumType::String),
42 _ => Err(Error::new(OrtErrorCode::ORT_FAIL, "Element type not supported by tract"))
43 }
44}
45
46fn convert_datum_type_to_sys(dtype: DatumType) -> ort_sys::ONNXTensorElementDataType {
47 match dtype {
48 DatumType::Bool => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
49 DatumType::U8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
50 DatumType::U16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
51 DatumType::U32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
52 DatumType::U64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
53 DatumType::I8 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
54 DatumType::I16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
55 DatumType::I32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
56 DatumType::I64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
57 DatumType::F16 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
58 DatumType::F32 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
59 DatumType::F64 => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
60 DatumType::String => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
61 _ => ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
62 }
63}