use alloc::{
boxed::Box,
string::{String, ToString}
};
use core::{
ffi::CStr,
fmt,
ptr::{self, NonNull}
};
use smallvec::{SmallVec, smallvec};
use crate::{
Result, ortsys,
util::{self, run_on_drop, with_cstr, with_cstr_ptr_array},
value::{Shape, SymbolicDimensions, TensorElementType}
};
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum ValueType {
Tensor {
ty: TensorElementType,
shape: Shape,
dimension_symbols: SymbolicDimensions
},
Sequence(Box<ValueType>),
Map {
key: TensorElementType,
value: TensorElementType
},
Optional(Box<ValueType>)
}
impl ValueType {
pub(crate) unsafe fn from_type_info(typeinfo_ptr: NonNull<ort_sys::OrtTypeInfo>) -> Self {
let _guard = util::run_on_drop(|| {
ortsys![unsafe ReleaseTypeInfo(typeinfo_ptr.as_ptr())];
});
let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN;
ortsys![unsafe GetOnnxTypeFromTypeInfo(typeinfo_ptr.as_ptr(), &mut ty).expect("infallible")];
match ty {
ort_sys::ONNXType::ONNX_TYPE_TENSOR | ort_sys::ONNXType::ONNX_TYPE_SPARSETENSOR => {
let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(typeinfo_ptr.as_ptr(), &mut info_ptr).expect("infallible"); nonNull(info_ptr)];
unsafe { extract_data_type_from_tensor_info(info_ptr) }
}
ort_sys::ONNXType::ONNX_TYPE_SEQUENCE => {
let mut info_ptr: *const ort_sys::OrtSequenceTypeInfo = ptr::null_mut();
ortsys![unsafe CastTypeInfoToSequenceTypeInfo(typeinfo_ptr.as_ptr(), &mut info_ptr).expect("infallible"); nonNull(info_ptr)];
let mut element_type_info: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
ortsys![unsafe GetSequenceElementType(info_ptr.as_ptr(), &mut element_type_info).expect("infallible"); nonNull(element_type_info)];
let _guard = util::run_on_drop(|| {
ortsys![unsafe ReleaseTypeInfo(element_type_info.as_ptr())];
});
let mut ty: ort_sys::ONNXType = ort_sys::ONNXType::ONNX_TYPE_UNKNOWN;
ortsys![unsafe GetOnnxTypeFromTypeInfo(element_type_info.as_ptr(), &mut ty).expect("infallible")];
match ty {
ort_sys::ONNXType::ONNX_TYPE_TENSOR => {
let mut info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(element_type_info.as_ptr(), &mut info_ptr).expect("infallible"); nonNull(info_ptr)];
let ty = unsafe { extract_data_type_from_tensor_info(info_ptr) };
ValueType::Sequence(Box::new(ty))
}
ort_sys::ONNXType::ONNX_TYPE_MAP => {
let mut info_ptr: *const ort_sys::OrtMapTypeInfo = ptr::null_mut();
ortsys![unsafe CastTypeInfoToMapTypeInfo(element_type_info.as_ptr(), &mut info_ptr).expect("infallible"); nonNull(info_ptr)];
let ty = unsafe { extract_data_type_from_map_info(info_ptr) };
ValueType::Sequence(Box::new(ty))
}
_ => unreachable!()
}
}
ort_sys::ONNXType::ONNX_TYPE_MAP => {
let mut info_ptr: *const ort_sys::OrtMapTypeInfo = ptr::null_mut();
ortsys![unsafe CastTypeInfoToMapTypeInfo(typeinfo_ptr.as_ptr(), &mut info_ptr).expect("infallible"); nonNull(info_ptr)];
unsafe { extract_data_type_from_map_info(info_ptr) }
}
ort_sys::ONNXType::ONNX_TYPE_OPTIONAL => {
let mut info_ptr: *const ort_sys::OrtOptionalTypeInfo = ptr::null_mut();
ortsys![unsafe CastTypeInfoToOptionalTypeInfo(typeinfo_ptr.as_ptr(), &mut info_ptr).expect("infallible"); nonNull(info_ptr)];
let mut contained_type: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
ortsys![unsafe GetOptionalContainedTypeInfo(info_ptr.as_ptr(), &mut contained_type).expect("infallible"); nonNull(contained_type)];
ValueType::Optional(Box::new(unsafe { ValueType::from_type_info(contained_type) }))
}
_ => unreachable!()
}
}
pub(crate) fn to_tensor_type_info(&self) -> Option<*mut ort_sys::OrtTensorTypeAndShapeInfo> {
match self {
Self::Tensor { ty, shape, dimension_symbols } => {
let mut info_ptr = ptr::null_mut();
ortsys![unsafe CreateTensorTypeAndShapeInfo(&mut info_ptr).expect("infallible")];
ortsys![unsafe SetTensorElementType(info_ptr, (*ty).into()).expect("infallible")];
ortsys![unsafe SetDimensions(info_ptr, shape.as_ptr(), shape.len()).expect("infallible")];
with_cstr_ptr_array(dimension_symbols, &|ptrs| {
ortsys![unsafe SetSymbolicDimensions(info_ptr, ptrs.as_ptr().cast_mut(), dimension_symbols.len()).expect("infallible")];
Ok(())
})
.expect("invalid dimension symbols");
Some(info_ptr)
}
_ => None
}
}
#[cfg(feature = "api-22")]
pub(crate) fn to_type_info(&self) -> Result<*mut ort_sys::OrtTypeInfo> {
let mut info_ptr: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
match self {
Self::Tensor { .. } => {
let tensor_type_info = self.to_tensor_type_info().expect("infallible");
let _guard = util::run_on_drop(|| ortsys![unsafe ReleaseTensorTypeAndShapeInfo(tensor_type_info)]);
ortsys![@editor: unsafe CreateTensorTypeInfo(tensor_type_info, &mut info_ptr)?];
}
Self::Map { .. } => {
todo!();
}
Self::Sequence(ty) => {
let el_type = ty.to_type_info()?;
let _guard = util::run_on_drop(|| ortsys![unsafe ReleaseTypeInfo(el_type)]);
ortsys![@editor: unsafe CreateSequenceTypeInfo(el_type, &mut info_ptr)?];
}
Self::Optional(ty) => {
let ty = ty.to_type_info()?;
let _guard = util::run_on_drop(|| ortsys![unsafe ReleaseTypeInfo(ty)]);
ortsys![@editor: unsafe CreateOptionalTypeInfo(ty, &mut info_ptr)?];
}
}
Ok(info_ptr)
}
#[must_use]
pub fn tensor_shape(&self) -> Option<&Shape> {
match self {
ValueType::Tensor { shape, .. } => Some(shape),
_ => None
}
}
#[must_use]
pub fn tensor_type(&self) -> Option<TensorElementType> {
match self {
ValueType::Tensor { ty, .. } => Some(*ty),
_ => None
}
}
#[inline]
#[must_use]
pub fn is_tensor(&self) -> bool {
matches!(self, ValueType::Tensor { .. })
}
#[inline]
#[must_use]
pub fn is_sequence(&self) -> bool {
matches!(self, ValueType::Sequence { .. })
}
#[inline]
#[must_use]
pub fn is_map(&self) -> bool {
matches!(self, ValueType::Map { .. })
}
}
impl fmt::Display for ValueType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ValueType::Tensor { ty, shape, dimension_symbols } => {
write!(f, "Tensor<{ty}>(")?;
for (i, dimension) in shape.iter().copied().enumerate() {
if dimension == -1 {
let sym = &dimension_symbols[i];
if sym.is_empty() {
f.write_str("dyn")?;
} else {
f.write_str(sym)?;
}
} else {
dimension.fmt(f)?;
}
if i != shape.len() - 1 {
f.write_str(", ")?;
}
}
f.write_str(")")?;
Ok(())
}
ValueType::Map { key, value } => write!(f, "Map<{key}, {value}>"),
ValueType::Sequence(inner) => write!(f, "Sequence<{inner}>"),
ValueType::Optional(inner) => write!(f, "Option<{inner}>")
}
}
}
#[derive(Debug)]
pub struct Outlet {
name: String,
dtype: ValueType,
value_info: Option<NonNull<ort_sys::OrtValueInfo>>,
drop: bool
}
impl Outlet {
pub fn new<S: Into<String>>(name: S, dtype: ValueType) -> Self {
let name = name.into();
let value_info = Self::make_value_info(&name, &dtype);
Self {
name,
dtype,
value_info,
drop: value_info.is_some()
}
}
#[cfg(feature = "api-22")]
pub(crate) unsafe fn from_raw(raw: NonNull<ort_sys::OrtValueInfo>, drop: bool) -> Result<Self> {
let mut name = ptr::null();
ortsys![unsafe GetValueInfoName(raw.as_ptr(), &mut name)?];
let name = if !name.is_null() {
unsafe { CStr::from_ptr(name) }.to_str().map_or_else(|_| String::new(), str::to_string)
} else {
String::new()
};
let mut type_info = ptr::null();
ortsys![unsafe GetValueInfoTypeInfo(raw.as_ptr(), &mut type_info)?; nonNull(type_info)];
let dtype = unsafe { ValueType::from_type_info(type_info) };
Ok(Self {
name,
dtype,
value_info: Some(raw),
drop
})
}
#[inline]
pub fn name(&self) -> &str {
&self.name
}
#[inline]
pub fn dtype(&self) -> &ValueType {
&self.dtype
}
#[cfg(feature = "api-22")]
pub(crate) fn make_value_info(name: &str, dtype: &ValueType) -> Option<NonNull<ort_sys::OrtValueInfo>> {
let type_info = dtype.to_type_info().ok()?;
let _guard = run_on_drop(|| ortsys![unsafe ReleaseTypeInfo(type_info)]);
with_cstr(name.as_bytes(), &|name| {
let mut ptr: *mut ort_sys::OrtValueInfo = ptr::null_mut();
ortsys![@editor: unsafe CreateValueInfo(name.as_ptr(), type_info, &mut ptr)?; nonNull(ptr)];
Ok(ptr)
})
.ok()
}
#[cfg(not(feature = "api-22"))]
pub(crate) fn make_value_info(_name: &str, _dtype: &ValueType) -> Option<NonNull<ort_sys::OrtValueInfo>> {
None
}
#[inline]
pub(crate) fn into_value_info_ptr(mut self) -> Option<NonNull<ort_sys::OrtValueInfo>> {
let value_info = self.value_info;
self.drop = false;
value_info
}
}
impl Drop for Outlet {
fn drop(&mut self) {
#[cfg(feature = "api-22")]
if self.drop {
ortsys![unsafe ReleaseValueInfo(self.value_info.expect("OrtValueInfo should not be null").as_ptr())];
}
}
}
pub(crate) unsafe fn extract_data_type_from_tensor_info(info_ptr: NonNull<ort_sys::OrtTensorTypeAndShapeInfo>) -> ValueType {
let mut type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(info_ptr.as_ptr(), &mut type_sys).expect("infallible")];
assert_ne!(type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let mut num_dims = 0;
ortsys![unsafe GetDimensionsCount(info_ptr.as_ptr(), &mut num_dims).expect("infallible")];
let mut node_dims = Shape::empty(num_dims);
ortsys![unsafe GetDimensions(info_ptr.as_ptr(), node_dims.as_mut_ptr(), num_dims).expect("infallible")];
let mut symbolic_dims: SmallVec<[_; 4]> = smallvec![ptr::null(); num_dims];
ortsys![unsafe GetSymbolicDimensions(info_ptr.as_ptr(), symbolic_dims.as_mut_ptr(), num_dims).expect("infallible")];
let dimension_symbols = symbolic_dims
.into_iter()
.map(|c| unsafe { CStr::from_ptr(c) }.to_str().map_or_else(|_| String::new(), str::to_string))
.collect();
ValueType::Tensor {
ty: type_sys.into(),
shape: node_dims,
dimension_symbols
}
}
unsafe fn extract_data_type_from_map_info(info_ptr: NonNull<ort_sys::OrtMapTypeInfo>) -> ValueType {
let mut key_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetMapKeyType(info_ptr.as_ptr(), &mut key_type_sys).expect("infallible")];
assert_ne!(key_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
let mut value_type_info: *mut ort_sys::OrtTypeInfo = ptr::null_mut();
ortsys![unsafe GetMapValueType(info_ptr.as_ptr(), &mut value_type_info).expect("infallible")];
let _guard = util::run_on_drop(|| {
ortsys![unsafe ReleaseTypeInfo(value_type_info)];
});
let mut value_info_ptr: *const ort_sys::OrtTensorTypeAndShapeInfo = ptr::null_mut();
ortsys![unsafe CastTypeInfoToTensorInfo(value_type_info, &mut value_info_ptr).expect("infallible")];
let mut value_type_sys = ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ortsys![unsafe GetTensorElementType(value_info_ptr, &mut value_type_sys).expect("infallible")];
assert_ne!(value_type_sys, ort_sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
ValueType::Map {
key: key_type_sys.into(),
value: value_type_sys.into()
}
}
#[cfg(test)]
mod tests {
use core::ptr::NonNull;
use super::ValueType;
use crate::{
ortsys,
value::{Shape, SymbolicDimensions, TensorElementType}
};
#[test]
fn test_to_from_tensor_info() -> crate::Result<()> {
let ty = ValueType::Tensor {
ty: TensorElementType::Float32,
shape: Shape::new([-1, 32, 4, 32]),
dimension_symbols: SymbolicDimensions::new(["d1".to_string(), String::default(), String::default(), String::default()])
};
let ty_ptr = NonNull::new(ty.to_tensor_type_info().expect("")).expect("");
let ty_d = unsafe { super::extract_data_type_from_tensor_info(ty_ptr) };
ortsys![unsafe ReleaseTensorTypeAndShapeInfo(ty_ptr.as_ptr())];
assert_eq!(ty, ty_d);
Ok(())
}
}