use pin_project::{pin_project, pinned_drop};
use core::slice;
use std::{
fmt::Debug,
marker::{PhantomData, PhantomPinned},
mem::transmute,
os::raw::c_void,
pin::Pin,
ptr::{self, NonNull},
};
use crate::{
datatype::DataType,
device::Device,
ffi::{DLManagedTensor, DLTensor},
};
#[derive(Debug)]
#[repr(transparent)]
pub struct Tensor<'tensor> {
pub inner: DLTensor,
_marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>, }
impl<'tensor> From<Tensor<'tensor>> for DLTensor {
fn from(ts: Tensor<'tensor>) -> Self {
ts.inner
}
}
impl<'tensor> From<DLTensor> for Tensor<'tensor> {
fn from(dts: DLTensor) -> Self {
Tensor {
inner: dts,
_marker: PhantomData,
}
}
}
impl<'tensor> Tensor<'tensor> {
pub fn new(
data: *mut c_void,
device: Device,
ndim: i32,
dtype: DataType,
shape: *mut i64,
strides: *mut i64,
byte_offset: u64,
) -> Self {
let inner = DLTensor {
data,
device: device.into(),
ndim,
dtype: dtype.into(),
shape,
strides,
byte_offset,
};
Tensor {
inner,
_marker: PhantomData,
}
}
pub fn into_inner(self) -> DLTensor {
self.inner
}
pub fn into_raw(self) -> *const DLTensor {
&self.inner as *const _
}
pub unsafe fn from_raw(ptr: *mut DLTensor) -> Self {
debug_assert!(!ptr.is_null());
Tensor {
inner: *ptr,
_marker: PhantomData,
}
}
pub fn data(&self) -> *mut c_void {
self.inner.data
}
pub fn device(&self) -> Device {
self.inner.device.into()
}
pub fn itemsize(&self) -> usize {
let ty = self.dtype();
ty.lanes() * ty.bits() / 8_usize
}
pub fn ndim(&self) -> usize {
self.inner.ndim as usize
}
pub fn dtype(&self) -> DataType {
self.inner.dtype.into()
}
pub fn shape(&self) -> Option<&[usize]> {
let dlt = self.inner;
if dlt.shape.is_null() || dlt.data.is_null() {
return None;
};
let ret = unsafe { slice::from_raw_parts(dlt.shape as *const _, dlt.ndim as usize) };
Some(ret)
}
pub fn strides(&self) -> Option<&[usize]> {
let dlt = self.inner;
if dlt.strides.is_null() || dlt.data.is_null() {
return None;
};
let ret = unsafe { slice::from_raw_parts(dlt.strides as *const _, dlt.ndim as usize) };
Some(ret)
}
pub fn byte_offset(&self) -> isize {
self.inner.byte_offset as isize
}
pub fn size(&self) -> Option<usize> {
let ty = self.dtype();
self.shape().map(|v| {
v.iter().product::<usize>() * (ty.bits() as usize * ty.lanes() as usize + 7) / 8
})
}
}
#[derive(Debug)]
#[repr(C)]
pub struct ManagerContext<C> {
pub ptr: Option<NonNull<*mut c_void>>,
ty: PhantomData<C>,
_pin: PhantomPinned,
}
impl<C> ManagerContext<C> {
pub fn new(ptr: Option<NonNull<*mut c_void>>) -> Self {
Self {
ptr,
ty: PhantomData,
_pin: PhantomPinned,
}
}
}
#[pin_project(PinnedDrop)]
#[repr(C)]
pub struct ManagedTensorProxy<C> {
pub dl_tensor: DLTensor,
#[pin]
pub manager_ctx: ManagerContext<C>, pub deleter: Option<fn(&mut ManagedTensor<C>)>,
}
impl<C: Debug> Debug for ManagedTensorProxy<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ManagedTensorProxy")
.field("dl_tensor", &self.dl_tensor)
.field("manager_ctx", &self.manager_ctx)
.finish()
}
}
impl<C> ManagedTensorProxy<C> {
pub fn dl_tensor(&self) -> DLTensor {
self.dl_tensor
}
pub fn manager_ctx(self: Pin<&mut Self>) -> Option<NonNull<*mut c_void>> {
let mut this = self.project();
this.manager_ctx.as_mut().ptr
}
pub fn set_manager_ctx(self: Pin<&mut Self>, manager_ctx: NonNull<*mut c_void>) {
let mut this = self.project();
let new = ManagerContext::new(Some(manager_ctx));
this.manager_ctx.set(new);
}
}
impl<C> From<DLManagedTensor> for ManagedTensorProxy<C> {
fn from(mut dlmt: DLManagedTensor) -> Self {
let ptr: Option<NonNull<*mut c_void>> = if dlmt.manager_ctx.is_null() {
None
} else {
unsafe { Some(NonNull::new_unchecked(&mut dlmt.manager_ctx as *mut _)) }
};
let manager_ctx = ManagerContext::new(ptr);
let deleter = dlmt.deleter.take().map(|del| unsafe {
transmute::<unsafe extern "C" fn(*mut DLManagedTensor), fn(&mut ManagedTensor<C>)>(del)
});
ManagedTensorProxy {
dl_tensor: dlmt.dl_tensor,
manager_ctx,
deleter,
}
}
}
impl<C> From<ManagedTensorProxy<C>> for DLManagedTensor {
fn from(pmt: ManagedTensorProxy<C>) -> Self {
let dl_tensor = pmt.dl_tensor;
let manager_ctx = match pmt.manager_ctx.ptr {
None => ptr::null_mut(),
Some(nnptr) => unsafe { *nnptr.as_ptr() },
};
let deleter = unsafe {
pmt.deleter.map(|del_fn| {
transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
del_fn,
)
})
};
DLManagedTensor {
dl_tensor,
manager_ctx,
deleter,
}
}
}
impl<C> From<Pin<&mut ManagedTensorProxy<C>>> for DLManagedTensor {
fn from(pmt: Pin<&mut ManagedTensorProxy<C>>) -> Self {
let dl_tensor = pmt.dl_tensor;
let manager_ctx = match pmt.manager_ctx.ptr {
None => ptr::null_mut(),
Some(nnptr) => unsafe { *nnptr.as_ptr() },
};
let deleter = unsafe {
pmt.deleter.map(|del_fn| {
transmute::<fn(&mut ManagedTensor<C>), unsafe extern "C" fn(*mut DLManagedTensor)>(
del_fn,
)
})
};
DLManagedTensor {
dl_tensor,
manager_ctx,
deleter,
}
}
}
#[allow(clippy::needless_lifetimes)]
#[pinned_drop]
impl<C> PinnedDrop for ManagedTensorProxy<C> {
fn drop(mut self: Pin<&mut Self>) {
let mut dlm: DLManagedTensor = self.as_mut().into();
if let Some(fptr) = self.deleter {
unsafe {
let cfptr = transmute::<fn(&mut ManagedTensor<C>), fn(*mut DLManagedTensor)>(fptr);
cfptr(&mut dlm as *mut _);
};
}
}
}
#[derive(Debug)]
#[repr(transparent)]
pub struct ManagedTensor<'tensor, C: 'tensor> {
pub inner: ManagedTensorProxy<C>,
_marker: PhantomData<fn(&'tensor ()) -> &'tensor ()>, }
impl<'tensor, C> From<DLManagedTensor> for ManagedTensor<'tensor, C> {
fn from(dlm: DLManagedTensor) -> Self {
let proxy: ManagedTensorProxy<C> = dlm.into();
ManagedTensor {
inner: proxy,
_marker: PhantomData,
}
}
}
impl<'tensor, C> From<ManagedTensor<'tensor, C>> for DLManagedTensor {
fn from(mt: ManagedTensor<'tensor, C>) -> Self {
mt.inner.into()
}
}
impl<'tensor, C: 'tensor> ManagedTensor<'tensor, C> {
pub fn new(tensor: Tensor<'tensor>, manager_ctx: Option<NonNull<*mut c_void>>) -> Self {
let manager_ctx = ManagerContext::new(manager_ctx);
let inner = ManagedTensorProxy {
dl_tensor: tensor.into_inner(),
manager_ctx,
deleter: None,
};
ManagedTensor {
inner,
_marker: PhantomData,
}
}
pub fn set_deleter(&mut self, deleter: fn(&mut ManagedTensor<C>)) {
self.inner.deleter = Some(deleter);
}
pub fn into_raw(self) -> *const DLManagedTensor {
let ret: DLManagedTensor = self.inner.into();
&ret as *const _
}
pub unsafe fn from_raw(ptr: *mut DLManagedTensor) -> Self {
debug_assert!(!ptr.is_null());
ManagedTensor {
inner: (*ptr).into(),
_marker: PhantomData,
}
}
pub fn into_tensor(self) -> Tensor<'tensor> {
self.inner.dl_tensor.into()
}
}