use numpy::{PyArray2, PyArrayMethods, PyUntypedArrayMethods};
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
use std::ffi::c_void;
pub fn register(_py: Python, parent: &Bound<'_, PyModule>) -> PyResult<()> {
parent.add(
"DLPACK_FLAG_BITMASK_READ_ONLY",
DLPACK_FLAG_BITMASK_READ_ONLY,
)?;
parent.add(
"DLPACK_FLAG_BITMASK_IS_COPIED",
DLPACK_FLAG_BITMASK_IS_COPIED,
)?;
parent.add("KDLCPU", device_type::KDLCPU)?;
parent.add("KDLCUDA", device_type::KDLCUDA)?;
parent.add("KDLCUDA_HOST", device_type::KDLCUDA_HOST)?;
parent.add("KDLOPENCL", device_type::KDLOPENCL)?;
parent.add("KDLVULKAN", device_type::KDLVULKAN)?;
parent.add("KDLMETAL", device_type::KDLMETAL)?;
parent.add("KDLVPI", device_type::KDLVPI)?;
parent.add("KDLROCM", device_type::KDLROCM)?;
parent.add("KDLROCM_HOST", device_type::KDLROCM_HOST)?;
parent.add("KDLEXT_DEV", device_type::KDLEXT_DEV)?;
parent.add("KDLCUDA_MANAGED", device_type::KDLCUDA_MANAGED)?;
parent.add("KDLONE_API", device_type::KDLONE_API)?;
parent.add("KDLWEB_GPU", device_type::KDLWEB_GPU)?;
parent.add("KDLHEXAGON", device_type::KDLHEXAGON)?;
Ok(())
}
pub const DLPACK_FLAG_BITMASK_READ_ONLY: u64 = 1 << 0;
pub const DLPACK_FLAG_BITMASK_IS_COPIED: u64 = 1 << 1;
#[repr(C)]
#[derive(Debug, Clone, Copy)]
#[allow(dead_code)]
pub struct DLPackVersion {
pub major: u32,
pub minor: u32,
}
pub mod device_type {
pub const KDLCPU: i32 = 1;
pub const KDLCUDA: i32 = 2;
pub const KDLCUDA_HOST: i32 = 3;
pub const KDLOPENCL: i32 = 4;
pub const KDLVULKAN: i32 = 7;
pub const KDLMETAL: i32 = 8;
pub const KDLVPI: i32 = 9;
pub const KDLROCM: i32 = 10;
pub const KDLROCM_HOST: i32 = 11;
pub const KDLEXT_DEV: i32 = 12;
pub const KDLCUDA_MANAGED: i32 = 13;
pub const KDLONE_API: i32 = 14;
pub const KDLWEB_GPU: i32 = 15;
pub const KDLHEXAGON: i32 = 16;
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct DLDevice {
pub device_type: i32, pub device_id: i32,
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(dead_code)]
#[allow(non_camel_case_types)]
#[allow(clippy::enum_variant_names)]
pub enum DLDataTypeCode {
kDLInt = 0,
kDLUInt = 1,
kDLFloat = 2,
kDLBfloat = 4,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct DLDataType {
pub code: u8, pub bits: u8, pub lanes: u16, }
#[repr(C)]
#[derive(Debug)]
pub struct DLTensor {
pub data: *mut c_void,
pub device: DLDevice,
pub ndim: i32,
pub dtype: DLDataType,
pub shape: *mut i64,
pub strides: *mut i64, pub byte_offset: u64,
}
type DeleterFn = unsafe extern "C" fn(*mut DLManagedTensor);
#[repr(C)]
pub struct DLManagedTensor {
pub dl_tensor: DLTensor,
pub manager_ctx: *mut c_void,
pub deleter: *const c_void, }
#[allow(unused)]
pub struct DLPackContext {
shape: Box<[i64; 2]>,
strides: Option<Box<[i64; 2]>>,
data_owner: Py<PyAny>,
}
impl DLPackContext {
pub fn new(_py: Python<'_>, arr: &Bound<'_, PyArray2<f64>>) -> Self {
let readonly = arr.readonly();
let shape = readonly.shape();
let shape_box = Box::new([shape[0] as i64, shape[1] as i64]);
let strides = if readonly.is_c_contiguous() {
None
} else {
let strides_bytes = readonly.strides();
let element_size = std::mem::size_of::<f64>() as isize;
Some(Box::new([
(strides_bytes[0] / element_size) as i64,
(strides_bytes[1] / element_size) as i64,
]))
};
Self {
shape: shape_box,
strides,
data_owner: arr.clone().into_any().unbind(),
}
}
}
unsafe extern "C" fn dlpack_deleter(managed: *mut DLManagedTensor) {
unsafe {
if managed.is_null() {
return;
}
}
let ctx_ptr = unsafe { (*managed).manager_ctx.cast::<DLPackContext>() };
if !ctx_ptr.is_null() {
unsafe {
if ffi::Py_IsInitialized() != 0 {
Python::attach(|_py| unsafe {
let _ctx = Box::from_raw(ctx_ptr);
});
} else {
}
}
}
unsafe {
let _managed = Box::from_raw(managed);
}
}
unsafe extern "C" fn pycapsule_destructor(capsule: *mut ffi::PyObject) {
unsafe {
unsafe {
if capsule.is_null() {
return;
}
}
let name = ffi::PyCapsule_GetName(capsule);
if name.is_null() {
return;
}
let name_str = std::ffi::CStr::from_ptr(name);
if name_str.to_bytes() != b"dltensor" {
return;
}
let ptr = ffi::PyCapsule_GetPointer(capsule, c"dltensor".as_ptr());
if !ptr.is_null() {
let managed = ptr.cast::<DLManagedTensor>();
unsafe {
if !(*managed).deleter.is_null() {
let deleter: DeleterFn = unsafe { std::mem::transmute((*managed).deleter) };
unsafe { deleter(managed) };
}
}
}
}
}
#[allow(unused)]
pub fn create_dlpack_capsule<'py>(
py: Python<'py>,
arr: &Bound<'py, PyArray2<f64>>,
flags: u64,
) -> PyResult<Bound<'py, PyCapsule>> {
let data_ptr = arr.data().cast::<c_void>();
let mut ctx = Box::new(DLPackContext::new(py, arr));
let dl_tensor = DLTensor {
data: data_ptr,
device: DLDevice {
device_type: device_type::KDLCPU,
device_id: 0,
},
ndim: 2,
dtype: DLDataType {
code: DLDataTypeCode::kDLFloat as u8,
bits: 64,
lanes: 1,
},
shape: ctx.shape.as_mut_ptr(),
strides: ctx
.strides
.as_mut()
.map_or(std::ptr::null_mut(), |s| s.as_mut_ptr()),
byte_offset: 0,
};
let managed = Box::new(DLManagedTensor {
dl_tensor,
manager_ctx: Box::into_raw(ctx).cast::<c_void>(),
deleter: dlpack_deleter as *const c_void,
});
let managed_ptr = Box::into_raw(managed).cast::<c_void>();
let name_ptr = c"dltensor".as_ptr();
unsafe {
let capsule_ptr = ffi::PyCapsule_New(managed_ptr, name_ptr, Some(pycapsule_destructor));
if capsule_ptr.is_null() {
let _ = Box::from_raw(managed_ptr.cast::<DLManagedTensor>());
return Err(PyErr::fetch(py));
}
Ok(Bound::from_owned_ptr(py, capsule_ptr).cast_into()?)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dlpack_structs_layout() {
assert_eq!(std::mem::size_of::<DLDevice>(), 8);
assert_eq!(std::mem::size_of::<DLDataType>(), 4);
assert_eq!(std::mem::align_of::<DLTensor>(), 8);
}
#[test]
fn test_device_type_values() {
assert_eq!(device_type::KDLCPU, 1);
assert_eq!(device_type::KDLCUDA, 2);
}
#[test]
fn test_datatype_code_values() {
assert_eq!(DLDataTypeCode::kDLFloat as u8, 2);
}
}