use std::ffi::c_void;
use std::ptr::NonNull;
use pyo3::ffi;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;
const DEVICE_CPU: i32 = 1;
const DTYPE_FLOAT: u8 = 2;
const F32_BITS: u8 = 32;
const LANES_DENSE: u16 = 1;
static CAPSULE_NAME: &std::ffi::CStr = c"dltensor";
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct DLDevice {
device_type: i32,
device_id: i32,
}
#[repr(C)]
#[derive(Debug, Clone, Copy)]
pub struct DLDataType {
code: u8,
bits: u8,
lanes: u16,
}
#[repr(C)]
pub struct DLTensor {
data: *mut c_void,
device: DLDevice,
ndim: i32,
dtype: DLDataType,
shape: *mut i64,
strides: *mut i64,
byte_offset: u64,
}
#[repr(C)]
pub struct DLManagedTensor {
dl_tensor: DLTensor,
manager_ctx: *mut c_void,
deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
}
struct ManagedTensorState {
data: Vec<f32>,
shape: Vec<i64>,
}
unsafe extern "C" fn managed_tensor_deleter(managed: *mut DLManagedTensor) {
if managed.is_null() {
return;
}
let managed_ref = unsafe { &*managed };
if !managed_ref.manager_ctx.is_null() {
drop(unsafe { Box::from_raw(managed_ref.manager_ctx as *mut ManagedTensorState) });
}
drop(unsafe { Box::from_raw(managed) });
}
unsafe extern "C" fn capsule_destructor(capsule: *mut ffi::PyObject) {
if capsule.is_null() {
return;
}
let ptr = unsafe { ffi::PyCapsule_GetPointer(capsule, CAPSULE_NAME.as_ptr()) };
if !ptr.is_null() {
let managed = ptr as *mut DLManagedTensor;
unsafe { managed_tensor_deleter(managed) };
}
}
pub fn vec_to_dlpack(py: Python<'_>, data: Vec<f32>, shape: Vec<i64>) -> PyResult<Py<PyCapsule>> {
let expected_len: usize = if shape.is_empty() {
1
} else {
shape.iter().map(|&d| d as usize).product()
};
if data.len() != expected_len {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"data length {} does not match shape product {} (shape={:?})",
data.len(),
expected_len,
shape
)));
}
let mut state = Box::new(ManagedTensorState { data, shape });
let data_ptr = state.data.as_mut_ptr() as *mut c_void;
let shape_ptr = state.shape.as_mut_ptr();
let ndim = state.shape.len() as i32;
let state_raw = Box::into_raw(state);
let dl_tensor = DLTensor {
data: data_ptr,
device: DLDevice {
device_type: DEVICE_CPU,
device_id: 0,
},
ndim,
dtype: DLDataType {
code: DTYPE_FLOAT,
bits: F32_BITS,
lanes: LANES_DENSE,
},
shape: shape_ptr,
strides: std::ptr::null_mut(), byte_offset: 0,
};
let managed = Box::new(DLManagedTensor {
dl_tensor,
manager_ctx: state_raw as *mut c_void,
deleter: Some(managed_tensor_deleter),
});
let managed_raw = Box::into_raw(managed);
let non_null_ptr = NonNull::new(managed_raw as *mut c_void)
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("null DLManagedTensor pointer"))?;
let capsule = unsafe {
PyCapsule::new_with_pointer_and_destructor(
py,
non_null_ptr,
CAPSULE_NAME,
Some(capsule_destructor),
)
}?;
Ok(capsule.unbind())
}
pub fn dlpack_to_vec(_py: Python<'_>, capsule: &Bound<'_, PyCapsule>) -> PyResult<Vec<f32>> {
let managed_ptr = unsafe { ffi::PyCapsule_GetPointer(capsule.as_ptr(), CAPSULE_NAME.as_ptr()) }
as *const DLManagedTensor;
if managed_ptr.is_null() {
return Err(pyo3::exceptions::PyValueError::new_err(
"DLPack capsule contains a null pointer",
));
}
let managed = unsafe { &*managed_ptr };
let tensor = &managed.dl_tensor;
if tensor.device.device_type != DEVICE_CPU {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"DLPack tensor is on device type {} (expected CPU=1)",
tensor.device.device_type
)));
}
if tensor.dtype.code != DTYPE_FLOAT
|| tensor.dtype.bits != F32_BITS
|| tensor.dtype.lanes != LANES_DENSE
{
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"DLPack tensor dtype is not float32 (got code={}, bits={}, lanes={})",
tensor.dtype.code, tensor.dtype.bits, tensor.dtype.lanes
)));
}
let ndim = tensor.ndim as usize;
let total_elements: usize = if ndim == 0 {
1
} else {
let shape_slice = unsafe { std::slice::from_raw_parts(tensor.shape, ndim) };
shape_slice.iter().map(|&d| d as usize).product()
};
let data_ptr =
(tensor.data as *const u8).wrapping_add(tensor.byte_offset as usize) as *const f32;
let slice = unsafe { std::slice::from_raw_parts(data_ptr, total_elements) };
Ok(slice.to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dlpack_shape_matches_input() {
let shape: Vec<i64> = vec![3, 4];
let data: Vec<f32> = vec![1.0_f32; 12];
let mut state = Box::new(ManagedTensorState {
data: data.clone(),
shape: shape.clone(),
});
let shape_ptr = state.shape.as_mut_ptr();
let ndim = state.shape.len() as i32;
let state_raw = Box::into_raw(state);
assert_eq!(ndim, 2, "ndim must be 2");
unsafe {
assert_eq!(*shape_ptr, 3, "shape[0] must be 3");
assert_eq!(*shape_ptr.add(1), 4, "shape[1] must be 4");
}
let _ = unsafe { Box::from_raw(state_raw) };
let _ = data; }
#[test]
fn dlpack_dtype_is_f32() {
let dtype = DLDataType {
code: DTYPE_FLOAT,
bits: F32_BITS,
lanes: LANES_DENSE,
};
assert_eq!(dtype.code, 2, "dtype code must be 2 (float)");
assert_eq!(dtype.bits, 32, "dtype bits must be 32 (f32)");
assert_eq!(dtype.lanes, 1, "dtype lanes must be 1 (scalar)");
}
#[test]
fn dlpack_device_is_cpu() {
let device = DLDevice {
device_type: DEVICE_CPU,
device_id: 0,
};
assert_eq!(device.device_type, 1, "device_type must be 1 (kCPU)");
assert_eq!(device.device_id, 0, "device_id must be 0 for single CPU");
}
#[test]
fn dlpack_capsule_name_is_dltensor() {
assert_eq!(
CAPSULE_NAME
.to_str()
.expect("capsule name must be valid UTF-8"),
"dltensor"
);
}
#[test]
fn dlpack_deleter_null_is_safe() {
unsafe { managed_tensor_deleter(std::ptr::null_mut()) };
}
}