use pyo3::prelude::*;
use pyo3::types::PyCapsule;
use std::ffi::c_void;
use std::ffi::CStr;
use std::ptr::NonNull;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(i32)]
pub enum DLDeviceType {
Cpu = 1,
Cuda = 2,
CudaHost = 3,
OpenCL = 4,
Vulkan = 7,
Metal = 8,
Rocm = 10,
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct DLDataType {
pub code: u8,
pub bits: u8,
pub lanes: u16,
}
#[derive(Debug, Clone, Copy)]
#[repr(C)]
pub struct DLDevice {
pub device_type: i32,
pub device_id: i32,
}
#[derive(Debug)]
#[repr(C)]
pub struct DLTensor {
pub data: *mut c_void,
pub device: DLDevice,
pub ndim: i32,
pub dtype: DLDataType,
pub shape: *mut i64,
pub strides: *mut i64,
pub byte_offset: u64,
}
#[repr(C)]
pub struct DLManagedTensor {
pub dl_tensor: DLTensor,
pub manager_ctx: *mut c_void,
pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
}
unsafe impl Send for DLManagedTensor {}
unsafe impl Sync for DLManagedTensor {}
#[pyclass(name = "DLPackCapsule")]
pub struct DLPackCapsule {
shape: Vec<i64>,
strides: Vec<i64>,
#[allow(dead_code)]
data: Vec<u8>,
dtype: DLDataType,
device: DLDevice,
}
#[pymethods]
impl DLPackCapsule {
#[new]
pub fn new(shape: Vec<i64>, dtype_code: u8, dtype_bits: u8) -> Self {
let n: i64 = shape.iter().product();
let bytes_per_elem = (dtype_bits as usize).div_ceil(8).max(1);
let n_bytes = (n as usize) * bytes_per_elem;
let strides = compute_row_major_strides(&shape);
Self {
shape,
strides,
data: vec![0u8; n_bytes],
dtype: DLDataType {
code: dtype_code,
bits: dtype_bits,
lanes: 1,
},
device: DLDevice {
device_type: DLDeviceType::Cpu as i32,
device_id: 0,
},
}
}
#[pyo3(name = "__dlpack_device__")]
pub fn dlpack_device(&self) -> (i32, i32) {
(self.device.device_type, self.device.device_id)
}
#[pyo3(name = "__dlpack__")]
pub fn dlpack<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
let mut shape_buf = self.shape.clone().into_boxed_slice();
let mut strides_buf = self.strides.clone().into_boxed_slice();
let managed = Box::new(DLManagedTensor {
dl_tensor: DLTensor {
data: shape_buf.as_mut_ptr() as *mut c_void, device: self.device,
ndim: self.shape.len() as i32,
dtype: self.dtype,
shape: shape_buf.as_mut_ptr(),
strides: strides_buf.as_mut_ptr(),
byte_offset: 0,
},
manager_ctx: std::ptr::null_mut(),
deleter: Some(dlpack_deleter),
});
std::mem::forget(shape_buf);
std::mem::forget(strides_buf);
let raw_ptr = Box::into_raw(managed);
let non_null = NonNull::new(raw_ptr as *mut c_void)
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("null managed tensor ptr"))?;
unsafe {
PyCapsule::new_with_pointer_and_destructor(
py,
non_null,
DLTENSOR_CAPSULE_NAME,
Some(capsule_destructor),
)
}
}
pub fn shape(&self) -> Vec<i64> {
self.shape.clone()
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn dtype_code(&self) -> u8 {
self.dtype.code
}
pub fn dtype_bits(&self) -> u8 {
self.dtype.bits
}
}
const DLTENSOR_CAPSULE_NAME: &CStr = c"dltensor";
unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_CAPSULE_NAME.as_ptr()) };
if !ptr.is_null() {
let managed_ptr = ptr as *mut DLManagedTensor;
if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
unsafe { deleter(managed_ptr) };
}
}
}
unsafe extern "C" fn dlpack_deleter(managed: *mut DLManagedTensor) {
if !managed.is_null() {
let _ = unsafe { Box::from_raw(managed) };
}
}
fn compute_row_major_strides(shape: &[i64]) -> Vec<i64> {
let n = shape.len();
let mut strides = vec![1i64; n];
if n > 1 {
for i in (0..n - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
}
strides
}
pub fn register_dlpack_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<DLPackCapsule>()?;
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DLDataTypeCode {
Int = 0,
UInt = 1,
Float = 2,
BFloat = 3,
}
impl TryFrom<u8> for DLDataTypeCode {
type Error = DlpackError;
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Int),
1 => Ok(Self::UInt),
2 => Ok(Self::Float),
3 => Ok(Self::BFloat),
other => Err(DlpackError::UnsupportedDtype {
code: other,
bits: 0,
}),
}
}
}
#[derive(Debug, Clone)]
pub struct DLTensorInfo {
pub shape: Vec<i64>,
pub dtype_code: DLDataTypeCode,
pub dtype_bits: u8,
pub device_type: DLDeviceType,
}
#[derive(Debug, thiserror::Error)]
pub enum DlpackError {
#[error("unsupported device: expected CPU")]
NonCpuDevice,
#[error("unsupported dtype: {code}:{bits}")]
UnsupportedDtype {
code: u8,
bits: u8,
},
#[error("null data pointer")]
NullPointer,
}
pub fn validate_dlpack_tensor(tensor: &DLTensor) -> Result<DLTensorInfo, DlpackError> {
if tensor.data.is_null() {
return Err(DlpackError::NullPointer);
}
let device_type = decode_device_type(tensor.device.device_type);
let dtype_code = DLDataTypeCode::try_from(tensor.dtype.code)?;
let shape = if tensor.ndim == 0 || tensor.shape.is_null() {
Vec::new()
} else {
unsafe {
std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize).to_vec()
}
};
Ok(DLTensorInfo {
shape,
dtype_code,
dtype_bits: tensor.dtype.bits,
device_type,
})
}
pub fn dlpack_from_slice(data: &[f64], shape: &[i64]) -> DLTensor {
DLTensor {
data: data.as_ptr() as *mut c_void,
device: DLDevice {
device_type: DLDeviceType::Cpu as i32,
device_id: 0,
},
ndim: shape.len() as i32,
dtype: DLDataType {
code: DLDataTypeCode::Float as u8,
bits: 64,
lanes: 1,
},
shape: shape.as_ptr() as *mut i64,
strides: std::ptr::null_mut(), byte_offset: 0,
}
}
pub fn dlpack_to_vec_f64(tensor: &DLTensor) -> Result<Vec<f64>, DlpackError> {
if tensor.data.is_null() {
return Err(DlpackError::NullPointer);
}
let device_type = tensor.device.device_type;
if device_type != DLDeviceType::Cpu as i32 {
return Err(DlpackError::NonCpuDevice);
}
if tensor.dtype.code != DLDataTypeCode::Float as u8
|| tensor.dtype.bits != 64
|| tensor.dtype.lanes != 1
{
return Err(DlpackError::UnsupportedDtype {
code: tensor.dtype.code,
bits: tensor.dtype.bits,
});
}
let n_elems = if tensor.ndim == 0 {
1usize
} else if tensor.shape.is_null() {
0usize
} else {
let shape =
unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize) };
shape.iter().map(|&d| d as usize).product()
};
let base = unsafe { (tensor.data as *const u8).add(tensor.byte_offset as usize) as *const f64 };
let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
Ok(slice.to_vec())
}
fn decode_device_type(raw: i32) -> DLDeviceType {
match raw {
1 => DLDeviceType::Cpu,
2 => DLDeviceType::Cuda,
3 => DLDeviceType::CudaHost,
4 => DLDeviceType::OpenCL,
7 => DLDeviceType::Vulkan,
8 => DLDeviceType::Metal,
10 => DLDeviceType::Rocm,
_ => DLDeviceType::Cpu, }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_valid_f64_cpu_tensor() {
let mut data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut shape = vec![2_i64, 3];
let tensor = dlpack_from_slice(&data, &shape);
let info = validate_dlpack_tensor(&tensor).expect("validate_dlpack_tensor failed");
assert_eq!(info.shape, vec![2, 3]);
assert_eq!(info.dtype_code, DLDataTypeCode::Float);
assert_eq!(info.dtype_bits, 64);
assert_eq!(info.device_type, DLDeviceType::Cpu);
let _ = (&mut data, &mut shape);
}
#[test]
fn test_validate_null_pointer_returns_err() {
let shape = vec![3_i64];
let mut tensor = dlpack_from_slice(&[0.0_f64; 3], &shape);
tensor.data = std::ptr::null_mut();
let result = validate_dlpack_tensor(&tensor);
assert!(
matches!(result, Err(DlpackError::NullPointer)),
"expected NullPointer error"
);
}
#[test]
fn test_validate_shape_fields() {
let data = vec![10.0_f64; 12];
let shape = vec![3_i64, 4];
let tensor = dlpack_from_slice(&data, &shape);
let info = validate_dlpack_tensor(&tensor).expect("validate failed");
assert_eq!(info.shape, vec![3, 4]);
}
#[test]
fn test_dlpack_from_slice_shape_fields() {
let data = vec![1.0_f64, 2.0, 3.0];
let shape = vec![3_i64];
let tensor = dlpack_from_slice(&data, &shape);
assert_eq!(tensor.ndim, 1);
assert!(!tensor.data.is_null());
assert!(!tensor.shape.is_null());
assert_eq!(tensor.dtype.code, 2); assert_eq!(tensor.dtype.bits, 64);
}
#[test]
fn test_dlpack_from_slice_2d() {
let data = vec![0.0_f64; 6];
let shape = vec![2_i64, 3];
let tensor = dlpack_from_slice(&data, &shape);
assert_eq!(tensor.ndim, 2);
let s = unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, 2) };
assert_eq!(s, [2, 3]);
}
#[test]
fn test_dlpack_to_vec_f64_round_trip() {
let original = vec![1.0_f64, 2.5, 3.15, -7.0, 0.0];
let shape = vec![5_i64];
let tensor = dlpack_from_slice(&original, &shape);
let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
assert_eq!(recovered, original);
}
#[test]
fn test_dlpack_to_vec_f64_2d() {
let original: Vec<f64> = (0..6).map(|i| i as f64).collect();
let shape = vec![2_i64, 3];
let tensor = dlpack_from_slice(&original, &shape);
let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
assert_eq!(recovered, original);
}
#[test]
fn test_dlpack_to_vec_f64_null_pointer_err() {
let data = vec![0.0_f64];
let shape = vec![1_i64];
let mut tensor = dlpack_from_slice(&data, &shape);
tensor.data = std::ptr::null_mut();
assert!(matches!(
dlpack_to_vec_f64(&tensor),
Err(DlpackError::NullPointer)
));
}
#[test]
fn test_dlpack_to_vec_f64_non_cpu_err() {
let data = vec![0.0_f64];
let shape = vec![1_i64];
let mut tensor = dlpack_from_slice(&data, &shape);
tensor.device.device_type = DLDeviceType::Cuda as i32;
assert!(matches!(
dlpack_to_vec_f64(&tensor),
Err(DlpackError::NonCpuDevice)
));
}
#[test]
fn test_dlpack_to_vec_f64_wrong_dtype_err() {
let data = vec![0.0_f64];
let shape = vec![1_i64];
let mut tensor = dlpack_from_slice(&data, &shape);
tensor.dtype.code = 0;
assert!(matches!(
dlpack_to_vec_f64(&tensor),
Err(DlpackError::UnsupportedDtype { .. })
));
}
#[test]
fn test_dtype_code_try_from() {
assert_eq!(DLDataTypeCode::try_from(0u8).unwrap(), DLDataTypeCode::Int);
assert_eq!(DLDataTypeCode::try_from(1u8).unwrap(), DLDataTypeCode::UInt);
assert_eq!(
DLDataTypeCode::try_from(2u8).unwrap(),
DLDataTypeCode::Float
);
assert_eq!(
DLDataTypeCode::try_from(3u8).unwrap(),
DLDataTypeCode::BFloat
);
assert!(DLDataTypeCode::try_from(99u8).is_err());
}
}