use std::os::raw::{c_int, c_void};
use dlpk::sys::{
DLDataType, DLDataTypeCode, DLDevice, DLDeviceType, DLManagedTensorVersioned, DLPackVersion,
DLTensor, DLPACK_FLAG_BITMASK_IS_COPIED,
};
fn compute_row_major_strides(shape: &[i64]) -> Vec<i64> {
let ndim = shape.len();
if ndim == 0 {
return vec![];
}
let mut strides = vec![1i64; ndim];
for i in (0..ndim - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
fn dtype_f64() -> DLDataType {
DLDataType {
code: DLDataTypeCode::kDLFloat,
bits: 64,
lanes: 1,
}
}
fn dtype_i32() -> DLDataType {
DLDataType {
code: DLDataTypeCode::kDLInt,
bits: 32,
lanes: 1,
}
}
fn cpu_device() -> DLDevice {
DLDevice {
device_type: DLDeviceType::kDLCPU,
device_id: 0,
}
}
fn dlpack_version() -> DLPackVersion {
DLPackVersion { major: 1, minor: 0 }
}
struct BorrowedTensorContext {
shape: Vec<i64>,
strides: Vec<i64>,
}
unsafe extern "C" fn borrowed_deleter(ptr: *mut DLManagedTensorVersioned) {
if ptr.is_null() {
return;
}
let ctx = unsafe { (*ptr).manager_ctx.cast::<BorrowedTensorContext>() };
if !ctx.is_null() {
drop(unsafe { Box::from_raw(ctx) });
}
drop(unsafe { Box::from_raw(ptr) });
}
unsafe fn create_borrowed_tensor(
data: *mut c_void,
dtype: DLDataType,
shape_vec: Vec<i64>,
) -> *mut DLManagedTensorVersioned {
let ndim = shape_vec.len() as i32;
let strides_vec = compute_row_major_strides(&shape_vec);
let mut ctx = Box::new(BorrowedTensorContext {
shape: shape_vec,
strides: strides_vec,
});
let dl_tensor = DLTensor {
data,
device: cpu_device(),
ndim,
dtype,
shape: ctx.shape.as_mut_ptr(),
strides: ctx.strides.as_mut_ptr(),
byte_offset: 0,
};
let managed = Box::new(DLManagedTensorVersioned {
version: dlpack_version(),
manager_ctx: Box::into_raw(ctx).cast(),
deleter: Some(borrowed_deleter),
flags: 0,
dl_tensor,
});
Box::into_raw(managed)
}
struct OwnedF64TensorContext {
_data: Vec<f64>,
shape: Vec<i64>,
strides: Vec<i64>,
}
unsafe extern "C" fn owned_f64_deleter(ptr: *mut DLManagedTensorVersioned) {
if ptr.is_null() {
return;
}
let ctx = unsafe { (*ptr).manager_ctx.cast::<OwnedF64TensorContext>() };
if !ctx.is_null() {
drop(unsafe { Box::from_raw(ctx) });
}
drop(unsafe { Box::from_raw(ptr) });
}
#[allow(dead_code)]
struct OwnedI32TensorContext {
_data: Vec<i32>,
shape: Vec<i64>,
strides: Vec<i64>,
}
#[allow(dead_code)]
unsafe extern "C" fn owned_i32_deleter(ptr: *mut DLManagedTensorVersioned) {
if ptr.is_null() {
return;
}
let ctx = unsafe { (*ptr).manager_ctx.cast::<OwnedI32TensorContext>() };
if !ctx.is_null() {
drop(unsafe { Box::from_raw(ctx) });
}
drop(unsafe { Box::from_raw(ptr) });
}
pub(crate) fn create_owned_f64_tensor(
mut data: Vec<f64>,
shape_vec: Vec<i64>,
) -> *mut DLManagedTensorVersioned {
let ndim = shape_vec.len() as i32;
let strides_vec = compute_row_major_strides(&shape_vec);
let data_ptr = data.as_mut_ptr();
let mut ctx = Box::new(OwnedF64TensorContext {
_data: data,
shape: shape_vec,
strides: strides_vec,
});
let dl_tensor = DLTensor {
data: data_ptr.cast(),
device: cpu_device(),
ndim,
dtype: dtype_f64(),
shape: ctx.shape.as_mut_ptr(),
strides: ctx.strides.as_mut_ptr(),
byte_offset: 0,
};
let managed = Box::new(DLManagedTensorVersioned {
version: dlpack_version(),
manager_ctx: Box::into_raw(ctx).cast(),
deleter: Some(owned_f64_deleter),
flags: DLPACK_FLAG_BITMASK_IS_COPIED,
dl_tensor,
});
Box::into_raw(managed)
}
#[allow(dead_code)]
pub(crate) fn create_owned_i32_tensor(
mut data: Vec<i32>,
shape_vec: Vec<i64>,
) -> *mut DLManagedTensorVersioned {
let ndim = shape_vec.len() as i32;
let strides_vec = compute_row_major_strides(&shape_vec);
let data_ptr = data.as_mut_ptr();
let mut ctx = Box::new(OwnedI32TensorContext {
_data: data,
shape: shape_vec,
strides: strides_vec,
});
let dl_tensor = DLTensor {
data: data_ptr.cast(),
device: cpu_device(),
ndim,
dtype: dtype_i32(),
shape: ctx.shape.as_mut_ptr(),
strides: ctx.strides.as_mut_ptr(),
byte_offset: 0,
};
let managed = Box::new(DLManagedTensorVersioned {
version: dlpack_version(),
manager_ctx: Box::into_raw(ctx).cast(),
deleter: Some(owned_i32_deleter),
flags: DLPACK_FLAG_BITMASK_IS_COPIED,
dl_tensor,
});
Box::into_raw(managed)
}
#[allow(dead_code)]
pub(crate) fn validate_positions(
tensor: *const DLManagedTensorVersioned,
) -> Result<usize, String> {
if tensor.is_null() {
return Err("positions tensor is NULL".into());
}
let t = unsafe { &(*tensor).dl_tensor };
if t.dtype != dtype_f64() {
return Err(format!("positions: expected f64, got {:?}", t.dtype));
}
if t.ndim != 2 {
return Err(format!("positions: expected ndim=2, got {}", t.ndim));
}
let shape = unsafe { std::slice::from_raw_parts(t.shape, t.ndim as usize) };
if shape[1] != 3 {
return Err(format!("positions: expected shape[1]=3, got {}", shape[1]));
}
Ok(shape[0] as usize)
}
#[allow(dead_code)]
pub(crate) fn validate_atomic_numbers(
tensor: *const DLManagedTensorVersioned,
expected_n: usize,
) -> Result<(), String> {
if tensor.is_null() {
return Err("atomic_numbers tensor is NULL".into());
}
let t = unsafe { &(*tensor).dl_tensor };
if t.dtype != dtype_i32() {
return Err(format!("atomic_numbers: expected i32, got {:?}", t.dtype));
}
if t.ndim != 1 {
return Err(format!(
"atomic_numbers: expected ndim=1, got {}",
t.ndim
));
}
let shape = unsafe { std::slice::from_raw_parts(t.shape, 1) };
if shape[0] as usize != expected_n {
return Err(format!(
"atomic_numbers: expected len={}, got {}",
expected_n, shape[0]
));
}
Ok(())
}
#[allow(dead_code)]
pub(crate) fn validate_box_matrix(
tensor: *const DLManagedTensorVersioned,
) -> Result<(), String> {
if tensor.is_null() {
return Err("box_matrix tensor is NULL".into());
}
let t = unsafe { &(*tensor).dl_tensor };
if t.dtype != dtype_f64() {
return Err(format!("box_matrix: expected f64, got {:?}", t.dtype));
}
if t.ndim != 2 {
return Err(format!("box_matrix: expected ndim=2, got {}", t.ndim));
}
let shape = unsafe { std::slice::from_raw_parts(t.shape, 2) };
if shape[0] != 3 || shape[1] != 3 {
return Err(format!(
"box_matrix: expected shape=[3,3], got [{},{}]",
shape[0], shape[1]
));
}
Ok(())
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_cpu_f64_2d(
data: *mut f64,
rows: i64,
cols: i64,
) -> *mut DLManagedTensorVersioned {
create_borrowed_tensor(data.cast(), dtype_f64(), vec![rows, cols])
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_cpu_i32_1d(
data: *mut c_int,
len: i64,
) -> *mut DLManagedTensorVersioned {
create_borrowed_tensor(data.cast(), dtype_i32(), vec![len])
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_cpu_f64_matrix3(
data: *mut f64,
) -> *mut DLManagedTensorVersioned {
create_borrowed_tensor(data.cast(), dtype_f64(), vec![3, 3])
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_owned_cpu_f64_2d(
data: *const f64,
rows: i64,
cols: i64,
) -> *mut DLManagedTensorVersioned {
let len = (rows * cols) as usize;
let vec = unsafe { std::slice::from_raw_parts(data, len) }.to_vec();
create_owned_f64_tensor(vec, vec![rows, cols])
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_free(tensor: *mut DLManagedTensorVersioned) {
if tensor.is_null() {
return;
}
if let Some(deleter) = unsafe { (*tensor).deleter } {
unsafe { deleter(tensor) };
}
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_device(
tensor: *const DLManagedTensorVersioned,
) -> DLDevice {
unsafe { (*tensor).dl_tensor.device }
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_data(
tensor: *const DLManagedTensorVersioned,
) -> *const c_void {
unsafe { (*tensor).dl_tensor.data as *const c_void }
}
#[no_mangle]
pub unsafe extern "C" fn rgpot_tensor_shape(
tensor: *const DLManagedTensorVersioned,
ndim_out: *mut i32,
) -> *const i64 {
let t = unsafe { &(*tensor) };
unsafe { *ndim_out = t.dl_tensor.ndim };
t.dl_tensor.shape as *const i64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn borrowed_f64_2d_has_correct_metadata() {
let mut data = [1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let tensor = unsafe { rgpot_tensor_cpu_f64_2d(data.as_mut_ptr(), 2, 3) };
assert!(!tensor.is_null());
let t = unsafe { &(*tensor) };
assert_eq!(t.dl_tensor.device, cpu_device());
assert_eq!(t.dl_tensor.dtype, dtype_f64());
assert_eq!(t.dl_tensor.ndim, 2);
let shape = unsafe { std::slice::from_raw_parts(t.dl_tensor.shape, 2) };
assert_eq!(shape, &[2, 3]);
let strides = unsafe { std::slice::from_raw_parts(t.dl_tensor.strides, 2) };
assert_eq!(strides, &[3, 1]);
assert_eq!(t.dl_tensor.byte_offset, 0);
let data_ptr = t.dl_tensor.data as *const f64;
assert_eq!(data_ptr, data.as_ptr());
unsafe { rgpot_tensor_free(tensor) };
assert_eq!(data[0], 1.0);
}
#[test]
fn borrowed_i32_1d_has_correct_metadata() {
let mut data = [6_i32, 1, 8];
let tensor = unsafe { rgpot_tensor_cpu_i32_1d(data.as_mut_ptr(), 3) };
assert!(!tensor.is_null());
let t = unsafe { &(*tensor) };
assert_eq!(t.dl_tensor.dtype, dtype_i32());
assert_eq!(t.dl_tensor.ndim, 1);
let shape = unsafe { std::slice::from_raw_parts(t.dl_tensor.shape, 1) };
assert_eq!(shape, &[3]);
unsafe { rgpot_tensor_free(tensor) };
assert_eq!(data[2], 8);
}
#[test]
fn borrowed_matrix3_is_3x3() {
let mut data = [0.0_f64; 9];
let tensor = unsafe { rgpot_tensor_cpu_f64_matrix3(data.as_mut_ptr()) };
let t = unsafe { &(*tensor) };
assert_eq!(t.dl_tensor.ndim, 2);
let shape = unsafe { std::slice::from_raw_parts(t.dl_tensor.shape, 2) };
assert_eq!(shape, &[3, 3]);
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn owned_f64_tensor_survives_source_drop() {
let tensor = {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
create_owned_f64_tensor(data, vec![2, 3])
};
let t = unsafe { &(*tensor) };
let slice =
unsafe { std::slice::from_raw_parts(t.dl_tensor.data as *const f64, 6) };
assert_eq!(slice, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
assert!(t.flags & DLPACK_FLAG_BITMASK_IS_COPIED != 0);
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn owned_i32_tensor_survives_source_drop() {
let tensor = {
let data = vec![1_i32, 8, 6];
create_owned_i32_tensor(data, vec![3])
};
let t = unsafe { &(*tensor) };
let slice =
unsafe { std::slice::from_raw_parts(t.dl_tensor.data as *const i32, 3) };
assert_eq!(slice, &[1, 8, 6]);
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn owned_cpu_f64_2d_copies_data() {
let source = [10.0_f64, 20.0, 30.0, 40.0, 50.0, 60.0];
let tensor = unsafe { rgpot_tensor_owned_cpu_f64_2d(source.as_ptr(), 2, 3) };
let t = unsafe { &(*tensor) };
assert_ne!(t.dl_tensor.data as *const f64, source.as_ptr());
let slice =
unsafe { std::slice::from_raw_parts(t.dl_tensor.data as *const f64, 6) };
assert_eq!(slice, &source);
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn free_null_is_noop() {
unsafe { rgpot_tensor_free(std::ptr::null_mut()) };
}
#[test]
fn device_accessor_returns_cpu() {
let mut data = [0.0_f64; 6];
let tensor = unsafe { rgpot_tensor_cpu_f64_2d(data.as_mut_ptr(), 2, 3) };
let dev = unsafe { rgpot_tensor_device(tensor) };
assert_eq!(dev.device_type, DLDeviceType::kDLCPU);
assert_eq!(dev.device_id, 0);
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn data_accessor_returns_correct_pointer() {
let mut data = [42.0_f64; 3];
let tensor = unsafe { rgpot_tensor_cpu_f64_2d(data.as_mut_ptr(), 1, 3) };
let ptr = unsafe { rgpot_tensor_data(tensor) };
assert_eq!(ptr as *const f64, data.as_ptr());
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn shape_accessor_returns_correct_values() {
let mut data = [0.0_f64; 12];
let tensor = unsafe { rgpot_tensor_cpu_f64_2d(data.as_mut_ptr(), 4, 3) };
let mut ndim: i32 = 0;
let shape_ptr = unsafe { rgpot_tensor_shape(tensor, &mut ndim) };
assert_eq!(ndim, 2);
let shape = unsafe { std::slice::from_raw_parts(shape_ptr, ndim as usize) };
assert_eq!(shape, &[4, 3]);
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn validate_positions_ok() {
let mut data = [0.0_f64; 6];
let tensor = unsafe { rgpot_tensor_cpu_f64_2d(data.as_mut_ptr(), 2, 3) };
assert_eq!(validate_positions(tensor), Ok(2));
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn validate_positions_null() {
assert!(validate_positions(std::ptr::null()).is_err());
}
#[test]
fn validate_positions_wrong_shape() {
let mut data = [0.0_f64; 8];
let tensor = unsafe { rgpot_tensor_cpu_f64_2d(data.as_mut_ptr(), 2, 4) };
assert!(validate_positions(tensor).is_err());
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn validate_positions_wrong_dtype() {
let mut data = [0_i32; 6];
let tensor = unsafe { rgpot_tensor_cpu_i32_1d(data.as_mut_ptr(), 6) };
assert!(validate_positions(tensor).is_err());
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn validate_atomic_numbers_ok() {
let mut data = [1_i32, 8];
let tensor = unsafe { rgpot_tensor_cpu_i32_1d(data.as_mut_ptr(), 2) };
assert!(validate_atomic_numbers(tensor, 2).is_ok());
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn validate_atomic_numbers_wrong_count() {
let mut data = [1_i32, 8];
let tensor = unsafe { rgpot_tensor_cpu_i32_1d(data.as_mut_ptr(), 2) };
assert!(validate_atomic_numbers(tensor, 3).is_err());
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn validate_box_matrix_ok() {
let mut data = [0.0_f64; 9];
let tensor = unsafe { rgpot_tensor_cpu_f64_matrix3(data.as_mut_ptr()) };
assert!(validate_box_matrix(tensor).is_ok());
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn validate_box_matrix_wrong_shape() {
let mut data = [0.0_f64; 6];
let tensor = unsafe { rgpot_tensor_cpu_f64_2d(data.as_mut_ptr(), 2, 3) };
assert!(validate_box_matrix(tensor).is_err());
unsafe { rgpot_tensor_free(tensor) };
}
#[test]
fn row_major_strides_1d() {
assert_eq!(compute_row_major_strides(&[5]), vec![1]);
}
#[test]
fn row_major_strides_2d() {
assert_eq!(compute_row_major_strides(&[4, 3]), vec![3, 1]);
}
#[test]
fn row_major_strides_3d() {
assert_eq!(compute_row_major_strides(&[2, 3, 4]), vec![12, 4, 1]);
}
#[test]
fn row_major_strides_empty() {
assert_eq!(compute_row_major_strides(&[]), Vec::<i64>::new());
}
}