use alloc::{boxed::Box, format, sync::Arc};
use core::{
any::Any,
fmt::{self, Debug},
marker::PhantomData,
mem::transmute,
ops::{Deref, DerefMut},
ptr::{self, NonNull}
};
mod impl_map;
mod impl_sequence;
mod impl_tensor;
pub(crate) mod r#type;
pub use self::{
impl_map::{DynMap, DynMapRef, DynMapRefMut, DynMapValueType, Map, MapRef, MapRefMut, MapValueType, MapValueTypeMarker},
impl_sequence::{
DynSequence, DynSequenceRef, DynSequenceRefMut, DynSequenceValueType, Sequence, SequenceRef, SequenceRefMut, SequenceValueType, SequenceValueTypeMarker
},
impl_tensor::{
DefiniteTensorValueTypeMarker, DynTensor, DynTensorRef, DynTensorRefMut, DynTensorValueType, IntoTensorElementType, OwnedTensorArrayData,
PrimitiveTensorElementType, Shape, SymbolicDimensions, Tensor, TensorArrayData, TensorArrayDataMut, TensorArrayDataParts, TensorElementType, TensorRef,
TensorRefMut, TensorValueType, TensorValueTypeMarker, ToShape, Utf8Data
},
r#type::{Outlet, ValueType}
};
use crate::{
AsPointer,
error::{Error, ErrorCode, Result},
memory::MemoryInfo,
ortsys,
session::SharedSessionInner
};
#[derive(Debug)]
pub(crate) struct ValueInner {
pub(crate) ptr: NonNull<ort_sys::OrtValue>,
pub(crate) dtype: ValueType,
pub(crate) memory_info: Option<MemoryInfo>,
pub(crate) drop: bool,
_backing: Option<Box<dyn Any>>
}
impl ValueInner {
pub fn new(ptr: NonNull<ort_sys::OrtValue>, dtype: ValueType, memory_info: Option<MemoryInfo>, drop: bool) -> Arc<Self> {
crate::logging::create!(Value, ptr);
Arc::new(Self {
ptr,
dtype,
memory_info,
drop,
_backing: None
})
}
pub fn new_backed(ptr: NonNull<ort_sys::OrtValue>, dtype: ValueType, memory_info: Option<MemoryInfo>, drop: bool, backing: Box<dyn Any>) -> Arc<Self> {
crate::logging::create!(Value, ptr);
Arc::new(Self {
ptr,
dtype,
memory_info,
drop,
_backing: Some(backing)
})
}
pub(crate) fn is_backed(&self) -> bool {
self._backing.is_some()
}
}
impl AsPointer for ValueInner {
type Sys = ort_sys::OrtValue;
fn ptr(&self) -> *const Self::Sys {
self.ptr.as_ptr()
}
}
impl Drop for ValueInner {
fn drop(&mut self) {
if self.drop {
ortsys![unsafe ReleaseValue(self.ptr_mut())];
crate::logging::drop!(Value, self.ptr());
}
}
}
#[derive(Debug)]
pub struct ValueRef<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
inner: Value<Type>,
pub(crate) upgradable: bool,
lifetime: PhantomData<&'v ()>
}
impl<'v, Type: ValueTypeMarker + ?Sized> ValueRef<'v, Type> {
pub(crate) fn new(inner: Value<Type>) -> Self {
ValueRef {
upgradable: inner.inner.drop,
inner,
lifetime: PhantomData
}
}
#[inline]
pub fn downcast<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(self) -> Result<ValueRef<'v, OtherType>> {
let dt = self.dtype();
if OtherType::can_downcast(dt) {
Ok(unsafe { transmute::<ValueRef<'v, Type>, ValueRef<'v, OtherType>>(self) })
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", format_value_type::<OtherType>())))
}
}
pub fn try_upgrade(self) -> Result<Value<Type>, Self> {
if !self.upgradable {
return Err(self);
}
Ok(self.inner)
}
pub fn into_dyn(self) -> ValueRef<'v, DynValueTypeMarker> {
unsafe { transmute(self) }
}
}
impl<Type: ValueTypeMarker + ?Sized> Deref for ValueRef<'_, Type> {
type Target = Value<Type>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
#[derive(Debug)]
pub struct ValueRefMut<'v, Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
inner: Value<Type>,
pub(crate) upgradable: bool,
lifetime: PhantomData<&'v ()>
}
impl<'v, Type: ValueTypeMarker + ?Sized> ValueRefMut<'v, Type> {
pub(crate) fn new(inner: Value<Type>) -> Self {
ValueRefMut {
upgradable: inner.inner.drop,
inner,
lifetime: PhantomData
}
}
#[inline]
pub fn downcast<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(self) -> Result<ValueRefMut<'v, OtherType>> {
let dt = self.dtype();
if OtherType::can_downcast(dt) {
Ok(unsafe { transmute::<ValueRefMut<'v, Type>, ValueRefMut<'v, OtherType>>(self) })
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", format_value_type::<OtherType>())))
}
}
pub fn try_upgrade(self) -> Result<Value<Type>, Self> {
if !self.upgradable {
return Err(self);
}
Ok(self.inner)
}
pub fn into_dyn(self) -> ValueRefMut<'v, DynValueTypeMarker> {
unsafe { transmute(self) }
}
}
impl<Type: ValueTypeMarker + ?Sized> Deref for ValueRefMut<'_, Type> {
type Target = Value<Type>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<Type: ValueTypeMarker + ?Sized> DerefMut for ValueRefMut<'_, Type> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[derive(Debug)]
pub struct Value<Type: ValueTypeMarker + ?Sized = DynValueTypeMarker> {
pub(crate) inner: Arc<ValueInner>,
pub(crate) _markers: PhantomData<Type>
}
pub type DynValue = Value<DynValueTypeMarker>;
pub trait ValueTypeMarker {
#[doc(hidden)]
fn fmt(f: &mut fmt::Formatter) -> fmt::Result;
private_trait!();
}
pub(crate) struct ValueTypeFormatter<T: ?Sized>(PhantomData<T>);
#[inline]
pub(crate) fn format_value_type<T: ValueTypeMarker + ?Sized>() -> ValueTypeFormatter<T> {
ValueTypeFormatter(PhantomData)
}
impl<T: ValueTypeMarker + ?Sized> fmt::Display for ValueTypeFormatter<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<T as ValueTypeMarker>::fmt(f)
}
}
pub trait DowncastableTarget: ValueTypeMarker {
fn can_downcast(dtype: &ValueType) -> bool;
private_trait!();
}
impl DowncastableTarget for DynValueTypeMarker {
fn can_downcast(_: &ValueType) -> bool {
true
}
private_impl!();
}
#[derive(Debug)]
pub struct DynValueTypeMarker;
impl ValueTypeMarker for DynValueTypeMarker {
fn fmt(f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("DynValue")
}
private_impl!();
}
impl MapValueTypeMarker for DynValueTypeMarker {
private_impl!();
}
impl SequenceValueTypeMarker for DynValueTypeMarker {
private_impl!();
}
impl TensorValueTypeMarker for DynValueTypeMarker {
private_impl!();
}
unsafe impl<Type: ValueTypeMarker + ?Sized> Send for Value<Type> {}
unsafe impl<Type: ValueTypeMarker + ?Sized> Sync for Value<Type> {}
impl<Type: ValueTypeMarker + ?Sized> Value<Type> {
pub fn dtype(&self) -> &ValueType {
&self.inner.dtype
}
#[must_use]
pub unsafe fn from_ptr(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
let mut typeinfo_ptr = ptr::null_mut();
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr).expect("infallible"); nonNull(typeinfo_ptr)];
let dtype = unsafe { ValueType::from_type_info(typeinfo_ptr) };
let memory_info = unsafe { MemoryInfo::from_value(ptr) };
Value {
inner: match session {
Some(session) => ValueInner::new_backed(ptr, dtype, memory_info, true, Box::new(session)),
None => ValueInner::new(ptr, dtype, memory_info, true)
},
_markers: PhantomData
}
}
#[must_use]
pub(crate) unsafe fn from_ptr_nodrop(ptr: NonNull<ort_sys::OrtValue>, session: Option<Arc<SharedSessionInner>>) -> Value<Type> {
let mut typeinfo_ptr = ptr::null_mut();
ortsys![unsafe GetTypeInfo(ptr.as_ptr(), &mut typeinfo_ptr).expect("infallible"); nonNull(typeinfo_ptr)];
let dtype = unsafe { ValueType::from_type_info(typeinfo_ptr) };
let memory_info = unsafe { MemoryInfo::from_value(ptr) };
Value {
inner: match session {
Some(session) => ValueInner::new_backed(ptr, dtype, memory_info, false, Box::new(session)),
None => ValueInner::new(ptr, dtype, memory_info, false)
},
_markers: PhantomData
}
}
pub fn view(&self) -> ValueRef<'_, Type> {
ValueRef::new(Value::clone_of(self))
}
pub fn view_mut(&mut self) -> ValueRefMut<'_, Type> {
ValueRefMut::new(Value::clone_of(self))
}
pub fn into_dyn(self) -> DynValue {
unsafe { self.transmute_type() }
}
pub fn is_tensor(&self) -> bool {
let mut result = 0;
ortsys![unsafe IsTensor(self.ptr(), &mut result).expect("infallible")];
result == 1
}
#[inline(always)]
pub(crate) unsafe fn transmute_type<OtherType: ValueTypeMarker + ?Sized>(self) -> Value<OtherType> {
unsafe { transmute::<Value<Type>, Value<OtherType>>(self) }
}
#[inline(always)]
pub(crate) unsafe fn transmute_type_ref<OtherType: ValueTypeMarker + ?Sized>(&self) -> &Value<OtherType> {
unsafe { transmute::<&Value<Type>, &Value<OtherType>>(self) }
}
pub(crate) fn clone_of(value: &Self) -> Self {
Self {
inner: Arc::clone(&value.inner),
_markers: PhantomData
}
}
}
impl Value<DynValueTypeMarker> {
#[inline]
pub fn downcast<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(self) -> Result<Value<OtherType>> {
let dt = self.dtype();
if OtherType::can_downcast(dt) {
Ok(unsafe { transmute::<Value<DynValueTypeMarker>, Value<OtherType>>(self) })
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast {dt} to {}", format_value_type::<OtherType>())))
}
}
#[inline]
pub fn downcast_ref<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(&self) -> Result<ValueRef<'_, OtherType>> {
let dt = self.dtype();
if OtherType::can_downcast(dt) {
Ok(ValueRef::new(unsafe { transmute::<DynValue, Value<OtherType>>(Value::clone_of(self)) }))
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &{dt} to &{}", format_value_type::<OtherType>())))
}
}
#[inline]
pub fn downcast_mut<OtherType: ValueTypeMarker + DowncastableTarget + ?Sized>(&mut self) -> Result<ValueRefMut<'_, OtherType>> {
let dt = self.dtype();
if OtherType::can_downcast(dt) {
Ok(ValueRefMut::new(unsafe { transmute::<DynValue, Value<OtherType>>(Value::clone_of(self)) }))
} else {
Err(Error::new_with_code(ErrorCode::InvalidArgument, format!("Cannot downcast &mut {dt} to &mut {}", format_value_type::<OtherType>())))
}
}
}
impl<Type: ValueTypeMarker + ?Sized> AsPointer for Value<Type> {
type Sys = ort_sys::OrtValue;
fn ptr(&self) -> *const Self::Sys {
self.inner.ptr()
}
}
#[cfg(test)]
mod tests {
use super::{DynTensorValueType, Map, Sequence, Tensor, TensorRef, TensorRefMut, TensorValueType};
#[test]
fn test_casting_tensor() -> crate::Result<()> {
let tensor: Tensor<i32> = Tensor::from_array((vec![5], vec![1, 2, 3, 4, 5]))?;
let dyn_tensor = tensor.into_dyn();
let mut tensor: Tensor<i32> = dyn_tensor.downcast()?;
{
let dyn_tensor_ref = tensor.view().into_dyn();
let tensor_ref: TensorRef<i32> = dyn_tensor_ref.downcast()?;
assert_eq!(tensor_ref.extract_tensor(), tensor.extract_tensor());
}
{
let dyn_tensor_ref = tensor.view().into_dyn();
let tensor_ref: TensorRef<i32> = dyn_tensor_ref.downcast_ref()?;
assert_eq!(tensor_ref.extract_tensor(), tensor.extract_tensor());
}
{
let mut dyn_tensor_ref = tensor.view_mut().into_dyn();
let mut tensor_ref: TensorRefMut<i32> = dyn_tensor_ref.downcast_mut()?;
let (_, data) = tensor_ref.extract_tensor_mut();
data[2] = 42;
}
{
let (_, data) = tensor.extract_tensor_mut();
assert_eq!(data[2], 42);
}
{
let tensor = tensor
.into_dyn()
.downcast::<DynTensorValueType>()?
.into_dyn()
.downcast::<TensorValueType<i32>>()?
.upcast()
.into_dyn();
let tensor = tensor.view();
let tensor = tensor.downcast_ref::<TensorValueType<i32>>()?;
let (_, data) = tensor.extract_tensor();
assert_eq!(data, [1, 2, 42, 4, 5]);
}
Ok(())
}
#[test]
fn test_sequence_map() -> crate::Result<()> {
let map_contents = [("meaning".to_owned(), 42.0), ("pi".to_owned(), core::f32::consts::PI)];
let value = Sequence::new([Map::<String, f32>::new(map_contents)?])?;
for map in value.extract_sequence() {
let map = map.extract_key_values().into_iter().collect::<std::collections::HashMap<_, _>>();
assert_eq!(map["meaning"], 42.0);
assert_eq!(map["pi"], core::f32::consts::PI);
}
Ok(())
}
}