use std::ptr::NonNull;
use crate::ffi::{self, Flags};
use crate::traits::{MemoryLayout, TensorLike, TensorView};
use super::ManagerContext;
pub struct SafeManagedTensorVersioned(ffi::DlpackVersioned);
impl Drop for SafeManagedTensorVersioned {
fn drop(&mut self) {
unsafe {
if let Some(deleter) = self.0.as_ref().deleter {
deleter(self.0.as_ptr());
}
}
}
}
impl SafeManagedTensorVersioned {
pub fn new<T, L>(t: T) -> std::result::Result<Self, T::Error>
where
T: TensorLike<L>,
L: MemoryLayout,
{
Self::with_flags(t, Flags::default())
}
pub fn with_flags<T, L>(t: T, flags: Flags) -> std::result::Result<Self, T::Error>
where
T: TensorLike<L>,
L: MemoryLayout,
{
let ctx = ManagerContext::new(t);
ctx.into_dlpack_versioned(flags).map(Self)
}
pub unsafe fn from_raw(ptr: *mut ffi::ManagedTensorVersioned) -> Self {
unsafe { Self(NonNull::new_unchecked(ptr)) }
}
pub unsafe fn from_non_null(ptr: ffi::DlpackVersioned) -> Self {
Self(ptr)
}
pub unsafe fn into_raw(self) -> *mut ffi::ManagedTensorVersioned {
let ptr = self.0.as_ptr();
std::mem::forget(self);
ptr
}
pub fn into_non_null(self) -> ffi::DlpackVersioned {
let ptr = self.0;
std::mem::forget(self);
ptr
}
pub fn flags(&self) -> &Flags {
unsafe { &self.0.as_ref().flags }
}
pub fn read_only(&self) -> bool {
self.flags().contains(Flags::READ_ONLY)
}
pub fn is_copied(&self) -> bool {
self.flags().contains(Flags::IS_COPIED)
}
pub fn is_subbtype_type_padded(&self) -> bool {
self.flags().contains(Flags::IS_SUBBYTE_TYPE_PADDED)
}
}
impl TensorView for SafeManagedTensorVersioned {
fn dl_tensor(&self) -> &ffi::Tensor {
unsafe { &self.0.as_ref().dl_tensor }
}
}
impl AsRef<SafeManagedTensorVersioned> for SafeManagedTensorVersioned {
fn as_ref(&self) -> &Self {
self
}
}