use std::convert::TryFrom;
use std::ffi::CString;
use std::fmt;
use std::ptr::NonNull;
use std::sync::atomic::AtomicI32;
use tvm_macros::Object;
use tvm_sys::ffi::{self, TVMObjectFree, TVMObjectRetain, TVMObjectTypeKey2Index};
use tvm_sys::{ArgValue, RetValue};
use crate::errors::Error;
type Deleter = unsafe extern "C" fn(object: *mut Object) -> ();
#[derive(Debug, Object)]
#[ref_name = "ObjectRef"]
#[type_key = "runtime.Object"]
#[repr(C)]
pub struct Object {
pub(self) type_index: u32,
pub(self) ref_count: AtomicI32,
pub(self) fdeleter: Deleter,
}
unsafe extern "C" fn delete<T: IsObject>(object: *mut Object) {
let typed_object: *mut T = object as *mut T;
T::typed_delete(typed_object);
}
fn derived_from(child_type_index: u32, parent_type_index: u32) -> bool {
let mut is_derived = 0;
crate::check_call!(ffi::TVMObjectDerivedFrom(
child_type_index,
parent_type_index,
&mut is_derived
));
if is_derived == 0 {
false
} else {
true
}
}
impl Object {
fn new(type_index: u32, deleter: Deleter) -> Object {
Object {
type_index,
ref_count: AtomicI32::new(0),
fdeleter: deleter,
}
}
fn get_type_index<T: IsObject>() -> u32 {
let type_key = T::TYPE_KEY;
let cstring = CString::new(type_key).expect("type key must not contain null characters");
if type_key == "runtime.Object" {
return 0;
} else {
let mut index = 0;
unsafe {
if TVMObjectTypeKey2Index(cstring.as_ptr(), &mut index) != 0 {
panic!(crate::get_last_error())
}
}
return index;
}
}
pub fn count(&self) -> i32 {
self.ref_count.load(std::sync::atomic::Ordering::Relaxed)
}
pub fn base<T: IsObject>() -> Object {
let index = Object::get_type_index::<T>();
Object::new(index, delete::<T>)
}
pub(self) fn inc_ref(&self) {
let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
unsafe {
assert_eq!(TVMObjectRetain(raw_ptr), 0);
}
}
pub(self) fn dec_ref(&self) {
let raw_ptr = self as *const Object as *mut Object as *mut std::ffi::c_void;
unsafe {
assert_eq!(TVMObjectFree(raw_ptr), 0);
}
}
}
pub unsafe trait IsObject: AsRef<Object> + std::fmt::Debug {
const TYPE_KEY: &'static str;
unsafe extern "C" fn typed_delete(object: *mut Self) {
let object = Box::from_raw(object);
drop(object)
}
}
#[repr(C)]
pub struct ObjectPtr<T: IsObject> {
pub ptr: NonNull<T>,
}
impl ObjectPtr<Object> {
pub fn from_raw(object_ptr: *mut Object) -> Option<ObjectPtr<Object>> {
let non_null = NonNull::new(object_ptr);
non_null.map(|ptr| {
debug_assert!(unsafe { ptr.as_ref().count() } >= 0);
ObjectPtr { ptr }
})
}
}
impl<T: IsObject> Clone for ObjectPtr<T> {
fn clone(&self) -> Self {
unsafe { self.ptr.as_ref().as_ref().inc_ref() }
ObjectPtr { ptr: self.ptr }
}
}
impl<T: IsObject> Drop for ObjectPtr<T> {
fn drop(&mut self) {
unsafe { self.ptr.as_ref().as_ref().dec_ref() }
}
}
impl<T: IsObject> ObjectPtr<T> {
pub fn leak<'a>(object_ptr: ObjectPtr<T>) -> &'a mut T
where
T: 'a,
{
unsafe { &mut *std::mem::ManuallyDrop::new(object_ptr).ptr.as_ptr() }
}
pub fn new(object: T) -> ObjectPtr<T> {
object.as_ref().inc_ref();
let object_ptr = Box::new(object);
let object_ptr = Box::leak(object_ptr);
let ptr = NonNull::from(object_ptr);
ObjectPtr { ptr }
}
pub fn count(&self) -> i32 {
self.as_ref()
.ref_count
.load(std::sync::atomic::Ordering::Relaxed)
}
unsafe fn cast<U: IsObject>(self) -> ObjectPtr<U> {
let ptr = self.ptr.cast();
std::mem::forget(self);
ObjectPtr { ptr }
}
pub fn upcast<U>(self) -> ObjectPtr<U>
where
U: IsObject,
T: AsRef<U>,
{
unsafe { self.cast() }
}
pub fn downcast<U>(self) -> Result<ObjectPtr<U>, Error>
where
U: IsObject + AsRef<T>,
{
let child_index = Object::get_type_index::<U>();
let object_index = self.as_ref().type_index;
let is_derived = if child_index == object_index {
true
} else {
derived_from(object_index, child_index)
};
if is_derived {
Ok(unsafe { self.cast() })
} else {
Err(Error::downcast("TODOget_type_key".into(), U::TYPE_KEY))
}
}
pub unsafe fn into_raw(self) -> *mut T {
self.ptr.as_ptr()
}
}
impl<T: IsObject> std::ops::Deref for ObjectPtr<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { self.ptr.as_ref() }
}
}
impl<T: IsObject> fmt::Debug for ObjectPtr<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use std::ops::Deref;
write!(f, "{:?}", self.deref())
}
}
impl<'a, T: IsObject> From<ObjectPtr<T>> for RetValue {
fn from(object_ptr: ObjectPtr<T>) -> RetValue {
let raw_object_ptr = ObjectPtr::leak(object_ptr) as *mut T as *mut std::ffi::c_void;
assert!(!raw_object_ptr.is_null());
RetValue::ObjectHandle(raw_object_ptr)
}
}
impl<'a, T: IsObject> TryFrom<RetValue> for ObjectPtr<T> {
type Error = Error;
fn try_from(ret_value: RetValue) -> Result<ObjectPtr<T>, Self::Error> {
use crate::ffi::DLTensor;
use crate::ndarray::NDArrayContainer;
match ret_value {
RetValue::ObjectHandle(handle) | RetValue::ModuleHandle(handle) => {
let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
optr.downcast()
}
RetValue::NDArrayHandle(handle) => {
let optr: ObjectPtr<NDArrayContainer> =
NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
optr.upcast::<Object>().downcast()
}
_ => Err(Error::downcast(format!("{:?}", ret_value), T::TYPE_KEY)),
}
}
}
impl<'a, T: IsObject> From<ObjectPtr<T>> for ArgValue<'a> {
fn from(object_ptr: ObjectPtr<T>) -> ArgValue<'a> {
debug_assert!(object_ptr.count() >= 1);
let object_ptr = object_ptr.upcast::<Object>();
match T::TYPE_KEY {
"runtime.NDArray" => {
use crate::ndarray::NDArrayContainer;
let raw_ptr = NDArrayContainer::leak(object_ptr.downcast().unwrap())
as *mut NDArrayContainer as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
ArgValue::NDArrayHandle(raw_ptr)
}
"runtime.Module" => {
let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
ArgValue::ModuleHandle(raw_ptr)
}
_ => {
let raw_ptr = ObjectPtr::leak(object_ptr) as *mut Object as *mut std::ffi::c_void;
assert!(!raw_ptr.is_null());
ArgValue::ObjectHandle(raw_ptr)
}
}
}
}
impl<'a, T: IsObject> TryFrom<ArgValue<'a>> for ObjectPtr<T> {
type Error = Error;
fn try_from(arg_value: ArgValue<'a>) -> Result<ObjectPtr<T>, Self::Error> {
use crate::ffi::DLTensor;
use crate::ndarray::NDArrayContainer;
match arg_value {
ArgValue::ObjectHandle(handle) | ArgValue::ModuleHandle(handle) => {
let optr = ObjectPtr::from_raw(handle as *mut Object).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
optr.downcast()
}
ArgValue::NDArrayHandle(handle) => {
let optr =
NDArrayContainer::from_raw(handle as *mut DLTensor).ok_or(Error::Null)?;
debug_assert!(optr.count() >= 1);
optr.upcast::<Object>().downcast()
}
_ => Err(Error::downcast(format!("{:?}", arg_value), "ObjectHandle")),
}
}
}
impl<T: IsObject> std::hash::Hash for ObjectPtr<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
state.write_i64(
super::structural_hash(ObjectRef(Some(self.clone().upcast())), false).unwrap(),
)
}
}
impl<T: IsObject> PartialEq for ObjectPtr<T> {
fn eq(&self, other: &Self) -> bool {
let lhs = ObjectRef(Some(self.clone().upcast()));
let rhs = ObjectRef(Some(other.clone().upcast()));
super::structural_equal(lhs, rhs, false, false).unwrap()
}
}
impl<T: IsObject> Eq for ObjectPtr<T> {}
#[cfg(test)]
mod tests {
use super::{Object, ObjectPtr};
use anyhow::{ensure, Result};
use std::convert::TryInto;
use tvm_sys::{ArgValue, RetValue};
#[test]
fn test_new_object() -> anyhow::Result<()> {
let object = Object::base::<Object>();
let ptr = ObjectPtr::new(object);
assert_eq!(ptr.count(), 1);
Ok(())
}
#[test]
fn test_leak() -> anyhow::Result<()> {
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let object = ObjectPtr::leak(ptr);
assert_eq!(object.count(), 1);
Ok(())
}
#[test]
fn test_clone() -> anyhow::Result<()> {
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let ptr2 = ptr.clone();
assert_eq!(ptr2.count(), 2);
drop(ptr);
assert_eq!(ptr2.count(), 1);
Ok(())
}
#[test]
fn roundtrip_retvalue() -> Result<()> {
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let ret_value: RetValue = ptr.clone().into();
let ptr2: ObjectPtr<Object> = ret_value.try_into()?;
assert_eq!(ptr.count(), ptr2.count());
assert_eq!(ptr.count(), 2);
ensure!(
ptr.type_index == ptr2.type_index,
"type indices do not match"
);
ensure!(
ptr.fdeleter == ptr2.fdeleter,
"objects have different deleters"
);
drop(ptr2);
assert_eq!(ptr.count(), 1);
Ok(())
}
#[test]
fn roundtrip_argvalue() -> Result<()> {
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let ptr_clone = ptr.clone();
assert_eq!(ptr.count(), 2);
let arg_value: ArgValue = ptr_clone.into();
assert_eq!(ptr.count(), 2);
let ptr2: ObjectPtr<Object> = arg_value.try_into()?;
assert_eq!(ptr2.count(), 2);
assert_eq!(ptr.count(), ptr2.count());
assert_eq!(ptr.count(), 2);
ensure!(
ptr.type_index == ptr2.type_index,
"type indices do not match"
);
ensure!(
ptr.fdeleter == ptr2.fdeleter,
"objects have different deleters"
);
drop(ptr2);
assert_eq!(ptr.count(), 1);
Ok(())
}
fn test_fn(o: ObjectPtr<Object>) -> ObjectPtr<Object> {
assert_eq!(o.count(), 3);
return o;
}
#[test]
fn test_ref_count_boundary3() {
use super::*;
use crate::function::{register, Function};
let ptr = ObjectPtr::new(Object::base::<Object>());
assert_eq!(ptr.count(), 1);
let stay = ptr.clone();
assert_eq!(ptr.count(), 2);
register(test_fn, "my_func2").unwrap();
let func = Function::get("my_func2").unwrap();
let same = func.invoke(vec![ptr.into()]).unwrap();
let same: ObjectPtr<Object> = same.try_into().unwrap();
drop(same);
assert_eq!(stay.count(), 3);
}
}