pub mod ndarray_tensor;
pub mod ort_owned_tensor;
pub mod ort_tensor;
pub mod type_dynamic_tensor;
use std::{ffi, fmt, ptr, rc, result, string};
pub use ort_owned_tensor::{DynOrtTensor, OrtOwnedTensor};
pub use ort_tensor::OrtTensor;
pub use type_dynamic_tensor::FromArray;
pub use type_dynamic_tensor::InputTensor;
use super::{
ortsys,
sys::{self as sys, OnnxEnumInt},
tensor::ort_owned_tensor::TensorPointerHolder,
OrtError, OrtResult
};
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))]
pub enum TensorElementDataType {
Float32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT as OnnxEnumInt,
Uint8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8 as OnnxEnumInt,
Int8 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8 as OnnxEnumInt,
Uint16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16 as OnnxEnumInt,
Int16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16 as OnnxEnumInt,
Int32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32 as OnnxEnumInt,
Int64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64 as OnnxEnumInt,
String = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING as OnnxEnumInt,
Bool = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL as OnnxEnumInt,
#[cfg(feature = "half")]
Float16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 as OnnxEnumInt,
Float64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE as OnnxEnumInt,
Uint32 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32 as OnnxEnumInt,
Uint64 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 as OnnxEnumInt,
#[cfg(feature = "half")]
Bfloat16 = sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 as OnnxEnumInt
}
impl From<TensorElementDataType> for sys::ONNXTensorElementDataType {
fn from(val: TensorElementDataType) -> Self {
match val {
TensorElementDataType::Float32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
TensorElementDataType::Uint8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8,
TensorElementDataType::Int8 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8,
TensorElementDataType::Uint16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16,
TensorElementDataType::Int16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16,
TensorElementDataType::Int32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32,
TensorElementDataType::Int64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64,
TensorElementDataType::String => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING,
TensorElementDataType::Bool => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL,
#[cfg(feature = "half")]
TensorElementDataType::Float16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16,
TensorElementDataType::Float64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE,
TensorElementDataType::Uint32 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32,
TensorElementDataType::Uint64 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64,
#[cfg(feature = "half")]
TensorElementDataType::Bfloat16 => sys::ONNXTensorElementDataType_ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
}
}
}
pub trait IntoTensorElementDataType {
fn tensor_element_data_type() -> TensorElementDataType;
fn try_utf8_bytes(&self) -> Option<&[u8]>;
}
macro_rules! impl_type_trait {
($type_:ty, $variant:ident) => {
impl IntoTensorElementDataType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
}
fn try_utf8_bytes(&self) -> Option<&[u8]> {
None
}
}
};
}
impl_type_trait!(f32, Float32);
impl_type_trait!(u8, Uint8);
impl_type_trait!(i8, Int8);
impl_type_trait!(u16, Uint16);
impl_type_trait!(i16, Int16);
impl_type_trait!(i32, Int32);
impl_type_trait!(i64, Int64);
impl_type_trait!(bool, Bool);
#[cfg(feature = "half")]
impl_type_trait!(half::f16, Float16);
impl_type_trait!(f64, Float64);
impl_type_trait!(u32, Uint32);
impl_type_trait!(u64, Uint64);
#[cfg(feature = "half")]
impl_type_trait!(half::bf16, Bfloat16);
pub trait Utf8Data {
fn utf8_bytes(&self) -> &[u8];
}
impl Utf8Data for String {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl<'a> Utf8Data for &'a str {
fn utf8_bytes(&self) -> &[u8] {
self.as_bytes()
}
}
impl<T: Utf8Data> IntoTensorElementDataType for T {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}
fn try_utf8_bytes(&self) -> Option<&[u8]> {
Some(self.utf8_bytes())
}
}
pub trait TensorDataToType: Sized + fmt::Debug + Clone {
fn tensor_element_data_type() -> TensorElementDataType;
fn extract_data<'t, D>(shape: D, tensor_element_len: usize, tensor_ptr: rc::Rc<TensorPointerHolder>) -> OrtResult<TensorData<'t, Self, D>>
where
D: ndarray::Dimension;
}
#[derive(Debug)]
pub enum TensorData<'t, T, D>
where
D: ndarray::Dimension
{
TensorPtr {
ptr: rc::Rc<TensorPointerHolder>,
array_view: ndarray::ArrayView<'t, T, D>
},
Strings {
strings: ndarray::Array<T, D>
}
}
macro_rules! impl_prim_type_from_ort_trait {
($type_: ty, $variant: ident) => {
impl TensorDataToType for $type_ {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::$variant
}
fn extract_data<'t, D>(shape: D, _tensor_element_len: usize, tensor_ptr: rc::Rc<TensorPointerHolder>) -> OrtResult<TensorData<'t, Self, D>>
where
D: ndarray::Dimension
{
extract_primitive_array(shape, tensor_ptr.tensor_ptr).map(|v| TensorData::TensorPtr { ptr: tensor_ptr, array_view: v })
}
}
};
}
fn extract_primitive_array<'t, D, T: TensorDataToType>(shape: D, tensor: *mut sys::OrtValue) -> OrtResult<ndarray::ArrayView<'t, T, D>>
where
D: ndarray::Dimension
{
let mut output_array_ptr: *mut T = ptr::null_mut();
let output_array_ptr_ptr: *mut *mut T = &mut output_array_ptr;
let output_array_ptr_ptr_void: *mut *mut std::ffi::c_void = output_array_ptr_ptr as *mut *mut std::ffi::c_void;
ortsys![unsafe GetTensorMutableData(tensor, output_array_ptr_ptr_void) -> OrtError::GetTensorMutableData; nonNull(output_array_ptr)];
let array_view = unsafe { ndarray::ArrayView::from_shape_ptr(shape, output_array_ptr) };
Ok(array_view)
}
#[cfg(feature = "half")]
impl_prim_type_from_ort_trait!(half::f16, Float16);
#[cfg(feature = "half")]
impl_prim_type_from_ort_trait!(half::bf16, Bfloat16);
impl_prim_type_from_ort_trait!(f32, Float32);
impl_prim_type_from_ort_trait!(f64, Float64);
impl_prim_type_from_ort_trait!(u8, Uint8);
impl_prim_type_from_ort_trait!(u16, Uint16);
impl_prim_type_from_ort_trait!(u32, Uint32);
impl_prim_type_from_ort_trait!(u64, Uint64);
impl_prim_type_from_ort_trait!(i8, Int8);
impl_prim_type_from_ort_trait!(i16, Int16);
impl_prim_type_from_ort_trait!(i32, Int32);
impl_prim_type_from_ort_trait!(i64, Int64);
impl_prim_type_from_ort_trait!(bool, Bool);
impl TensorDataToType for String {
fn tensor_element_data_type() -> TensorElementDataType {
TensorElementDataType::String
}
fn extract_data<'t, D: ndarray::Dimension>(
shape: D,
tensor_element_len: usize,
tensor_ptr: rc::Rc<TensorPointerHolder>
) -> OrtResult<TensorData<'t, Self, D>> {
let mut total_length = 0;
ortsys![unsafe GetStringTensorDataLength(tensor_ptr.tensor_ptr, &mut total_length) -> OrtError::GetStringTensorDataLength];
let mut string_contents = vec![0u8; total_length as _];
let mut offsets = vec![0; tensor_element_len + 1];
ortsys![unsafe GetStringTensorContent(tensor_ptr.tensor_ptr, string_contents.as_mut_ptr() as *mut ffi::c_void, total_length, offsets.as_mut_ptr(), tensor_element_len as _) -> OrtError::GetStringTensorContent];
debug_assert_eq!(0, offsets[tensor_element_len]);
offsets[tensor_element_len] = total_length;
let strings = offsets
.windows(2)
.map(|w| {
let slice = &string_contents[w[0] as _..w[1] as _];
String::from_utf8(slice.into())
})
.collect::<result::Result<Vec<String>, string::FromUtf8Error>>()
.map_err(OrtError::StringFromUtf8Error)?;
let array = ndarray::Array::from_shape_vec(shape, strings).expect("Shape extracted from tensor didn't match tensor contents");
Ok(TensorData::Strings { strings: array })
}
}