ort2 0.1.2

onnxruntime wrapper c/c++ api
Documentation
use std::{
    ffi::c_void,
    marker::PhantomData,
    ptr::{null, null_mut},
};

use ndarray::{ArrayViewD, ArrayViewMutD};
use ort2_sys::{self as ffi, ONNXTensorElementDataType, ONNXType};
use smart_default::SmartDefault;
use tracing::*;

use crate::{
    allocator::AllocatorTrait,
    api::{api, ok},
    error::Result,
    memory::MemoryInfo,
};

pub trait TensorTrait {
    fn data(&self) -> *mut c_void;
    fn size(&self) -> usize;
}

pub trait TensorTypeAndShapeInfoTrait {
    fn inner(&self) -> *mut ffi::OrtTensorTypeAndShapeInfo;

    fn typ(&self) -> Result<ONNXTensorElementDataType> {
        let mut typ = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
        ok!(GetTensorElementType, self.inner(), &mut typ)?;
        Ok(typ)
    }

    fn shape(&self) -> Result<Vec<i64>> {
        let mut count = 0usize;
        ok!(GetDimensionsCount, self.inner(), &mut count)?;
        let mut shape = vec![0i64; count];
        ok!(GetDimensions, self.inner(), shape.as_mut_ptr(), count)?;
        Ok(shape)
    }
}

pub struct TensorTypeAndShapeInfoCasted<'a> {
    inner: *mut ffi::OrtTensorTypeAndShapeInfo,
    marker: PhantomData<&'a TypeInfo>,
}

impl TensorTypeAndShapeInfoTrait for TensorTypeAndShapeInfoCasted<'_> {
    fn inner(&self) -> *mut ffi::OrtTensorTypeAndShapeInfo {
        self.inner
    }
}

impl std::fmt::Debug for TensorTypeAndShapeInfoCasted<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TensorTypeAndShapeInfo")
            .field("inner", &self.inner)
            .field("typ", &self.typ().expect("failed to get typ"))
            .field("shape", &self.shape().expect("failed to get shape"))
            .finish()
    }
}
pub struct TypeInfo {
    inner: *mut ffi::OrtTypeInfo,
}

impl std::fmt::Debug for TypeInfo {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        let mut f = f.debug_struct("TypeInfo");
        f.field("inner", &self.inner);
        if let Ok(typ) = self.typ() {
            match typ {
                ONNXType::ONNX_TYPE_TENSOR => f.field(
                    "typ",
                    &self.tensor_typ().expect("failed to get tensor shape info"),
                ),
                _ => f.field("typ", &typ),
            };
        }
        f.finish()
    }
}

impl TypeInfo {
    pub fn typ(&self) -> Result<ONNXType> {
        let mut typ = ONNXType::ONNX_TYPE_UNKNOWN;
        ok!(GetOnnxTypeFromTypeInfo, self.inner, &mut typ)?;
        Ok(typ)
    }

    pub fn tensor_typ(&self) -> Result<TensorTypeAndShapeInfoCasted> {
        let mut inner = null();
        ok!(CastTypeInfoToTensorInfo, self.inner, &mut inner)?;
        Ok(TensorTypeAndShapeInfoCasted {
            inner: inner as *mut _,
            marker: PhantomData,
        })
    }

    pub(crate) fn new(inner: *mut ffi::OrtTypeInfo) -> Self {
        Self { inner }
    }
}

impl Drop for TypeInfo {
    fn drop(&mut self) {
        api!(ReleaseTypeInfo, self.inner);
    }
}

pub struct Value<'a> {
    inner: *mut ffi::OrtValue,
    marker: PhantomData<&'a ()>,
}

impl std::fmt::Debug for Value<'_> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("Value")
            .field("inner", &self.inner)
            .field("typ", &self.typ().expect("failed to get typ of value"))
            .finish()
    }
}

impl Value<'_> {
    pub fn inner(&self) -> *mut ffi::OrtValue {
        self.inner
    }

    pub fn tensor() -> TensorBuilder {
        TensorBuilder::default()
    }

    pub fn typ(&self) -> Result<TypeInfo> {
        let mut inner = null_mut();
        ok!(GetTypeInfo, self.inner, &mut inner)?;
        Ok(TypeInfo::new(inner))
    }

    pub fn data(&self) -> Result<*mut c_void> {
        let mut data = null_mut();
        ok!(GetTensorMutableData, self.inner, &mut data)?;
        Ok(data)
    }

    pub(crate) fn new<T>(inner: *mut ffi::OrtValue, _: &T) -> Self {
        Self {
            inner,
            marker: PhantomData,
        }
    }

    pub fn view<T>(&self) -> Result<ArrayViewD<T>> {
        let typ = self.typ()?;
        let typ = typ.tensor_typ()?;

        let shape = typ
            .shape()?
            .into_iter()
            .map(|d| d as usize)
            .collect::<Vec<_>>();

        Ok(unsafe { ArrayViewD::from_shape_ptr(shape, self.data()? as *const T) })
    }

    pub fn view_mut<T>(&mut self) -> Result<ArrayViewMutD<T>> {
        let typ = self.typ()?;
        let typ = typ.tensor_typ()?;

        let shape = typ
            .shape()?
            .into_iter()
            .map(|d| d as usize)
            .collect::<Vec<_>>();
        Ok(unsafe { ArrayViewMutD::from_shape_ptr(shape, self.data()? as *mut T) })
    }
}

impl Drop for Value<'_> {
    fn drop(&mut self) {
        trace!(?self, "dropping");
        api!(ReleaseValue, self.inner);
    }
}

#[derive(SmartDefault)]
pub struct TensorBuilder {
    shape: Vec<i64>,
    #[default(ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED)]
    typ: ONNXTensorElementDataType,
    mem_info: MemoryInfo,
}

impl TensorBuilder {
    pub fn with_shape(mut self, shape: impl AsRef<[i64]>) -> Self {
        self.shape = shape.as_ref().to_vec();
        self
    }

    pub fn with_typ(mut self, typ: ONNXTensorElementDataType) -> Self {
        self.typ = typ;
        self
    }

    pub fn with_memory_info(mut self, mem_info: MemoryInfo) -> Self {
        self.mem_info = mem_info;
        self
    }

    pub fn borrow(self, input: &impl TensorTrait) -> Result<Value> {
        let Self {
            shape,
            typ,
            mem_info,
        } = self;

        let size = input.size();
        let data = input.data();

        let mut inner = null_mut();

        ok!(
            CreateTensorWithDataAsOrtValue,
            mem_info.inner(),
            data,
            size,
            shape.as_ptr(),
            shape.len(),
            typ,
            &mut inner
        )?;

        Ok(Value::new(inner, &data))
    }

    pub fn alloc(self, alloc: &impl AllocatorTrait) -> Result<Value> {
        let Self { shape, typ, .. } = self;

        let mut inner = null_mut();

        ok!(
            CreateTensorAsOrtValue,
            alloc.inner(),
            shape.as_ptr(),
            shape.len(),
            typ,
            &mut inner
        )?;

        Ok(Value {
            inner,
            marker: PhantomData,
        })
    }
}

impl<T> TensorTrait for &[T]
where
    T: Sized,
{
    fn data(&self) -> *mut c_void {
        self.as_ptr() as *mut c_void
    }

    #[allow(clippy::manual_slice_size_calculation)]
    fn size(&self) -> usize {
        self.len() * std::mem::size_of::<T>()
    }
}

impl<T> TensorTrait for Vec<T>
where
    T: Sized,
{
    fn data(&self) -> *mut c_void {
        self.as_ptr() as *mut c_void
    }

    fn size(&self) -> usize {
        self.len() * size_of::<T>()
    }
}