use std::{
ffi::c_void,
marker::PhantomData,
ptr::{null, null_mut},
};
use ndarray::{ArrayViewD, ArrayViewMutD};
use ort2_sys::{self as ffi, ONNXTensorElementDataType, ONNXType};
use smart_default::SmartDefault;
use tracing::*;
use crate::{
allocator::AllocatorTrait,
api::{api, ok},
error::Result,
memory::MemoryInfo,
};
pub trait TensorTrait {
fn data(&self) -> *mut c_void;
fn size(&self) -> usize;
}
pub trait TensorTypeAndShapeInfoTrait {
fn inner(&self) -> *mut ffi::OrtTensorTypeAndShapeInfo;
fn typ(&self) -> Result<ONNXTensorElementDataType> {
let mut typ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
ok!(GetTensorElementType, self.inner(), &mut typ)?;
Ok(typ)
}
fn shape(&self) -> Result<Vec<i64>> {
let mut count = 0usize;
ok!(GetDimensionsCount, self.inner(), &mut count)?;
let mut shape = vec![0i64; count];
ok!(GetDimensions, self.inner(), shape.as_mut_ptr(), count)?;
Ok(shape)
}
}
pub struct TensorTypeAndShapeInfoCasted<'a> {
inner: *mut ffi::OrtTensorTypeAndShapeInfo,
marker: PhantomData<&'a TypeInfo>,
}
impl TensorTypeAndShapeInfoTrait for TensorTypeAndShapeInfoCasted<'_> {
fn inner(&self) -> *mut ffi::OrtTensorTypeAndShapeInfo {
self.inner
}
}
impl std::fmt::Debug for TensorTypeAndShapeInfoCasted<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TensorTypeAndShapeInfo")
.field("inner", &self.inner)
.field("typ", &self.typ().expect("failed to get typ"))
.field("shape", &self.shape().expect("failed to get shape"))
.finish()
}
}
pub struct TypeInfo {
inner: *mut ffi::OrtTypeInfo,
}
impl std::fmt::Debug for TypeInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut f = f.debug_struct("TypeInfo");
f.field("inner", &self.inner);
if let Ok(typ) = self.typ() {
match typ {
ONNXType::ONNX_TYPE_TENSOR => f.field(
"typ",
&self.tensor_typ().expect("failed to get tensor shape info"),
),
_ => f.field("typ", &typ),
};
}
f.finish()
}
}
impl TypeInfo {
pub fn typ(&self) -> Result<ONNXType> {
let mut typ = ONNXType::ONNX_TYPE_UNKNOWN;
ok!(GetOnnxTypeFromTypeInfo, self.inner, &mut typ)?;
Ok(typ)
}
pub fn tensor_typ(&self) -> Result<TensorTypeAndShapeInfoCasted> {
let mut inner = null();
ok!(CastTypeInfoToTensorInfo, self.inner, &mut inner)?;
Ok(TensorTypeAndShapeInfoCasted {
inner: inner as *mut _,
marker: PhantomData,
})
}
pub(crate) fn new(inner: *mut ffi::OrtTypeInfo) -> Self {
Self { inner }
}
}
impl Drop for TypeInfo {
fn drop(&mut self) {
api!(ReleaseTypeInfo, self.inner);
}
}
pub struct Value<'a> {
inner: *mut ffi::OrtValue,
marker: PhantomData<&'a ()>,
}
impl std::fmt::Debug for Value<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Value")
.field("inner", &self.inner)
.field("typ", &self.typ().expect("failed to get typ of value"))
.finish()
}
}
impl Value<'_> {
pub fn inner(&self) -> *mut ffi::OrtValue {
self.inner
}
pub fn tensor() -> TensorBuilder {
TensorBuilder::default()
}
pub fn typ(&self) -> Result<TypeInfo> {
let mut inner = null_mut();
ok!(GetTypeInfo, self.inner, &mut inner)?;
Ok(TypeInfo::new(inner))
}
pub fn data(&self) -> Result<*mut c_void> {
let mut data = null_mut();
ok!(GetTensorMutableData, self.inner, &mut data)?;
Ok(data)
}
pub(crate) fn new<T>(inner: *mut ffi::OrtValue, _: &T) -> Self {
Self {
inner,
marker: PhantomData,
}
}
pub fn view<T>(&self) -> Result<ArrayViewD<T>> {
let typ = self.typ()?;
let typ = typ.tensor_typ()?;
let shape = typ
.shape()?
.into_iter()
.map(|d| d as usize)
.collect::<Vec<_>>();
Ok(unsafe { ArrayViewD::from_shape_ptr(shape, self.data()? as *const T) })
}
pub fn view_mut<T>(&mut self) -> Result<ArrayViewMutD<T>> {
let typ = self.typ()?;
let typ = typ.tensor_typ()?;
let shape = typ
.shape()?
.into_iter()
.map(|d| d as usize)
.collect::<Vec<_>>();
Ok(unsafe { ArrayViewMutD::from_shape_ptr(shape, self.data()? as *mut T) })
}
}
impl Drop for Value<'_> {
fn drop(&mut self) {
trace!(?self, "dropping");
api!(ReleaseValue, self.inner);
}
}
#[derive(SmartDefault)]
pub struct TensorBuilder {
shape: Vec<i64>,
#[default(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)]
typ: ONNXTensorElementDataType,
mem_info: MemoryInfo,
}
impl TensorBuilder {
pub fn with_shape(mut self, shape: impl AsRef<[i64]>) -> Self {
self.shape = shape.as_ref().to_vec();
self
}
pub fn with_typ(mut self, typ: ONNXTensorElementDataType) -> Self {
self.typ = typ;
self
}
pub fn with_memory_info(mut self, mem_info: MemoryInfo) -> Self {
self.mem_info = mem_info;
self
}
pub fn borrow(self, input: &impl TensorTrait) -> Result<Value> {
let Self {
shape,
typ,
mem_info,
} = self;
let size = input.size();
let data = input.data();
let mut inner = null_mut();
ok!(
CreateTensorWithDataAsOrtValue,
mem_info.inner(),
data,
size,
shape.as_ptr(),
shape.len(),
typ,
&mut inner
)?;
Ok(Value::new(inner, &data))
}
pub fn alloc(self, alloc: &impl AllocatorTrait) -> Result<Value> {
let Self { shape, typ, .. } = self;
let mut inner = null_mut();
ok!(
CreateTensorAsOrtValue,
alloc.inner(),
shape.as_ptr(),
shape.len(),
typ,
&mut inner
)?;
Ok(Value {
inner,
marker: PhantomData,
})
}
}
impl<T> TensorTrait for &[T]
where
T: Sized,
{
fn data(&self) -> *mut c_void {
self.as_ptr() as *mut c_void
}
#[allow(clippy::manual_slice_size_calculation)]
fn size(&self) -> usize {
self.len() * std::mem::size_of::<T>()
}
}
impl<T> TensorTrait for Vec<T>
where
T: Sized,
{
fn data(&self) -> *mut c_void {
self.as_ptr() as *mut c_void
}
fn size(&self) -> usize {
self.len() * size_of::<T>()
}
}