use crate::ffi::{
DLDataType, DLDevice, DLManagedTensor, DLManagedTensorVersioned, DLPackVersion, DLTensor,
DLPACK_FLAG_BITMASK_READ_ONLY, DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION,
};
use crate::{
DLPACK_CAPSULE_NAME, DLPACK_CAPSULE_NAME_USED, DLPACK_VERSIONED_CAPSULE_NAME,
DLPACK_VERSIONED_CAPSULE_NAME_USED,
};
use pyo3::prelude::*;
use std::ffi::{c_void, CStr};
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub data: *mut c_void,
pub device: DLDevice,
pub dtype: DLDataType,
pub shape: Vec<i64>,
pub strides: Option<Vec<i64>>,
pub byte_offset: u64,
}
impl TensorInfo {
pub fn contiguous(
data: *mut c_void,
device: DLDevice,
dtype: DLDataType,
shape: Vec<i64>,
) -> Self {
Self {
data,
device,
dtype,
shape,
strides: None,
byte_offset: 0,
}
}
pub fn strided(
data: *mut c_void,
device: DLDevice,
dtype: DLDataType,
shape: Vec<i64>,
strides: Vec<i64>,
) -> Self {
assert_eq!(
strides.len(),
shape.len(),
"strides length ({}) must equal shape length ({})",
strides.len(),
shape.len()
);
Self {
data,
device,
dtype,
shape,
strides: Some(strides),
byte_offset: 0,
}
}
pub fn with_byte_offset(mut self, offset: u64) -> Self {
self.byte_offset = offset;
self
}
}
pub trait IntoDLPack: Send + Sized {
fn tensor_info(&self) -> TensorInfo;
fn into_dlpack(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let info = self.tensor_info();
export_to_capsule(py, self, info)
}
fn into_dlpack_readonly(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
let info = self.tensor_info();
export_to_capsule_versioned(py, self, info, DLPACK_FLAG_BITMASK_READ_ONLY)
}
}
struct ManagedContext<T, M> {
managed: M,
#[allow(dead_code)]
tensor: T,
#[allow(dead_code)]
shape: Vec<i64>,
#[allow(dead_code)]
strides: Option<Vec<i64>>,
}
fn validate_strides(info: &TensorInfo) -> PyResult<()> {
if let Some(ref strides) = info.strides {
if strides.len() != info.shape.len() {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"strides length ({}) must equal shape length ({})",
strides.len(),
info.shape.len()
)));
}
}
Ok(())
}
fn build_dl_tensor(
data: *mut c_void,
device: DLDevice,
dtype: DLDataType,
byte_offset: u64,
shape: &[i64],
strides: &Option<Vec<i64>>,
) -> DLTensor {
let ndim = shape.len() as i32;
let shape_ptr = if ndim == 0 {
std::ptr::null_mut()
} else {
shape.as_ptr() as *mut i64
};
let strides_ptr = if ndim == 0 {
std::ptr::null_mut()
} else {
strides
.as_ref()
.map(|s| s.as_ptr() as *mut i64)
.unwrap_or(std::ptr::null_mut())
};
DLTensor {
data,
device,
ndim,
dtype,
shape: shape_ptr,
strides: strides_ptr,
byte_offset,
}
}
unsafe fn free_managed_ctx<C>(manager_ctx: *mut c_void) {
if manager_ctx.is_null() {
return;
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
drop(Box::from_raw(manager_ctx as *mut C));
}));
}
unsafe fn into_capsule<C>(
py: Python<'_>,
ctx_ptr: *mut C,
managed_field: *mut c_void,
name: &CStr,
destructor: unsafe extern "C" fn(*mut pyo3::ffi::PyObject),
err_msg: &str,
) -> PyResult<Py<PyAny>> {
let capsule_ptr = pyo3::ffi::PyCapsule_New(managed_field, name.as_ptr(), Some(destructor));
if capsule_ptr.is_null() {
let _ = Box::from_raw(ctx_ptr);
return Err(pyo3::exceptions::PyMemoryError::new_err(err_msg.to_owned()));
}
Ok(Bound::from_owned_ptr(py, capsule_ptr).unbind())
}
unsafe fn unconsumed_managed_ptr(
capsule_ptr: *mut pyo3::ffi::PyObject,
used_name: &CStr,
) -> *mut c_void {
if capsule_ptr.is_null() {
return std::ptr::null_mut();
}
let name_ptr = pyo3::ffi::PyCapsule_GetName(capsule_ptr);
if name_ptr.is_null() {
return std::ptr::null_mut();
}
if CStr::from_ptr(name_ptr).to_bytes() == used_name.to_bytes() {
return std::ptr::null_mut();
}
pyo3::ffi::PyCapsule_GetPointer(capsule_ptr, name_ptr)
}
trait ManagedStruct: Sized {
const CAPSULE_NAME: &'static CStr;
const USED_NAME: &'static CStr;
fn manager_ctx(&self) -> *mut c_void;
fn set_manager_ctx(&mut self, ctx: *mut c_void);
fn deleter(&self) -> Option<unsafe extern "C" fn(*mut Self)>;
}
impl ManagedStruct for DLManagedTensor {
const CAPSULE_NAME: &'static CStr = DLPACK_CAPSULE_NAME;
const USED_NAME: &'static CStr = DLPACK_CAPSULE_NAME_USED;
fn manager_ctx(&self) -> *mut c_void {
self.manager_ctx
}
fn set_manager_ctx(&mut self, ctx: *mut c_void) {
self.manager_ctx = ctx;
}
fn deleter(&self) -> Option<unsafe extern "C" fn(*mut Self)> {
self.deleter
}
}
impl ManagedStruct for DLManagedTensorVersioned {
const CAPSULE_NAME: &'static CStr = DLPACK_VERSIONED_CAPSULE_NAME;
const USED_NAME: &'static CStr = DLPACK_VERSIONED_CAPSULE_NAME_USED;
fn manager_ctx(&self) -> *mut c_void {
self.manager_ctx
}
fn set_manager_ctx(&mut self, ctx: *mut c_void) {
self.manager_ctx = ctx;
}
fn deleter(&self) -> Option<unsafe extern "C" fn(*mut Self)> {
self.deleter
}
}
unsafe extern "C" fn raw_capsule_destructor<M: ManagedStruct>(
capsule_ptr: *mut pyo3::ffi::PyObject,
) {
let managed_ptr = unconsumed_managed_ptr(capsule_ptr, M::USED_NAME) as *mut M;
if managed_ptr.is_null() {
return;
}
if let Some(deleter) = (*managed_ptr).deleter() {
deleter(managed_ptr);
}
}
unsafe extern "C" fn dlpack_deleter<T, M: ManagedStruct>(managed_ptr: *mut M) {
if managed_ptr.is_null() {
return;
}
free_managed_ctx::<ManagedContext<T, M>>((*managed_ptr).manager_ctx());
}
fn export_to_capsule_with<T: IntoDLPack, M: ManagedStruct>(
py: Python<'_>,
tensor: T,
info: TensorInfo,
err_msg: &str,
build: impl FnOnce(DLTensor) -> M,
) -> PyResult<Py<PyAny>> {
validate_strides(&info)?;
let TensorInfo {
data,
device,
dtype,
shape,
strides,
byte_offset,
} = info;
let dl_tensor = build_dl_tensor(data, device, dtype, byte_offset, &shape, &strides);
let ctx_ptr = Box::into_raw(Box::new(ManagedContext {
managed: build(dl_tensor),
tensor,
shape,
strides,
}));
unsafe {
(*ctx_ptr).managed.set_manager_ctx(ctx_ptr as *mut c_void);
let managed_field = &mut (*ctx_ptr).managed as *mut M as *mut c_void;
into_capsule(
py,
ctx_ptr,
managed_field,
M::CAPSULE_NAME,
raw_capsule_destructor::<M>,
err_msg,
)
}
}
fn export_to_capsule<T: IntoDLPack>(
py: Python<'_>,
tensor: T,
info: TensorInfo,
) -> PyResult<Py<PyAny>> {
export_to_capsule_with::<T, DLManagedTensor>(
py,
tensor,
info,
"Failed to create DLPack capsule",
|dl_tensor| DLManagedTensor {
dl_tensor,
manager_ctx: std::ptr::null_mut(),
deleter: Some(dlpack_deleter::<T, DLManagedTensor>),
},
)
}
fn export_to_capsule_versioned<T: IntoDLPack>(
py: Python<'_>,
tensor: T,
info: TensorInfo,
flags: u64,
) -> PyResult<Py<PyAny>> {
export_to_capsule_with::<T, DLManagedTensorVersioned>(
py,
tensor,
info,
"Failed to create versioned DLPack capsule",
|dl_tensor| DLManagedTensorVersioned {
version: DLPackVersion {
major: DLPACK_MAJOR_VERSION,
minor: DLPACK_MINOR_VERSION,
},
manager_ctx: std::ptr::null_mut(),
deleter: Some(dlpack_deleter::<T, DLManagedTensorVersioned>),
flags,
dl_tensor,
},
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ffi::{cpu_device, cuda_device, dtype_f32, dtype_f64, dtype_i32};
use pyo3::Python;
use std::sync::atomic::{AtomicUsize, Ordering};
struct TestTensor {
data: Vec<f32>,
shape: Vec<i64>,
}
impl IntoDLPack for TestTensor {
fn tensor_info(&self) -> TensorInfo {
TensorInfo::contiguous(
self.data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
self.shape.clone(),
)
}
}
struct StridedTensor {
data: Vec<f32>,
shape: Vec<i64>,
strides: Vec<i64>,
}
impl IntoDLPack for StridedTensor {
fn tensor_info(&self) -> TensorInfo {
TensorInfo::strided(
self.data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
self.shape.clone(),
self.strides.clone(),
)
}
}
struct GpuTensor {
device_ptr: u64,
shape: Vec<i64>,
device_id: i32,
}
impl IntoDLPack for GpuTensor {
fn tensor_info(&self) -> TensorInfo {
TensorInfo::contiguous(
self.device_ptr as *mut c_void,
cuda_device(self.device_id),
dtype_f32(),
self.shape.clone(),
)
}
}
struct OffsetTensor {
data: Vec<f32>,
shape: Vec<i64>,
offset: u64,
}
impl IntoDLPack for OffsetTensor {
fn tensor_info(&self) -> TensorInfo {
TensorInfo::contiguous(
self.data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
self.shape.clone(),
)
.with_byte_offset(self.offset)
}
}
static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
struct DropTracker {
data: Vec<f32>,
shape: Vec<i64>,
}
impl Drop for DropTracker {
fn drop(&mut self) {
DROP_COUNT.fetch_add(1, Ordering::SeqCst);
}
}
impl IntoDLPack for DropTracker {
fn tensor_info(&self) -> TensorInfo {
TensorInfo::contiguous(
self.data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
self.shape.clone(),
)
}
}
#[test]
fn test_tensor_info_contiguous() {
let data = [1.0f32, 2.0, 3.0, 4.0].to_vec();
let info = TensorInfo::contiguous(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![2, 2],
);
assert!(info.strides.is_none());
assert_eq!(info.byte_offset, 0);
assert_eq!(info.shape, vec![2, 2]);
assert!(info.device.is_cpu());
assert!(info.dtype.is_f32());
}
#[test]
fn test_tensor_info_strided() {
let data = [1.0f32; 24].to_vec();
let info = TensorInfo::strided(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![2, 3, 4],
vec![12, 4, 1],
);
assert_eq!(info.strides, Some(vec![12, 4, 1]));
assert_eq!(info.byte_offset, 0);
assert_eq!(info.shape, vec![2, 3, 4]);
}
#[test]
fn test_tensor_info_with_byte_offset() {
let data = [1.0f32; 10].to_vec();
let info = TensorInfo::contiguous(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![10],
)
.with_byte_offset(16);
assert_eq!(info.byte_offset, 16);
}
#[test]
fn test_tensor_info_with_different_dtypes() {
let data_f64 = [1.0f64; 10].to_vec();
let info = TensorInfo::contiguous(
data_f64.as_ptr() as *mut c_void,
cpu_device(),
dtype_f64(),
vec![10],
);
assert!(info.dtype.is_f64());
let data_i32 = [1i32; 10].to_vec();
let info = TensorInfo::contiguous(
data_i32.as_ptr() as *mut c_void,
cpu_device(),
dtype_i32(),
vec![10],
);
assert!(info.dtype.is_i32());
}
#[test]
fn test_tensor_info_with_different_devices() {
let data = [1.0f32; 10].to_vec();
let cpu_info = TensorInfo::contiguous(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![10],
);
assert!(cpu_info.device.is_cpu());
let cuda_info = TensorInfo::contiguous(
0x12345678 as *mut c_void,
cuda_device(0),
dtype_f32(),
vec![10],
);
assert!(cuda_info.device.is_cuda());
assert_eq!(cuda_info.device.device_id, 0);
let cuda1_info = TensorInfo::contiguous(
0x12345678 as *mut c_void,
cuda_device(1),
dtype_f32(),
vec![10],
);
assert_eq!(cuda1_info.device.device_id, 1);
}
#[test]
fn test_tensor_info_debug() {
let data = [1.0f32; 10].to_vec();
let info = TensorInfo::contiguous(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![2, 5],
);
let debug = format!("{:?}", info);
assert!(debug.contains("TensorInfo"));
assert!(debug.contains("shape"));
}
#[test]
fn test_tensor_info_clone() {
let data = [1.0f32; 10].to_vec();
let info = TensorInfo::strided(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![2, 5],
vec![5, 1],
)
.with_byte_offset(8);
let cloned = info.clone();
assert_eq!(cloned.shape, info.shape);
assert_eq!(cloned.strides, info.strides);
assert_eq!(cloned.byte_offset, info.byte_offset);
}
#[test]
fn test_tensor_info_empty_shape() {
let data = [1.0f32].to_vec();
let info = TensorInfo::contiguous(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![], );
assert!(info.shape.is_empty());
}
#[test]
fn test_tensor_info_high_dimensional() {
let data = vec![1.0f32; 120];
let info = TensorInfo::contiguous(
data.as_ptr() as *mut c_void,
cpu_device(),
dtype_f32(),
vec![2, 3, 4, 5],
);
assert_eq!(info.shape.len(), 4);
}
#[test]
fn test_into_dlpack_contiguous() {
Python::attach(|py| {
let tensor = TestTensor {
data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
shape: vec![2, 3],
};
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_into_dlpack_strided() {
Python::attach(|py| {
let tensor = StridedTensor {
data: vec![1.0; 24],
shape: vec![2, 3, 4],
strides: vec![12, 4, 1],
};
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_into_dlpack_gpu_tensor() {
Python::attach(|py| {
let tensor = GpuTensor {
device_ptr: 0xDEADBEEF,
shape: vec![16, 32],
device_id: 0,
};
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_into_dlpack_with_offset() {
Python::attach(|py| {
let tensor = OffsetTensor {
data: vec![1.0; 20],
shape: vec![10],
offset: 40, };
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_into_dlpack_scalar() {
Python::attach(|py| {
let tensor = TestTensor {
data: vec![42.0],
shape: vec![], };
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_into_dlpack_1d() {
Python::attach(|py| {
let tensor = TestTensor {
data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
shape: vec![5],
};
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_into_dlpack_readonly_is_versioned() {
Python::attach(|py| {
let tensor = TestTensor {
data: vec![1.0, 2.0, 3.0, 4.0],
shape: vec![2, 2],
};
let capsule = tensor
.into_dlpack_readonly(py)
.expect("Failed to create read-only capsule");
unsafe {
let name_ptr = pyo3::ffi::PyCapsule_GetName(capsule.as_ptr());
assert!(!name_ptr.is_null());
let name = CStr::from_ptr(name_ptr);
assert_eq!(name.to_bytes(), b"dltensor_versioned");
let managed_ptr = pyo3::ffi::PyCapsule_GetPointer(capsule.as_ptr(), name_ptr)
as *mut DLManagedTensorVersioned;
assert!(!managed_ptr.is_null());
assert_eq!(
(*managed_ptr).flags & DLPACK_FLAG_BITMASK_READ_ONLY,
DLPACK_FLAG_BITMASK_READ_ONLY
);
assert_eq!((*managed_ptr).version.major, DLPACK_MAJOR_VERSION);
}
});
}
#[test]
fn test_capsule_cleanup_on_drop() {
DROP_COUNT.store(0, Ordering::SeqCst);
Python::attach(|py| {
{
let tensor = DropTracker {
data: vec![1.0, 2.0, 3.0],
shape: vec![3],
};
let _capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
}
py.run(c"import gc; gc.collect()", None, None).unwrap();
});
}
#[test]
fn test_deleter_null_check() {
unsafe {
dlpack_deleter::<TestTensor, DLManagedTensor>(std::ptr::null_mut());
}
}
#[test]
fn test_capsule_destructor_null_check() {
unsafe {
raw_capsule_destructor::<DLManagedTensor>(std::ptr::null_mut());
}
}
#[test]
fn test_versioned_deleter_null_check() {
unsafe {
dlpack_deleter::<TestTensor, DLManagedTensorVersioned>(std::ptr::null_mut());
}
}
#[test]
fn test_versioned_capsule_destructor_null_check() {
unsafe {
raw_capsule_destructor::<DLManagedTensorVersioned>(std::ptr::null_mut());
}
}
#[test]
fn test_into_dlpack_requires_send() {
fn assert_send<T: Send>() {}
assert_send::<TestTensor>();
assert_send::<StridedTensor>();
assert_send::<GpuTensor>();
assert_send::<OffsetTensor>();
assert_send::<DropTracker>();
}
#[test]
fn test_large_shape() {
Python::attach(|py| {
let tensor = TestTensor {
data: vec![1.0; 1000000],
shape: vec![100, 100, 100],
};
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_non_contiguous_strides() {
Python::attach(|py| {
let tensor = StridedTensor {
data: vec![1.0; 6],
shape: vec![2, 3],
strides: vec![1, 2], };
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
#[test]
fn test_zero_stride() {
Python::attach(|py| {
let tensor = StridedTensor {
data: vec![1.0; 3],
shape: vec![2, 3],
strides: vec![0, 1], };
let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
assert!(!capsule.is_none(py));
});
}
}