use std::ffi::CStr;
use pyo3::conversion::{FromPyObject, IntoPyObject};
use pyo3::{Borrowed, Bound, PyAny, PyErr};
use crate::{SafeManagedTensor, SafeManagedTensorVersioned, ffi};
const DLTENSOR: &CStr = c"dltensor";
const USED_DLTENSOR: &CStr = c"used_dltensor";
const DLTENSOR_VERSIONED: &CStr = c"dltensor_versioned";
const USED_DLTENSOR_VERSIONED: &CStr = c"used_dltensor_versioned";
unsafe extern "C" fn dlpack_capsule_deleter(capsule: *mut pyo3::ffi::PyObject) {
unsafe {
if pyo3::ffi::PyCapsule_IsValid(capsule, USED_DLTENSOR.as_ptr()) == 1 {
return;
}
let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR.as_ptr());
if ptr.is_null() {
pyo3::ffi::PyErr_WriteUnraisable(capsule);
return;
}
let _ = SafeManagedTensor::from_raw(ptr as *mut ffi::ManagedTensor);
}
}
unsafe extern "C" fn dlpack_capsule_deleter_versioned(capsule: *mut pyo3::ffi::PyObject) {
unsafe {
if pyo3::ffi::PyCapsule_IsValid(capsule, USED_DLTENSOR_VERSIONED.as_ptr()) == 1 {
return;
}
let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_VERSIONED.as_ptr());
if ptr.is_null() {
pyo3::ffi::PyErr_WriteUnraisable(capsule);
return;
}
let _ = SafeManagedTensorVersioned::from_raw(ptr as *mut ffi::ManagedTensorVersioned);
}
}
fn capsule_to_raw_dlpack(
capsule: *mut pyo3::ffi::PyObject,
name: &CStr,
used_name: &CStr,
) -> *mut std::ffi::c_void {
unsafe {
let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, name.as_ptr());
pyo3::ffi::PyCapsule_SetName(capsule, used_name.as_ptr());
ptr
}
}
fn raw_dlpack_to_capsule(
ptr: *mut std::ffi::c_void,
name: &CStr,
deleter: unsafe extern "C" fn(*mut pyo3::ffi::PyObject),
) -> *mut pyo3::ffi::PyObject {
unsafe { pyo3::ffi::PyCapsule_New(ptr, name.as_ptr(), Some(deleter)) }
}
impl<'py> FromPyObject<'_, 'py> for SafeManagedTensor {
type Error = PyErr;
fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
let ptr = capsule_to_raw_dlpack(ob.as_ptr(), DLTENSOR, USED_DLTENSOR);
unsafe { Ok(SafeManagedTensor::from_raw(ptr as *mut _)) }
}
}
impl<'py> IntoPyObject<'py> for SafeManagedTensor {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
type Error = pyo3::PyErr;
fn into_pyobject(self, py: pyo3::Python<'py>) -> pyo3::PyResult<Self::Output> {
unsafe {
let capsule =
raw_dlpack_to_capsule(self.into_raw() as *mut _, DLTENSOR, dlpack_capsule_deleter);
Bound::from_owned_ptr_or_err(py, capsule)
}
}
}
impl<'py> FromPyObject<'_, 'py> for SafeManagedTensorVersioned {
type Error = PyErr;
fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
let ptr = capsule_to_raw_dlpack(ob.as_ptr(), DLTENSOR_VERSIONED, USED_DLTENSOR_VERSIONED);
unsafe { Ok(SafeManagedTensorVersioned::from_raw(ptr as *mut _)) }
}
}
impl<'py> IntoPyObject<'py> for SafeManagedTensorVersioned {
type Target = PyAny;
type Output = Bound<'py, PyAny>;
type Error = pyo3::PyErr;
fn into_pyobject(self, py: pyo3::Python<'py>) -> pyo3::PyResult<Self::Output> {
unsafe {
let capsule = raw_dlpack_to_capsule(
self.into_raw() as *mut _,
DLTENSOR_VERSIONED,
dlpack_capsule_deleter_versioned,
);
Bound::from_owned_ptr_or_err(py, capsule)
}
}
}
#[cfg(test)]
mod tests {
use pyo3::Python;
use super::*;
use crate::traits::TensorView;
#[test]
fn test_dlpack() {
Python::initialize();
Python::attach(|py| {
let mt =
SafeManagedTensor::new(vec![1i32, 2, 3]).expect("fail to make safe managed tensor");
let ptr = mt.data_ptr();
let capsule = mt.into_pyobject(py).expect("fail to convert to pyobject");
let mt2 = SafeManagedTensor::extract((&capsule).into()).expect("fail to extract bound");
assert_eq!(ptr, mt2.data_ptr());
});
}
#[test]
fn test_dlpack_versioned() {
Python::initialize();
Python::attach(|py| {
let mt = SafeManagedTensorVersioned::new(vec![1i32, 2, 3])
.expect("fail to make safe managed tensor");
let ptr = mt.data_ptr();
let capsule = mt.into_pyobject(py).expect("fail to convert to pyobject");
let mt2 = SafeManagedTensorVersioned::extract((&capsule).into())
.expect("fail to extract bound");
assert_eq!(ptr, mt2.data_ptr());
});
}
}