dlpark 0.7.0

dlpack Rust binding for Python
Documentation
use std::ffi::CStr;

use pyo3::conversion::{FromPyObject, IntoPyObject};
use pyo3::{Borrowed, Bound, PyAny, PyErr};

use crate::{SafeManagedTensor, SafeManagedTensorVersioned, ffi};

const DLTENSOR: &CStr = c"dltensor";
const USED_DLTENSOR: &CStr = c"used_dltensor";
const DLTENSOR_VERSIONED: &CStr = c"dltensor_versioned";
const USED_DLTENSOR_VERSIONED: &CStr = c"used_dltensor_versioned";

unsafe extern "C" fn dlpack_capsule_deleter(capsule: *mut pyo3::ffi::PyObject) {
    unsafe {
        if pyo3::ffi::PyCapsule_IsValid(capsule, USED_DLTENSOR.as_ptr()) == 1 {
            return;
        }

        let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR.as_ptr());

        if ptr.is_null() {
            pyo3::ffi::PyErr_WriteUnraisable(capsule);
            return;
        }

        let _ = SafeManagedTensor::from_raw(ptr as *mut ffi::ManagedTensor);
    }
}

unsafe extern "C" fn dlpack_capsule_deleter_versioned(capsule: *mut pyo3::ffi::PyObject) {
    unsafe {
        if pyo3::ffi::PyCapsule_IsValid(capsule, USED_DLTENSOR_VERSIONED.as_ptr()) == 1 {
            return;
        }

        let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_VERSIONED.as_ptr());

        if ptr.is_null() {
            pyo3::ffi::PyErr_WriteUnraisable(capsule);
            return;
        }

        let _ = SafeManagedTensorVersioned::from_raw(ptr as *mut ffi::ManagedTensorVersioned);
    }
}

fn capsule_to_raw_dlpack(
    capsule: *mut pyo3::ffi::PyObject,
    name: &CStr,
    used_name: &CStr,
) -> *mut std::ffi::c_void {
    unsafe {
        let ptr = pyo3::ffi::PyCapsule_GetPointer(capsule, name.as_ptr());
        pyo3::ffi::PyCapsule_SetName(capsule, used_name.as_ptr());
        ptr
    }
}

fn raw_dlpack_to_capsule(
    ptr: *mut std::ffi::c_void,
    name: &CStr,
    deleter: unsafe extern "C" fn(*mut pyo3::ffi::PyObject),
) -> *mut pyo3::ffi::PyObject {
    unsafe { pyo3::ffi::PyCapsule_New(ptr, name.as_ptr(), Some(deleter)) }
}

impl<'py> FromPyObject<'_, 'py> for SafeManagedTensor {
    type Error = PyErr;
    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
        let ptr = capsule_to_raw_dlpack(ob.as_ptr(), DLTENSOR, USED_DLTENSOR);
        unsafe { Ok(SafeManagedTensor::from_raw(ptr as *mut _)) }
    }
}

impl<'py> IntoPyObject<'py> for SafeManagedTensor {
    type Target = PyAny;
    type Output = Bound<'py, PyAny>;
    type Error = pyo3::PyErr;

    fn into_pyobject(self, py: pyo3::Python<'py>) -> pyo3::PyResult<Self::Output> {
        unsafe {
            let capsule =
                raw_dlpack_to_capsule(self.into_raw() as *mut _, DLTENSOR, dlpack_capsule_deleter);
            Bound::from_owned_ptr_or_err(py, capsule)
        }
    }
}

impl<'py> FromPyObject<'_, 'py> for SafeManagedTensorVersioned {
    type Error = PyErr;
    fn extract(ob: Borrowed<'_, 'py, PyAny>) -> Result<Self, Self::Error> {
        let ptr = capsule_to_raw_dlpack(ob.as_ptr(), DLTENSOR_VERSIONED, USED_DLTENSOR_VERSIONED);
        unsafe { Ok(SafeManagedTensorVersioned::from_raw(ptr as *mut _)) }
    }
}

impl<'py> IntoPyObject<'py> for SafeManagedTensorVersioned {
    type Target = PyAny;
    type Output = Bound<'py, PyAny>;
    type Error = pyo3::PyErr;

    fn into_pyobject(self, py: pyo3::Python<'py>) -> pyo3::PyResult<Self::Output> {
        unsafe {
            let capsule = raw_dlpack_to_capsule(
                self.into_raw() as *mut _,
                DLTENSOR_VERSIONED,
                dlpack_capsule_deleter_versioned,
            );
            Bound::from_owned_ptr_or_err(py, capsule)
        }
    }
}

#[cfg(test)]
mod tests {
    use pyo3::Python;

    use super::*;
    use crate::traits::TensorView;

    #[test]
    fn test_dlpack() {
        Python::initialize();
        Python::attach(|py| {
            let mt =
                SafeManagedTensor::new(vec![1i32, 2, 3]).expect("fail to make safe managed tensor");
            let ptr = mt.data_ptr();
            let capsule = mt.into_pyobject(py).expect("fail to convert to pyobject");
            let mt2 = SafeManagedTensor::extract((&capsule).into()).expect("fail to extract bound");
            assert_eq!(ptr, mt2.data_ptr());
        });
    }

    #[test]
    fn test_dlpack_versioned() {
        Python::initialize();
        Python::attach(|py| {
            let mt = SafeManagedTensorVersioned::new(vec![1i32, 2, 3])
                .expect("fail to make safe managed tensor");
            let ptr = mt.data_ptr();
            let capsule = mt.into_pyobject(py).expect("fail to convert to pyobject");
            let mt2 = SafeManagedTensorVersioned::extract((&capsule).into())
                .expect("fail to extract bound");
            assert_eq!(ptr, mt2.data_ptr());
        });
    }
}