use std::ffi::{c_void, CStr};
use std::ptr::NonNull;
use pyo3::exceptions::{PyRuntimeError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyCapsule, PyCapsuleMethods};
use scirs2_numpy::dlpack::{
DLDataType, DLDataTypeCode, DLDevice, DLDeviceType, DLManagedTensor, DLTensor,
};
const DLTENSOR_NAME: &CStr = c"dltensor";
const USED_DLTENSOR_NAME: &CStr = c"used_dltensor";
struct BackingStore {
managed: DLManagedTensor,
data: Vec<f64>,
shape: Vec<i64>,
strides: Vec<i64>,
}
impl BackingStore {
unsafe fn drop_raw(ptr: *mut BackingStore) {
if !ptr.is_null() {
drop(unsafe { Box::from_raw(ptr) });
}
}
}
unsafe extern "C" fn backing_store_deleter(managed: *mut DLManagedTensor) {
if managed.is_null() {
return;
}
let backing = managed as *mut BackingStore;
unsafe { BackingStore::drop_raw(backing) };
}
unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_NAME.as_ptr()) };
if !ptr.is_null() {
let managed_ptr = ptr as *mut DLManagedTensor;
if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
unsafe { deleter(managed_ptr) };
}
}
}
#[pyfunction]
pub fn from_dlpack(py: Python<'_>, capsule: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
let cap = capsule.cast::<PyCapsule>().map_err(|_| {
PyTypeError::new_err(
"from_dlpack: argument must be a PyCapsule (the result of tensor.__dlpack__()). \
Got a non-capsule object instead.",
)
})?;
let name_opt = cap.name().map_err(|e| {
PyValueError::new_err(format!("from_dlpack: could not read capsule name: {e}"))
})?;
let name_matches = match name_opt {
None => false,
Some(cn) => {
let name_cstr = unsafe { cn.as_cstr() };
name_cstr == DLTENSOR_NAME
}
};
if !name_matches {
return Err(PyValueError::new_err(
"from_dlpack: expected a PyCapsule named 'dltensor'. \
Pass the result of tensor.__dlpack__() directly.",
));
}
let nn_ptr: NonNull<c_void> = cap
.pointer_checked(Some(DLTENSOR_NAME))
.map_err(|e| PyRuntimeError::new_err(format!("from_dlpack: null capsule pointer: {e}")))?;
let managed_ptr = nn_ptr.as_ptr() as *mut DLManagedTensor;
let dl_tensor: &DLTensor = unsafe { &(*managed_ptr).dl_tensor };
if dl_tensor.device.device_type != DLDeviceType::Cpu as i32 {
return Err(PyTypeError::new_err(format!(
"from_dlpack: only CPU tensors are supported (got device type {}). \
Copy the tensor to CPU before calling from_dlpack.",
dl_tensor.device.device_type
)));
}
if dl_tensor.data.is_null() {
return Err(PyValueError::new_err(
"from_dlpack: tensor has a null data pointer.",
));
}
let n_elems: usize = if dl_tensor.ndim == 0 || dl_tensor.shape.is_null() {
1
} else {
let shape_slice = unsafe {
std::slice::from_raw_parts(dl_tensor.shape as *const i64, dl_tensor.ndim as usize)
};
shape_slice.iter().map(|&d| d as usize).product()
};
let base_ptr = unsafe { (dl_tensor.data as *const u8).add(dl_tensor.byte_offset as usize) };
let dtype = dl_tensor.dtype;
let flat_vec: Vec<f64> = match (dtype.code, dtype.bits, dtype.lanes) {
(2, 32, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const f32, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(2, 64, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const f64, n_elems) };
slice.to_vec()
}
(0, 8, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i8, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(0, 16, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i16, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(0, 32, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i32, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(0, 64, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const i64, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(1, 8, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(1, 16, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u16, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(1, 32, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u32, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(1, 64, 1) => {
let slice = unsafe { std::slice::from_raw_parts(base_ptr as *const u64, n_elems) };
slice.iter().map(|&v| v as f64).collect()
}
(code, bits, _) => {
return Err(PyTypeError::new_err(format!(
"from_dlpack: unsupported dtype (code={code}, bits={bits}). \
Supported: int8/16/32/64, uint8/16/32/64, float32, float64.",
)));
}
};
let shape_vec: Vec<usize> = if dl_tensor.ndim == 0 || dl_tensor.shape.is_null() {
vec![n_elems]
} else {
let shape_slice = unsafe {
std::slice::from_raw_parts(dl_tensor.shape as *const i64, dl_tensor.ndim as usize)
};
shape_slice.iter().map(|&d| d as usize).collect()
};
let rename_result =
unsafe { pyo3::ffi::PyCapsule_SetName(cap.as_ptr(), USED_DLTENSOR_NAME.as_ptr()) };
let _ = rename_result;
if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
unsafe { deleter(managed_ptr) };
}
let numpy = py.import("numpy").map_err(|e| {
PyRuntimeError::new_err(format!("from_dlpack: could not import numpy: {e}"))
})?;
let arr = numpy.getattr("array")?.call1((flat_vec,))?;
let shaped = arr.call_method1("reshape", (shape_vec,))?;
Ok(shaped.into())
}
#[pyfunction]
pub fn to_dlpack(py: Python<'_>, array: &Bound<'_, PyAny>) -> PyResult<Py<PyAny>> {
let numpy = py
.import("numpy")
.map_err(|e| PyRuntimeError::new_err(format!("to_dlpack: could not import numpy: {e}")))?;
let arr = numpy.getattr("asarray")?.call1((array,))?;
let arr_f64 = numpy
.getattr("ascontiguousarray")?
.call((arr,), Some(&pyo3::types::PyDict::new(py)))?;
let shape_obj = arr_f64.getattr("shape")?;
let shape_tuple: Vec<i64> = shape_obj.extract::<Vec<i64>>().map_err(|e| {
PyTypeError::new_err(format!("to_dlpack: could not extract array shape: {e}"))
})?;
let flat_list = arr_f64.call_method0("flatten")?;
let data_vec: Vec<f64> = flat_list.extract::<Vec<f64>>().map_err(|e| {
PyTypeError::new_err(format!(
"to_dlpack: array must be convertible to float64: {e}"
))
})?;
let strides_vec: Vec<i64> = compute_c_strides(&shape_tuple);
let n = shape_tuple.len();
let mut store = Box::new(BackingStore {
managed: DLManagedTensor {
dl_tensor: DLTensor {
data: std::ptr::null_mut(), device: DLDevice {
device_type: DLDeviceType::Cpu as i32,
device_id: 0,
},
ndim: n as i32,
dtype: DLDataType {
code: DLDataTypeCode::Float as u8,
bits: 64,
lanes: 1,
},
shape: std::ptr::null_mut(), strides: std::ptr::null_mut(), byte_offset: 0,
},
manager_ctx: std::ptr::null_mut(),
deleter: Some(backing_store_deleter),
},
data: data_vec,
shape: shape_tuple,
strides: strides_vec,
});
store.managed.dl_tensor.data = store.data.as_mut_ptr() as *mut c_void;
store.managed.dl_tensor.shape = store.shape.as_mut_ptr();
store.managed.dl_tensor.strides = store.strides.as_mut_ptr();
let raw_store: *mut BackingStore = Box::into_raw(store);
let managed_nn = NonNull::new(raw_store as *mut c_void)
.ok_or_else(|| PyRuntimeError::new_err("to_dlpack: null BackingStore pointer"))?;
let capsule = unsafe {
PyCapsule::new_with_pointer_and_destructor(
py,
managed_nn,
DLTENSOR_NAME,
Some(capsule_destructor),
)
}
.map_err(|e| PyRuntimeError::new_err(format!("to_dlpack: failed to create capsule: {e}")))?;
Ok(capsule.into())
}
fn compute_c_strides(shape: &[i64]) -> Vec<i64> {
let n = shape.len();
if n == 0 {
return Vec::new();
}
let mut strides = vec![1i64; n];
for i in (0..n - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub fn register_dlpack_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(from_dlpack, m)?)?;
m.add_function(wrap_pyfunction!(to_dlpack, m)?)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dlpack_module_symbol_exists() {
let _msg = "dlpack module compiled successfully";
}
#[test]
fn compute_c_strides_1d() {
assert_eq!(compute_c_strides(&[5]), vec![1]);
}
#[test]
fn compute_c_strides_2d() {
assert_eq!(compute_c_strides(&[2, 3]), vec![3, 1]);
}
#[test]
fn compute_c_strides_3d() {
assert_eq!(compute_c_strides(&[2, 3, 4]), vec![12, 4, 1]);
}
#[test]
fn compute_c_strides_empty() {
assert_eq!(compute_c_strides(&[]), Vec::<i64>::new());
}
}