use std::ffi::CStr;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use crate::dlpack::{
array_from_dlpack_f32, array_from_dlpack_f64, DLDataType, DLDeviceType, DLManagedTensor,
DLTensor, DlpackError,
};
const DLTENSOR_NAME: &CStr = c"dltensor";
const USED_DLTENSOR_NAME: &CStr = c"used_dltensor";
#[derive(Debug, Clone)]
pub struct CudaTensorInfo {
pub device_id: i32,
pub shape: Vec<usize>,
pub dtype: DLDataType,
pub byte_offset: u64,
pub device_type_code: i32,
}
impl CudaTensorInfo {
pub fn numel(&self) -> usize {
self.shape.iter().product()
}
pub fn dtype_bits(&self) -> u8 {
self.dtype.bits
}
pub fn device_str(&self) -> String {
let name = match self.device_type_code {
2 => "cuda",
3 => "cuda_host",
4 => "opencl",
7 => "vulkan",
8 => "metal",
10 => "rocm",
_ => "unknown",
};
format!("{}:{}", name, self.device_id)
}
}
pub enum DLPackDispatchResult<'a, T> {
Cpu(ndarray::ArrayViewD<'a, T>),
Gpu(CudaTensorInfo),
OtherDevice {
device_type: i32,
device_id: i32,
},
}
pub fn cuda_tensor_info_from_dltensor(tensor: &DLTensor) -> Result<CudaTensorInfo, DlpackError> {
if tensor.data.is_null() {
return Err(DlpackError::NullPointer);
}
if tensor.device.device_type == DLDeviceType::Cpu as i32 {
return Err(DlpackError::NonCpuDevice);
}
let ndim = tensor.ndim.max(0) as usize;
let shape = if ndim == 0 || tensor.shape.is_null() {
Vec::new()
} else {
unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, ndim) }
.iter()
.map(|&d| d as usize)
.collect()
};
Ok(CudaTensorInfo {
device_id: tensor.device.device_id,
shape,
dtype: tensor.dtype,
byte_offset: tensor.byte_offset,
device_type_code: tensor.device.device_type,
})
}
pub unsafe fn dlpack_auto_dispatch_f32<'a>(
tensor: *const DLTensor,
) -> Result<DLPackDispatchResult<'a, f32>, DlpackError> {
let t = unsafe { &*tensor };
match t.device.device_type {
dt if dt == DLDeviceType::Cpu as i32 => {
let view = unsafe { array_from_dlpack_f32(tensor)? };
Ok(DLPackDispatchResult::Cpu(view))
}
_ => {
let info = cuda_tensor_info_from_dltensor(t)?;
Ok(DLPackDispatchResult::Gpu(info))
}
}
}
pub unsafe fn dlpack_auto_dispatch_f64<'a>(
tensor: *const DLTensor,
) -> Result<DLPackDispatchResult<'a, f64>, DlpackError> {
let t = unsafe { &*tensor };
match t.device.device_type {
dt if dt == DLDeviceType::Cpu as i32 => {
let view = unsafe { array_from_dlpack_f64(tensor)? };
Ok(DLPackDispatchResult::Cpu(view))
}
_ => {
let info = cuda_tensor_info_from_dltensor(t)?;
Ok(DLPackDispatchResult::Gpu(info))
}
}
}
pub fn cuda_tensor_info(capsule: &Bound<'_, PyAny>) -> PyResult<CudaTensorInfo> {
let raw_obj: *mut pyo3::ffi::PyObject = capsule.as_ptr();
let is_used =
unsafe { pyo3::ffi::PyCapsule_IsValid(raw_obj, USED_DLTENSOR_NAME.as_ptr()) == 1 };
if is_used {
return Err(pyo3::exceptions::PyValueError::new_err(
"DLPack capsule has already been consumed ('used_dltensor'). \
Call __dlpack__() again on the original tensor.",
));
}
let raw_ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(raw_obj, DLTENSOR_NAME.as_ptr()) };
if raw_ptr.is_null() {
return Err(PyErr::fetch(capsule.py()));
}
let managed = unsafe { &*(raw_ptr as *const DLManagedTensor) };
let dl_tensor = &managed.dl_tensor;
cuda_tensor_info_from_dltensor(dl_tensor).map_err(|e| match e {
DlpackError::NonCpuDevice => pyo3::exceptions::PyValueError::new_err(
"cuda_tensor_info requires a non-CPU DLPack tensor. \
Use the standard DLPack CPU path for CPU tensors.",
),
DlpackError::NullPointer => {
pyo3::exceptions::PyValueError::new_err("DLPack tensor has a null data pointer.")
}
other => pyo3::exceptions::PyValueError::new_err(format!("DLPack error: {other}")),
})
}
#[pyfunction]
pub fn get_cuda_tensor_info(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<Py<PyDict>> {
let info = cuda_tensor_info(obj)?;
let dict = PyDict::new(py);
dict.set_item("device_id", info.device_id)?;
dict.set_item("shape", info.shape.clone())?;
dict.set_item("device_type", info.device_type_code)?;
dict.set_item("device_str", info.device_str())?;
Ok(dict.into())
}
pub fn register_dlpack_cuda_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_cuda_tensor_info, m)?)?;
Ok(())
}
#[cfg(test)]
fn make_non_cpu_dltensor(device_type: i32, device_id: i32, shape: &[i64]) -> DLTensor {
use crate::dlpack::{DLDataTypeCode, DLDevice};
use std::ffi::c_void;
static SENTINEL: u8 = 0;
DLTensor {
data: &SENTINEL as *const u8 as *mut c_void,
device: DLDevice {
device_type,
device_id,
},
ndim: shape.len() as i32,
dtype: DLDataType {
code: DLDataTypeCode::Float as u8,
bits: 32,
lanes: 1,
},
shape: shape.as_ptr() as *mut i64,
strides: std::ptr::null_mut(),
byte_offset: 0,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dlpack::{dlpack_from_slice, DLDeviceType};
use std::ffi::c_void;
#[test]
fn test_cuda_tensor_info_rejects_cpu_device() {
let data = [1.0_f64, 2.0, 3.0];
let shape = [3_i64];
let tensor = dlpack_from_slice(&data, &shape);
let result = cuda_tensor_info_from_dltensor(&tensor);
assert!(
matches!(result, Err(DlpackError::NonCpuDevice)),
"CPU tensor should be rejected by cuda_tensor_info_from_dltensor"
);
}
#[test]
fn test_cuda_tensor_info_rejects_null_data() {
let shape = [4_i64, 4];
let mut tensor = make_non_cpu_dltensor(2, 0, &shape);
tensor.data = std::ptr::null_mut();
let result = cuda_tensor_info_from_dltensor(&tensor);
assert!(
matches!(result, Err(DlpackError::NullPointer)),
"null data pointer should be rejected"
);
}
#[test]
fn test_cuda_tensor_info_extracts_shape() {
let shape = [3_i64, 4, 5];
let tensor = make_non_cpu_dltensor(2, 0, &shape);
let info = cuda_tensor_info_from_dltensor(&tensor)
.expect("CUDA tensor should produce CudaTensorInfo");
assert_eq!(info.shape, vec![3, 4, 5], "shape mismatch");
assert_eq!(info.numel(), 60, "numel mismatch");
}
#[test]
fn test_cuda_tensor_info_extracts_device_id() {
let shape = [8_i64];
let tensor = make_non_cpu_dltensor(2, 3, &shape); let info = cuda_tensor_info_from_dltensor(&tensor).expect("should produce CudaTensorInfo");
assert_eq!(info.device_id, 3, "device_id mismatch");
assert_eq!(
info.device_type_code, 2,
"device_type_code should be CUDA (2)"
);
}
#[test]
fn test_cuda_tensor_info_device_str() {
let shape = [1_i64];
let tensor = make_non_cpu_dltensor(2, 0, &shape);
let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
assert_eq!(info.device_str(), "cuda:0");
}
#[test]
fn test_rocm_tensor_info_device_str() {
let shape = [1_i64];
let tensor = make_non_cpu_dltensor(10, 1, &shape); let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
assert_eq!(info.device_str(), "rocm:1");
}
#[test]
fn test_cuda_tensor_info_zero_dim_tensor() {
use crate::dlpack::{DLDataType, DLDataTypeCode};
static SENTINEL: u8 = 0;
let tensor = DLTensor {
data: &SENTINEL as *const u8 as *mut c_void,
device: crate::dlpack::DLDevice {
device_type: 2,
device_id: 0,
},
ndim: 0,
dtype: DLDataType {
code: DLDataTypeCode::Float as u8,
bits: 32,
lanes: 1,
},
shape: std::ptr::null_mut(),
strides: std::ptr::null_mut(),
byte_offset: 0,
};
let info = cuda_tensor_info_from_dltensor(&tensor).expect("zero-dim should succeed");
assert!(info.shape.is_empty(), "zero-dim shape should be empty");
assert_eq!(info.numel(), 1, "empty product is 1");
}
#[test]
fn test_dlpack_auto_dispatch_cpu_f32_returns_array() {
let data = [1.0_f32, 2.0, 3.0, 4.0];
let shape = [2_i64, 2];
let tensor = crate::dlpack::DLTensor {
data: data.as_ptr() as *mut c_void,
device: crate::dlpack::DLDevice {
device_type: DLDeviceType::Cpu as i32,
device_id: 0,
},
ndim: 2,
dtype: crate::dlpack::DLDataType {
code: crate::dlpack::DLDataTypeCode::Float as u8,
bits: 32,
lanes: 1,
},
shape: shape.as_ptr() as *mut i64,
strides: std::ptr::null_mut(),
byte_offset: 0,
};
let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
.expect("CPU dispatch should succeed");
assert!(
matches!(result, DLPackDispatchResult::Cpu(_)),
"CPU tensor should return Cpu variant"
);
if let DLPackDispatchResult::Cpu(view) = result {
assert_eq!(view.shape(), &[2, 2]);
assert_eq!(view[[0, 0]], 1.0_f32);
}
}
#[test]
fn test_dlpack_auto_dispatch_cuda_f32_returns_gpu_info() {
let shape = [8_i64];
let tensor = make_non_cpu_dltensor(DLDeviceType::Cuda as i32, 0, &shape);
let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
.expect("CUDA dispatch should succeed");
assert!(
matches!(result, DLPackDispatchResult::Gpu(_)),
"CUDA tensor should return Gpu variant"
);
if let DLPackDispatchResult::Gpu(info) = result {
assert_eq!(info.shape, vec![8]);
assert_eq!(info.device_type_code, 2);
}
}
#[test]
fn test_dlpack_auto_dispatch_cpu_f64_returns_array() {
let data = [10.0_f64, 20.0, 30.0];
let shape = [3_i64];
let tensor = dlpack_from_slice(&data, &shape);
let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
.expect("CPU f64 dispatch should succeed");
assert!(
matches!(result, DLPackDispatchResult::Cpu(_)),
"CPU f64 tensor should return Cpu variant"
);
}
#[test]
fn test_dlpack_auto_dispatch_cuda_f64_returns_gpu_info() {
let shape = [4_i64, 4];
let tensor = make_non_cpu_dltensor(DLDeviceType::Cuda as i32, 1, &shape);
let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
.expect("CUDA f64 dispatch should succeed");
if let DLPackDispatchResult::Gpu(info) = result {
assert_eq!(info.shape, vec![4, 4]);
assert_eq!(info.device_id, 1);
} else {
panic!("expected Gpu variant");
}
}
#[test]
fn test_dlpack_other_device_passthrough() {
let shape = [16_i64];
let tensor = make_non_cpu_dltensor(DLDeviceType::Metal as i32, 0, &shape);
let result = unsafe { dlpack_auto_dispatch_f32(&tensor as *const _) }
.expect("Metal dispatch should succeed");
assert!(
matches!(result, DLPackDispatchResult::Gpu(_)),
"Metal tensor should return Gpu variant"
);
if let DLPackDispatchResult::Gpu(info) = result {
assert_eq!(info.device_str(), "metal:0");
}
}
#[test]
fn test_dlpack_rocm_passthrough() {
let shape = [32_i64];
let tensor = make_non_cpu_dltensor(DLDeviceType::Rocm as i32, 2, &shape);
let result = unsafe { dlpack_auto_dispatch_f64(&tensor as *const _) }
.expect("ROCm dispatch should succeed");
if let DLPackDispatchResult::Gpu(info) = result {
assert_eq!(info.device_type_code, 10);
assert_eq!(info.device_id, 2);
} else {
panic!("expected Gpu variant for ROCm device");
}
}
#[test]
fn test_cuda_tensor_numel_empty_shape() {
let shape: [i64; 0] = [];
let tensor = make_non_cpu_dltensor(2, 0, &shape);
let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
assert_eq!(info.numel(), 1, "empty shape product is 1");
}
#[test]
fn test_cuda_tensor_dtype_bits() {
let shape = [4_i64];
let tensor = make_non_cpu_dltensor(2, 0, &shape);
let info = cuda_tensor_info_from_dltensor(&tensor).expect("should succeed");
assert_eq!(info.dtype_bits(), 32, "dtype bits should be 32");
}
}