1use crate::ffi::{DLDataType, DLDevice, DLManagedTensor, DLTensor};
7use pyo3::prelude::*;
8use std::ffi::{c_void, CStr};
9
10#[derive(Debug, Clone)]
16pub struct TensorInfo {
17 pub data: *mut c_void,
19 pub device: DLDevice,
21 pub dtype: DLDataType,
23 pub shape: Vec<i64>,
25 pub strides: Option<Vec<i64>>,
27 pub byte_offset: u64,
29}
30
31impl TensorInfo {
32 pub fn contiguous(
34 data: *mut c_void,
35 device: DLDevice,
36 dtype: DLDataType,
37 shape: Vec<i64>,
38 ) -> Self {
39 Self {
40 data,
41 device,
42 dtype,
43 shape,
44 strides: None,
45 byte_offset: 0,
46 }
47 }
48
49 pub fn strided(
56 data: *mut c_void,
57 device: DLDevice,
58 dtype: DLDataType,
59 shape: Vec<i64>,
60 strides: Vec<i64>,
61 ) -> Self {
62 assert_eq!(
63 strides.len(),
64 shape.len(),
65 "strides length ({}) must equal shape length ({})",
66 strides.len(),
67 shape.len()
68 );
69 Self {
70 data,
71 device,
72 dtype,
73 shape,
74 strides: Some(strides),
75 byte_offset: 0,
76 }
77 }
78
79 pub fn with_byte_offset(mut self, offset: u64) -> Self {
81 self.byte_offset = offset;
82 self
83 }
84}
85
86pub trait IntoDLPack: Send + Sized {
115 fn tensor_info(&self) -> TensorInfo;
117
118 fn into_dlpack(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
135 let info = self.tensor_info();
136 export_to_capsule(py, self, info)
137 }
138}
139
140struct ExportContext<T> {
142 #[allow(dead_code)]
144 tensor: T,
145 shape: Vec<i64>,
147 strides: Option<Vec<i64>>,
149}
150
151static DLPACK_CAPSULE_NAME: &[u8] = b"dltensor\0";
153
154static USED_DLTENSOR_NAME: &[u8] = b"used_dltensor\0";
157
158fn export_to_capsule<T: IntoDLPack>(
160 py: Python<'_>,
161 tensor: T,
162 info: TensorInfo,
163) -> PyResult<Py<PyAny>> {
164 if let Some(ref strides) = info.strides {
168 if strides.len() != info.shape.len() {
169 return Err(pyo3::exceptions::PyValueError::new_err(format!(
170 "strides length ({}) must equal shape length ({})",
171 strides.len(),
172 info.shape.len()
173 )));
174 }
175 }
176
177 let ctx = Box::new(ExportContext {
179 tensor,
180 shape: info.shape,
181 strides: info.strides,
182 });
183 let ctx_ptr = Box::into_raw(ctx);
184
185 let ndim = unsafe { (*ctx_ptr).shape.len() as i32 };
190 let shape_ptr = if ndim == 0 {
191 std::ptr::null_mut()
192 } else {
193 unsafe { (*ctx_ptr).shape.as_mut_ptr() }
194 };
195 let strides_ptr = if ndim == 0 {
196 std::ptr::null_mut()
197 } else {
198 unsafe {
199 (*ctx_ptr)
200 .strides
201 .as_mut()
202 .map(|s| s.as_mut_ptr())
203 .unwrap_or(std::ptr::null_mut())
204 }
205 };
206
207 let managed = Box::new(DLManagedTensor {
208 dl_tensor: DLTensor {
209 data: info.data,
210 device: info.device,
211 ndim,
212 dtype: info.dtype,
213 shape: shape_ptr,
214 strides: strides_ptr,
215 byte_offset: info.byte_offset,
216 },
217 manager_ctx: ctx_ptr as *mut c_void,
218 deleter: Some(dlpack_deleter::<T>),
219 });
220
221 let managed_ptr = Box::into_raw(managed);
222
223 let capsule_ptr = unsafe {
227 pyo3::ffi::PyCapsule_New(
228 managed_ptr as *mut c_void,
229 DLPACK_CAPSULE_NAME.as_ptr() as *const i8,
230 Some(raw_capsule_destructor),
231 )
232 };
233
234 if capsule_ptr.is_null() {
235 unsafe {
240 let _ = Box::from_raw(managed_ptr);
241 let _ = Box::from_raw(ctx_ptr);
242 }
243 return Err(pyo3::exceptions::PyMemoryError::new_err(
244 "Failed to create DLPack capsule",
245 ));
246 }
247
248 unsafe {
251 pyo3::ffi::PyCapsule_SetContext(capsule_ptr, ctx_ptr as *mut c_void);
252 }
253
254 Ok(unsafe { Py::from_owned_ptr(py, capsule_ptr) })
256}
257
258unsafe extern "C" fn raw_capsule_destructor(capsule_ptr: *mut pyo3::ffi::PyObject) {
268 if capsule_ptr.is_null() {
269 return;
270 }
271
272 let name_ptr = pyo3::ffi::PyCapsule_GetName(capsule_ptr);
274 if name_ptr.is_null() {
275 return;
277 }
278
279 let name = CStr::from_ptr(name_ptr);
280
281 if name.to_bytes() == USED_DLTENSOR_NAME[..USED_DLTENSOR_NAME.len() - 1].as_ref() {
284 return;
285 }
286
287 let managed_ptr =
289 pyo3::ffi::PyCapsule_GetPointer(capsule_ptr, name_ptr) as *mut DLManagedTensor;
290
291 if managed_ptr.is_null() {
292 return;
293 }
294
295 let managed = &*managed_ptr;
297 if let Some(deleter) = managed.deleter {
298 deleter(managed_ptr);
299 }
300}
301
302unsafe extern "C" fn dlpack_deleter<T>(managed_ptr: *mut DLManagedTensor) {
306 if managed_ptr.is_null() {
307 return;
308 }
309
310 let managed = Box::from_raw(managed_ptr);
312
313 if !managed.manager_ctx.is_null() {
315 let _ctx = Box::from_raw(managed.manager_ctx as *mut ExportContext<T>);
316 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::ffi::{cpu_device, cuda_device, dtype_f32, dtype_f64, dtype_i32};
324 use pyo3::Python;
325 use std::sync::atomic::{AtomicUsize, Ordering};
326
327 struct TestTensor {
332 data: Vec<f32>,
333 shape: Vec<i64>,
334 }
335
336 impl IntoDLPack for TestTensor {
337 fn tensor_info(&self) -> TensorInfo {
338 TensorInfo::contiguous(
339 self.data.as_ptr() as *mut c_void,
340 cpu_device(),
341 dtype_f32(),
342 self.shape.clone(),
343 )
344 }
345 }
346
347 struct StridedTensor {
348 data: Vec<f32>,
349 shape: Vec<i64>,
350 strides: Vec<i64>,
351 }
352
353 impl IntoDLPack for StridedTensor {
354 fn tensor_info(&self) -> TensorInfo {
355 TensorInfo::strided(
356 self.data.as_ptr() as *mut c_void,
357 cpu_device(),
358 dtype_f32(),
359 self.shape.clone(),
360 self.strides.clone(),
361 )
362 }
363 }
364
365 struct GpuTensor {
366 device_ptr: u64,
367 shape: Vec<i64>,
368 device_id: i32,
369 }
370
371 impl IntoDLPack for GpuTensor {
372 fn tensor_info(&self) -> TensorInfo {
373 TensorInfo::contiguous(
374 self.device_ptr as *mut c_void,
375 cuda_device(self.device_id),
376 dtype_f32(),
377 self.shape.clone(),
378 )
379 }
380 }
381
382 struct OffsetTensor {
383 data: Vec<f32>,
384 shape: Vec<i64>,
385 offset: u64,
386 }
387
388 impl IntoDLPack for OffsetTensor {
389 fn tensor_info(&self) -> TensorInfo {
390 TensorInfo::contiguous(
391 self.data.as_ptr() as *mut c_void,
392 cpu_device(),
393 dtype_f32(),
394 self.shape.clone(),
395 )
396 .with_byte_offset(self.offset)
397 }
398 }
399
400 static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
402
403 struct DropTracker {
404 data: Vec<f32>,
405 shape: Vec<i64>,
406 }
407
408 impl Drop for DropTracker {
409 fn drop(&mut self) {
410 DROP_COUNT.fetch_add(1, Ordering::SeqCst);
411 }
412 }
413
414 impl IntoDLPack for DropTracker {
415 fn tensor_info(&self) -> TensorInfo {
416 TensorInfo::contiguous(
417 self.data.as_ptr() as *mut c_void,
418 cpu_device(),
419 dtype_f32(),
420 self.shape.clone(),
421 )
422 }
423 }
424
425 #[test]
430 fn test_tensor_info_contiguous() {
431 let data = [1.0f32, 2.0, 3.0, 4.0].to_vec();
432 let info = TensorInfo::contiguous(
433 data.as_ptr() as *mut c_void,
434 cpu_device(),
435 dtype_f32(),
436 vec![2, 2],
437 );
438
439 assert!(info.strides.is_none());
440 assert_eq!(info.byte_offset, 0);
441 assert_eq!(info.shape, vec![2, 2]);
442 assert!(info.device.is_cpu());
443 assert!(info.dtype.is_f32());
444 }
445
446 #[test]
447 fn test_tensor_info_strided() {
448 let data = [1.0f32; 24].to_vec();
449 let info = TensorInfo::strided(
450 data.as_ptr() as *mut c_void,
451 cpu_device(),
452 dtype_f32(),
453 vec![2, 3, 4],
454 vec![12, 4, 1],
455 );
456
457 assert_eq!(info.strides, Some(vec![12, 4, 1]));
458 assert_eq!(info.byte_offset, 0);
459 assert_eq!(info.shape, vec![2, 3, 4]);
460 }
461
462 #[test]
463 fn test_tensor_info_with_byte_offset() {
464 let data = [1.0f32; 10].to_vec();
465 let info = TensorInfo::contiguous(
466 data.as_ptr() as *mut c_void,
467 cpu_device(),
468 dtype_f32(),
469 vec![10],
470 )
471 .with_byte_offset(16);
472
473 assert_eq!(info.byte_offset, 16);
474 }
475
476 #[test]
477 fn test_tensor_info_with_different_dtypes() {
478 let data_f64 = [1.0f64; 10].to_vec();
479 let info = TensorInfo::contiguous(
480 data_f64.as_ptr() as *mut c_void,
481 cpu_device(),
482 dtype_f64(),
483 vec![10],
484 );
485 assert!(info.dtype.is_f64());
486
487 let data_i32 = [1i32; 10].to_vec();
488 let info = TensorInfo::contiguous(
489 data_i32.as_ptr() as *mut c_void,
490 cpu_device(),
491 dtype_i32(),
492 vec![10],
493 );
494 assert!(info.dtype.is_i32());
495 }
496
497 #[test]
498 fn test_tensor_info_with_different_devices() {
499 let data = [1.0f32; 10].to_vec();
500
501 let cpu_info = TensorInfo::contiguous(
502 data.as_ptr() as *mut c_void,
503 cpu_device(),
504 dtype_f32(),
505 vec![10],
506 );
507 assert!(cpu_info.device.is_cpu());
508
509 let cuda_info = TensorInfo::contiguous(
510 0x12345678 as *mut c_void,
511 cuda_device(0),
512 dtype_f32(),
513 vec![10],
514 );
515 assert!(cuda_info.device.is_cuda());
516 assert_eq!(cuda_info.device.device_id, 0);
517
518 let cuda1_info = TensorInfo::contiguous(
519 0x12345678 as *mut c_void,
520 cuda_device(1),
521 dtype_f32(),
522 vec![10],
523 );
524 assert_eq!(cuda1_info.device.device_id, 1);
525 }
526
527 #[test]
528 fn test_tensor_info_debug() {
529 let data = [1.0f32; 10].to_vec();
530 let info = TensorInfo::contiguous(
531 data.as_ptr() as *mut c_void,
532 cpu_device(),
533 dtype_f32(),
534 vec![2, 5],
535 );
536 let debug = format!("{:?}", info);
537 assert!(debug.contains("TensorInfo"));
538 assert!(debug.contains("shape"));
539 }
540
541 #[test]
542 fn test_tensor_info_clone() {
543 let data = [1.0f32; 10].to_vec();
544 let info = TensorInfo::strided(
545 data.as_ptr() as *mut c_void,
546 cpu_device(),
547 dtype_f32(),
548 vec![2, 5],
549 vec![5, 1],
550 )
551 .with_byte_offset(8);
552
553 let cloned = info.clone();
554 assert_eq!(cloned.shape, info.shape);
555 assert_eq!(cloned.strides, info.strides);
556 assert_eq!(cloned.byte_offset, info.byte_offset);
557 }
558
559 #[test]
560 fn test_tensor_info_empty_shape() {
561 let data = [1.0f32].to_vec();
562 let info = TensorInfo::contiguous(
563 data.as_ptr() as *mut c_void,
564 cpu_device(),
565 dtype_f32(),
566 vec![], );
568 assert!(info.shape.is_empty());
569 }
570
571 #[test]
572 fn test_tensor_info_high_dimensional() {
573 let data = vec![1.0f32; 120];
574 let info = TensorInfo::contiguous(
575 data.as_ptr() as *mut c_void,
576 cpu_device(),
577 dtype_f32(),
578 vec![2, 3, 4, 5],
579 );
580 assert_eq!(info.shape.len(), 4);
581 }
582
583 #[test]
588 fn test_into_dlpack_contiguous() {
589 Python::attach(|py| {
590 let tensor = TestTensor {
591 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
592 shape: vec![2, 3],
593 };
594
595 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
596 assert!(!capsule.is_none(py));
597 });
598 }
599
600 #[test]
601 fn test_into_dlpack_strided() {
602 Python::attach(|py| {
603 let tensor = StridedTensor {
604 data: vec![1.0; 24],
605 shape: vec![2, 3, 4],
606 strides: vec![12, 4, 1],
607 };
608
609 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
610 assert!(!capsule.is_none(py));
611 });
612 }
613
614 #[test]
615 fn test_into_dlpack_gpu_tensor() {
616 Python::attach(|py| {
617 let tensor = GpuTensor {
618 device_ptr: 0xDEADBEEF,
619 shape: vec![16, 32],
620 device_id: 0,
621 };
622
623 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
624 assert!(!capsule.is_none(py));
625 });
626 }
627
628 #[test]
629 fn test_into_dlpack_with_offset() {
630 Python::attach(|py| {
631 let tensor = OffsetTensor {
632 data: vec![1.0; 20],
633 shape: vec![10],
634 offset: 40, };
636
637 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
638 assert!(!capsule.is_none(py));
639 });
640 }
641
642 #[test]
643 fn test_into_dlpack_scalar() {
644 Python::attach(|py| {
645 let tensor = TestTensor {
646 data: vec![42.0],
647 shape: vec![], };
649
650 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
651 assert!(!capsule.is_none(py));
652 });
653 }
654
655 #[test]
656 fn test_into_dlpack_1d() {
657 Python::attach(|py| {
658 let tensor = TestTensor {
659 data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
660 shape: vec![5],
661 };
662
663 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
664 assert!(!capsule.is_none(py));
665 });
666 }
667
668 #[test]
673 fn test_capsule_cleanup_on_drop() {
674 DROP_COUNT.store(0, Ordering::SeqCst);
675
676 Python::attach(|py| {
677 {
678 let tensor = DropTracker {
679 data: vec![1.0, 2.0, 3.0],
680 shape: vec![3],
681 };
682
683 let _capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
684 }
686 py.run(c"import gc; gc.collect()", None, None).unwrap();
688 });
689
690 }
694
695 #[test]
696 fn test_deleter_null_check() {
697 unsafe {
699 dlpack_deleter::<TestTensor>(std::ptr::null_mut());
700 }
701 }
703
704 #[test]
705 fn test_capsule_destructor_null_check() {
706 unsafe {
708 raw_capsule_destructor(std::ptr::null_mut());
709 }
710 }
712
713 #[test]
718 fn test_into_dlpack_requires_send() {
719 fn assert_send<T: Send>() {}
721 assert_send::<TestTensor>();
722 assert_send::<StridedTensor>();
723 assert_send::<GpuTensor>();
724 assert_send::<OffsetTensor>();
725 assert_send::<DropTracker>();
726 }
727
728 #[test]
733 fn test_large_shape() {
734 Python::attach(|py| {
735 let tensor = TestTensor {
736 data: vec![1.0; 1000000],
737 shape: vec![100, 100, 100],
738 };
739
740 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
741 assert!(!capsule.is_none(py));
742 });
743 }
744
745 #[test]
746 fn test_non_contiguous_strides() {
747 Python::attach(|py| {
748 let tensor = StridedTensor {
750 data: vec![1.0; 6],
751 shape: vec![2, 3],
752 strides: vec![1, 2], };
754
755 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
756 assert!(!capsule.is_none(py));
757 });
758 }
759
760 #[test]
761 fn test_zero_stride() {
762 Python::attach(|py| {
763 let tensor = StridedTensor {
765 data: vec![1.0; 3],
766 shape: vec![2, 3],
767 strides: vec![0, 1], };
769
770 let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
771 assert!(!capsule.is_none(py));
772 });
773 }
774}