1use crate::ffi::{
7 DLDataType, DLDevice, DLManagedTensor, DLManagedTensorVersioned, DLTensor,
8 DLPACK_FLAG_BITMASK_READ_ONLY, DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION,
9};
10use crate::{
11 DLPACK_CAPSULE_NAME, DLPACK_CAPSULE_NAME_USED, DLPACK_VERSIONED_CAPSULE_NAME,
12 DLPACK_VERSIONED_CAPSULE_NAME_USED,
13};
14use pyo3::prelude::*;
15use pyo3::types::PyCapsule;
16use std::ffi::{c_void, CStr};
17use std::ptr::NonNull;
18
19#[derive(Clone, Copy)]
25enum ManagedPtr {
26 Unversioned(NonNull<DLManagedTensor>),
27 Versioned(NonNull<DLManagedTensorVersioned>),
28}
29
30impl ManagedPtr {
31 unsafe fn dl_tensor(&self) -> &DLTensor {
37 match *self {
38 ManagedPtr::Unversioned(p) => &p.as_ref().dl_tensor,
39 ManagedPtr::Versioned(p) => &p.as_ref().dl_tensor,
40 }
41 }
42
43 unsafe fn run_deleter(&self) {
48 match *self {
49 ManagedPtr::Unversioned(p) => {
50 if let Some(deleter) = p.as_ref().deleter {
51 deleter(p.as_ptr());
52 }
53 }
54 ManagedPtr::Versioned(p) => {
55 if let Some(deleter) = p.as_ref().deleter {
56 deleter(p.as_ptr());
57 }
58 }
59 }
60 }
61}
62
63pub struct PyTensor {
98 managed: ManagedPtr,
99 #[allow(dead_code)]
102 capsule: Py<PyCapsule>,
103}
104
105unsafe impl Send for PyTensor {}
108
109fn validate_ndim(ndim: i32) -> PyResult<()> {
114 if ndim < 0 {
115 return Err(pyo3::exceptions::PyValueError::new_err(format!(
116 "DLPack tensor has negative ndim: {ndim}"
117 )));
118 }
119 Ok(())
120}
121
122impl PyTensor {
123 fn dl_tensor(&self) -> &DLTensor {
126 unsafe { self.managed.dl_tensor() }
127 }
128
129 pub fn from_pyany(_py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<Self> {
146 let py = obj.py();
147
148 let kwargs = pyo3::types::PyDict::new(py);
152 kwargs.set_item(
153 pyo3::intern!(py, "max_version"),
154 (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION),
155 )?;
156
157 let capsule_obj = match obj.call_method("__dlpack__", (), Some(&kwargs)) {
158 Ok(c) => c,
159 Err(e) if e.is_instance_of::<pyo3::exceptions::PyTypeError>(py) => {
160 obj.call_method0("__dlpack__")?
161 }
162 Err(e) => return Err(e),
163 };
164
165 let capsule: Bound<'_, PyCapsule> = capsule_obj.cast_into().map_err(|e| {
166 pyo3::exceptions::PyTypeError::new_err(format!(
167 "__dlpack__ did not return a PyCapsule: {:?}",
168 e.into_inner()
169 ))
170 })?;
171 Self::from_capsule(&capsule)
172 }
173
174 pub fn from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
184 let name_ptr = unsafe { pyo3::ffi::PyCapsule_GetName(capsule.as_ptr()) };
188 if name_ptr.is_null() {
189 return Err(pyo3::exceptions::PyValueError::new_err(
190 "DLPack capsule has no name",
191 ));
192 }
193 let name = unsafe { CStr::from_ptr(name_ptr) };
194 let name_bytes = name.to_bytes();
195
196 if name_bytes == DLPACK_CAPSULE_NAME.to_bytes() {
197 Self::from_unversioned_capsule(capsule)
198 } else if name_bytes == DLPACK_VERSIONED_CAPSULE_NAME.to_bytes() {
199 Self::from_versioned_capsule(capsule)
200 } else {
201 Err(pyo3::exceptions::PyValueError::new_err(format!(
202 "unexpected DLPack capsule name: {:?}",
203 name
204 )))
205 }
206 }
207
208 fn from_unversioned_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
210 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME))?;
211 let managed = NonNull::new(ptr.as_ptr() as *mut DLManagedTensor).ok_or_else(|| {
212 pyo3::exceptions::PyValueError::new_err("DLPack capsule contains null pointer")
213 })?;
214
215 validate_ndim(unsafe { managed.as_ref().dl_tensor.ndim })?;
218
219 let set_name_result = unsafe {
227 pyo3::ffi::PyCapsule_SetName(capsule.as_ptr(), DLPACK_CAPSULE_NAME_USED.as_ptr())
228 };
229 if set_name_result != 0 {
230 return Err(pyo3::exceptions::PyRuntimeError::new_err(
231 "Failed to mark DLPack capsule as consumed: PyCapsule_SetName failed",
232 ));
233 }
234
235 Ok(Self {
236 managed: ManagedPtr::Unversioned(managed),
237 capsule: capsule.clone().unbind(),
238 })
239 }
240
241 fn from_versioned_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
243 let ptr = capsule.pointer_checked(Some(DLPACK_VERSIONED_CAPSULE_NAME))?;
244 let managed =
245 NonNull::new(ptr.as_ptr() as *mut DLManagedTensorVersioned).ok_or_else(|| {
246 pyo3::exceptions::PyValueError::new_err("DLPack capsule contains null pointer")
247 })?;
248
249 let version = unsafe { managed.as_ref().version };
256 if version.major != DLPACK_MAJOR_VERSION {
257 return Err(pyo3::exceptions::PyValueError::new_err(format!(
258 "unsupported DLPack major version {}.{} (this build supports major version {})",
259 version.major, version.minor, DLPACK_MAJOR_VERSION
260 )));
261 }
262
263 validate_ndim(unsafe { managed.as_ref().dl_tensor.ndim })?;
267
268 let set_name_result = unsafe {
271 pyo3::ffi::PyCapsule_SetName(
272 capsule.as_ptr(),
273 DLPACK_VERSIONED_CAPSULE_NAME_USED.as_ptr(),
274 )
275 };
276 if set_name_result != 0 {
277 return Err(pyo3::exceptions::PyRuntimeError::new_err(
278 "Failed to mark DLPack capsule as consumed: PyCapsule_SetName failed",
279 ));
280 }
281
282 Ok(Self {
283 managed: ManagedPtr::Versioned(managed),
284 capsule: capsule.clone().unbind(),
285 })
286 }
287
288 pub fn device(&self) -> DLDevice {
290 self.dl_tensor().device
291 }
292
293 pub fn dtype(&self) -> DLDataType {
295 self.dl_tensor().dtype
296 }
297
298 pub fn ndim(&self) -> usize {
300 self.dl_tensor().ndim as usize
301 }
302
303 pub fn shape(&self) -> &[i64] {
307 let tensor = self.dl_tensor();
308 if tensor.shape.is_null() {
309 &[]
310 } else {
311 unsafe { std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize) }
312 }
313 }
314
315 pub fn strides(&self) -> Option<&[i64]> {
320 let tensor = self.dl_tensor();
321 if tensor.strides.is_null() {
322 None
323 } else {
324 Some(unsafe { std::slice::from_raw_parts(tensor.strides, tensor.ndim as usize) })
325 }
326 }
327
328 pub fn is_contiguous(&self) -> bool {
330 match self.strides() {
331 None => true,
332 Some(strides) => {
333 let shape = self.shape();
334 if shape.is_empty() {
335 return true;
336 }
337
338 let mut expected_stride = 1i64;
339 for i in (0..shape.len()).rev() {
340 if strides[i] != expected_stride {
341 return false;
342 }
343 expected_stride *= shape[i];
344 }
345 true
346 }
347 }
348 }
349
350 pub fn data_ptr(&self) -> *mut c_void {
357 let tensor = self.dl_tensor();
358 (tensor.data as *mut u8).wrapping_add(tensor.byte_offset as usize) as *mut c_void
362 }
363
364 pub fn data_ptr_raw(&self) -> *mut c_void {
366 self.dl_tensor().data
367 }
368
369 pub fn byte_offset(&self) -> u64 {
371 self.dl_tensor().byte_offset
372 }
373
374 pub fn numel(&self) -> usize {
376 self.shape().iter().map(|&d| d as usize).product()
377 }
378
379 pub fn itemsize(&self) -> usize {
381 self.dtype().itemsize()
382 }
383
384 pub fn nbytes(&self) -> usize {
386 self.numel() * self.itemsize()
387 }
388
389 pub fn is_read_only(&self) -> bool {
394 match self.managed {
395 ManagedPtr::Unversioned(_) => false,
396 ManagedPtr::Versioned(p) => unsafe {
397 p.as_ref().flags & DLPACK_FLAG_BITMASK_READ_ONLY != 0
398 },
399 }
400 }
401}
402
403impl Drop for PyTensor {
404 fn drop(&mut self) {
405 unsafe { self.managed.run_deleter() }
407 }
408}
409
410impl std::fmt::Debug for PyTensor {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 f.debug_struct("PyTensor")
413 .field("shape", &self.shape())
414 .field("strides", &self.strides())
415 .field("dtype", &self.dtype())
416 .field("device", &self.device())
417 .field("byte_offset", &self.byte_offset())
418 .finish()
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use super::*;
425 use crate::ffi::{cpu_device, cuda_device, dtype_f32, dtype_f64, DLTensor};
426 use pyo3::Python;
427 use std::ffi::CString;
428
429 #[repr(transparent)]
431 struct SendableTestPtr(*mut DLManagedTensor);
432 unsafe impl Send for SendableTestPtr {}
433
434 struct TestManagedTensor {
436 managed: Box<DLManagedTensor>,
437 shape: Vec<i64>,
438 strides: Option<Vec<i64>>,
439 #[allow(dead_code)]
440 data: Vec<u8>,
441 }
442
443 impl TestManagedTensor {
444 fn new(
445 shape: Vec<i64>,
446 strides: Option<Vec<i64>>,
447 dtype: DLDataType,
448 device: DLDevice,
449 ) -> Self {
450 let numel: usize = shape.iter().map(|&d| d as usize).product();
451 let data = vec![0u8; numel.max(1) * dtype.itemsize()];
452
453 let mut result = Self {
454 managed: Box::new(DLManagedTensor {
455 dl_tensor: DLTensor {
456 data: std::ptr::null_mut(),
457 device,
458 ndim: shape.len() as i32,
459 dtype,
460 shape: std::ptr::null_mut(),
461 strides: std::ptr::null_mut(),
462 byte_offset: 0,
463 },
464 manager_ctx: std::ptr::null_mut(),
465 deleter: None,
466 }),
467 shape,
468 strides,
469 data,
470 };
471
472 result.managed.dl_tensor.data = result.data.as_ptr() as *mut c_void;
474 result.managed.dl_tensor.shape = result.shape.as_mut_ptr();
475 if let Some(ref mut s) = result.strides {
476 result.managed.dl_tensor.strides = s.as_mut_ptr();
477 }
478
479 result
480 }
481
482 fn with_byte_offset(mut self, offset: u64) -> Self {
483 self.managed.dl_tensor.byte_offset = offset;
484 self
485 }
486
487 fn as_ptr(&self) -> *mut DLManagedTensor {
488 &*self.managed as *const _ as *mut _
489 }
490 }
491
492 #[test]
497 fn test_is_contiguous_no_strides() {
498 let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
500
501 let managed = unsafe { &*tensor.as_ptr() };
503 let strides_ptr = managed.dl_tensor.strides;
504
505 assert!(strides_ptr.is_null());
507 }
508
509 #[test]
510 fn test_is_contiguous_with_contiguous_strides() {
511 let tensor = TestManagedTensor::new(
514 vec![2, 3, 4],
515 Some(vec![12, 4, 1]),
516 dtype_f32(),
517 cpu_device(),
518 );
519
520 let shape = &tensor.shape;
521 let strides = tensor.strides.as_ref().unwrap();
522
523 let mut expected_stride = 1i64;
525 let mut is_contiguous = true;
526 for i in (0..shape.len()).rev() {
527 if strides[i] != expected_stride {
528 is_contiguous = false;
529 break;
530 }
531 expected_stride *= shape[i];
532 }
533 assert!(is_contiguous);
534 }
535
536 #[test]
537 fn test_is_contiguous_with_non_contiguous_strides() {
538 let tensor = TestManagedTensor::new(
540 vec![2, 3, 4],
541 Some(vec![1, 2, 6]), dtype_f32(),
543 cpu_device(),
544 );
545
546 let shape = &tensor.shape;
547 let strides = tensor.strides.as_ref().unwrap();
548
549 let mut expected_stride = 1i64;
550 let mut is_contiguous = true;
551 for i in (0..shape.len()).rev() {
552 if strides[i] != expected_stride {
553 is_contiguous = false;
554 break;
555 }
556 expected_stride *= shape[i];
557 }
558 assert!(!is_contiguous);
559 }
560
561 #[test]
562 fn test_is_contiguous_empty_tensor() {
563 let tensor = TestManagedTensor::new(vec![], None, dtype_f32(), cpu_device());
564 assert!(tensor.shape.is_empty());
566 }
567
568 #[test]
569 fn test_is_contiguous_1d() {
570 let tensor = TestManagedTensor::new(vec![10], Some(vec![1]), dtype_f32(), cpu_device());
571 let strides = tensor.strides.as_ref().unwrap();
572 assert_eq!(strides[0], 1);
573 }
574
575 #[test]
580 fn test_numel_calculation() {
581 let shapes_and_expected: Vec<(Vec<i64>, usize)> = vec![
582 (vec![], 1), (vec![5], 5),
584 (vec![2, 3], 6),
585 (vec![2, 3, 4], 24),
586 (vec![1, 1, 1, 1], 1),
587 (vec![10, 20, 30], 6000),
588 ];
589
590 for (shape, expected) in shapes_and_expected {
591 let numel: usize = if shape.is_empty() {
592 1 } else {
594 shape.iter().map(|&d| d as usize).product()
595 };
596 assert_eq!(numel, expected, "Failed for shape {:?}", shape);
597 }
598 }
599
600 #[test]
601 fn test_nbytes_calculation() {
602 let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
604 let numel: usize = tensor.shape.iter().map(|&d| d as usize).product();
605 let itemsize = dtype_f32().itemsize();
606 assert_eq!(numel * itemsize, 96);
607
608 let tensor2 = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
610 let numel2: usize = tensor2.shape.iter().map(|&d| d as usize).product();
611 let itemsize2 = dtype_f64().itemsize();
612 assert_eq!(numel2 * itemsize2, 48);
613 }
614
615 #[test]
620 fn test_data_ptr_with_offset() {
621 let tensor =
622 TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device()).with_byte_offset(16);
623
624 let managed = unsafe { &*tensor.as_ptr() };
625 let base_ptr = managed.dl_tensor.data as usize;
626 let offset = managed.dl_tensor.byte_offset as usize;
627 let adjusted_ptr = base_ptr + offset;
628
629 assert_eq!(offset, 16);
630 assert_eq!(adjusted_ptr, base_ptr + 16);
631 }
632
633 #[test]
634 fn test_data_ptr_no_offset() {
635 let tensor = TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device());
636
637 let managed = unsafe { &*tensor.as_ptr() };
638 assert_eq!(managed.dl_tensor.byte_offset, 0);
639 }
640
641 #[test]
646 fn test_device_accessor() {
647 let cpu_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
648 let managed = unsafe { &*cpu_tensor.as_ptr() };
649 assert!(managed.dl_tensor.device.is_cpu());
650
651 let cuda_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cuda_device(1));
652 let managed = unsafe { &*cuda_tensor.as_ptr() };
653 assert!(managed.dl_tensor.device.is_cuda());
654 assert_eq!(managed.dl_tensor.device.device_id, 1);
655 }
656
657 #[test]
658 fn test_dtype_accessor() {
659 let f32_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
660 let managed = unsafe { &*f32_tensor.as_ptr() };
661 assert!(managed.dl_tensor.dtype.is_f32());
662
663 let f64_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
664 let managed = unsafe { &*f64_tensor.as_ptr() };
665 assert!(managed.dl_tensor.dtype.is_f64());
666 }
667
668 #[test]
673 fn test_ndim() {
674 let shapes: Vec<Vec<i64>> = vec![
675 vec![],
676 vec![5],
677 vec![2, 3],
678 vec![2, 3, 4],
679 vec![1, 2, 3, 4, 5],
680 ];
681
682 for shape in shapes {
683 let expected_ndim = shape.len();
684 let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
685 let managed = unsafe { &*tensor.as_ptr() };
686 assert_eq!(managed.dl_tensor.ndim as usize, expected_ndim);
687 }
688 }
689
690 #[test]
691 fn test_shape_accessor() {
692 let shape = vec![2i64, 3, 4];
693 let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
694 let managed = unsafe { &*tensor.as_ptr() };
695
696 let shape_slice = unsafe {
697 std::slice::from_raw_parts(managed.dl_tensor.shape, managed.dl_tensor.ndim as usize)
698 };
699 assert_eq!(shape_slice, &[2, 3, 4]);
700 }
701
702 #[test]
707 fn test_capsule_creation_and_extraction() {
708 Python::attach(|py| {
709 let mut shape = vec![2i64, 3];
711 let data = [0u8; 24].to_vec(); let managed = Box::new(DLManagedTensor {
714 dl_tensor: DLTensor {
715 data: data.as_ptr() as *mut c_void,
716 device: cpu_device(),
717 ndim: 2,
718 dtype: dtype_f32(),
719 shape: shape.as_mut_ptr(),
720 strides: std::ptr::null_mut(),
721 byte_offset: 0,
722 },
723 manager_ctx: std::ptr::null_mut(),
724 deleter: None,
725 });
726
727 let managed_ptr = Box::into_raw(managed);
728 let sendable = SendableTestPtr(managed_ptr);
729 let name = CString::new("dltensor").unwrap();
730
731 let capsule =
733 PyCapsule::new(py, sendable, Some(name)).expect("Failed to create capsule");
734
735 let capsule_name = capsule.name().expect("Failed to get name");
737 assert!(capsule_name.is_some());
738
739 let _extracted = capsule
741 .pointer_checked(Some(DLPACK_CAPSULE_NAME))
742 .expect("Failed to extract pointer");
743
744 unsafe {
746 let _ = Box::from_raw(managed_ptr);
747 }
748 });
749 }
750
751 #[test]
752 fn test_capsule_wrong_name() {
753 #[allow(dead_code)]
755 struct TestData(i32);
756 unsafe impl Send for TestData {}
757
758 Python::attach(|py| {
759 let data = TestData(42);
760 let name = CString::new("wrong_name").unwrap();
761
762 let capsule = PyCapsule::new(py, data, Some(name)).expect("Failed to create capsule");
763
764 let result = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME));
766 assert!(result.is_err());
767 });
768 }
769
770 #[test]
771 fn test_pytensor_send() {
772 fn assert_send<T: Send>() {}
774 assert_send::<PyTensor>();
775 }
776
777 use std::sync::atomic::{AtomicUsize, Ordering};
782
783 static DELETER_CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
784
785 struct TestTensorContext {
787 data: Vec<f32>,
788 shape: Vec<i64>,
789 strides: Option<Vec<i64>>,
790 }
791
792 fn create_test_capsule(
794 py: Python<'_>,
795 ctx: Box<TestTensorContext>,
796 device: DLDevice,
797 dtype: DLDataType,
798 byte_offset: u64,
799 with_deleter: bool,
800 ) -> PyResult<Bound<'_, PyCapsule>> {
801 let ctx_ptr = Box::into_raw(ctx);
802
803 unsafe {
804 let ctx_ref = &mut *ctx_ptr;
805
806 let managed = Box::new(DLManagedTensor {
807 dl_tensor: DLTensor {
808 data: ctx_ref.data.as_ptr() as *mut c_void,
809 device,
810 ndim: ctx_ref.shape.len() as i32,
811 dtype,
812 shape: ctx_ref.shape.as_mut_ptr(),
813 strides: ctx_ref
814 .strides
815 .as_mut()
816 .map(|s| s.as_mut_ptr())
817 .unwrap_or(std::ptr::null_mut()),
818 byte_offset,
819 },
820 manager_ctx: ctx_ptr as *mut c_void,
821 deleter: if with_deleter {
822 Some(test_deleter)
823 } else {
824 None
825 },
826 });
827
828 let managed_ptr = Box::into_raw(managed);
829 let wrapper = SendableTestPtr(managed_ptr);
830 let name = CString::new("dltensor").unwrap();
831
832 PyCapsule::new(py, wrapper, Some(name))
833 }
834 }
835
836 unsafe extern "C" fn test_deleter(managed_ptr: *mut DLManagedTensor) {
838 if !managed_ptr.is_null() {
839 DELETER_CALL_COUNT.fetch_add(1, Ordering::SeqCst);
840 let managed = Box::from_raw(managed_ptr);
841 if !managed.manager_ctx.is_null() {
842 let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
843 }
844 }
845 }
846
847 #[test]
848 fn test_pytensor_all_accessors() {
849 Python::attach(|py| {
850 let ctx = Box::new(TestTensorContext {
851 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
852 shape: vec![2, 3],
853 strides: None,
854 });
855
856 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
857 .expect("Failed to create capsule");
858
859 let ptr = capsule
861 .pointer_checked(Some(DLPACK_CAPSULE_NAME))
862 .expect("Failed to get pointer");
863 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
865 let managed = NonNull::new(managed_ptr).expect("Null pointer");
866
867 let pytensor = PyTensor {
869 managed: ManagedPtr::Unversioned(managed),
870 capsule: capsule.clone().unbind(),
871 };
872
873 assert!(pytensor.device().is_cpu());
875 assert!(pytensor.dtype().is_f32());
876 assert_eq!(pytensor.ndim(), 2);
877 assert_eq!(pytensor.shape(), &[2, 3]);
878 assert!(pytensor.strides().is_none());
879 assert!(pytensor.is_contiguous());
880 assert!(!pytensor.data_ptr().is_null());
881 assert!(!pytensor.data_ptr_raw().is_null());
882 assert_eq!(pytensor.byte_offset(), 0);
883 assert_eq!(pytensor.numel(), 6);
884 assert_eq!(pytensor.itemsize(), 4);
885 assert_eq!(pytensor.nbytes(), 24);
886
887 let debug = format!("{:?}", pytensor);
889 assert!(debug.contains("PyTensor"));
890 assert!(debug.contains("shape"));
891 assert!(debug.contains("dtype"));
892 assert!(debug.contains("device"));
893
894 std::mem::forget(pytensor);
896 });
897 }
898
899 #[test]
900 fn test_pytensor_with_strides_contiguous() {
901 Python::attach(|py| {
902 let ctx = Box::new(TestTensorContext {
903 data: vec![1.0; 24],
904 shape: vec![2, 3, 4],
905 strides: Some(vec![12, 4, 1]), });
907
908 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
909 .expect("Failed to create capsule");
910
911 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
912 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
913 let managed = NonNull::new(managed_ptr).unwrap();
914
915 let pytensor = PyTensor {
916 managed: ManagedPtr::Unversioned(managed),
917 capsule: capsule.clone().unbind(),
918 };
919
920 assert_eq!(pytensor.ndim(), 3);
921 assert_eq!(pytensor.shape(), &[2, 3, 4]);
922 assert_eq!(pytensor.strides(), Some(&[12i64, 4, 1][..]));
923 assert!(pytensor.is_contiguous());
924 assert_eq!(pytensor.numel(), 24);
925
926 std::mem::forget(pytensor);
927 });
928 }
929
930 #[test]
931 fn test_pytensor_non_contiguous() {
932 Python::attach(|py| {
933 let ctx = Box::new(TestTensorContext {
934 data: vec![1.0; 6],
935 shape: vec![2, 3],
936 strides: Some(vec![1, 2]), });
938
939 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
940 .expect("Failed to create capsule");
941
942 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
943 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
944 let managed = NonNull::new(managed_ptr).unwrap();
945
946 let pytensor = PyTensor {
947 managed: ManagedPtr::Unversioned(managed),
948 capsule: capsule.clone().unbind(),
949 };
950
951 assert!(!pytensor.is_contiguous());
952 assert_eq!(pytensor.strides(), Some(&[1i64, 2][..]));
953
954 std::mem::forget(pytensor);
955 });
956 }
957
958 #[test]
959 fn test_pytensor_scalar() {
960 Python::attach(|py| {
961 let ctx = Box::new(TestTensorContext {
962 data: vec![42.0],
963 shape: vec![],
964 strides: None,
965 });
966
967 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
968 .expect("Failed to create capsule");
969
970 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
971 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
972 let managed = NonNull::new(managed_ptr).unwrap();
973
974 let pytensor = PyTensor {
975 managed: ManagedPtr::Unversioned(managed),
976 capsule: capsule.clone().unbind(),
977 };
978
979 assert_eq!(pytensor.ndim(), 0);
980 assert!(pytensor.shape().is_empty());
981 assert!(pytensor.is_contiguous());
982 assert_eq!(pytensor.numel(), 1);
983
984 std::mem::forget(pytensor);
985 });
986 }
987
988 #[test]
989 fn test_pytensor_with_byte_offset() {
990 Python::attach(|py| {
991 let ctx = Box::new(TestTensorContext {
992 data: vec![1.0; 20],
993 shape: vec![10],
994 strides: None,
995 });
996
997 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 16, false)
998 .expect("Failed to create capsule");
999
1000 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1001 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1002 let managed = NonNull::new(managed_ptr).unwrap();
1003
1004 let pytensor = PyTensor {
1005 managed: ManagedPtr::Unversioned(managed),
1006 capsule: capsule.clone().unbind(),
1007 };
1008
1009 assert_eq!(pytensor.byte_offset(), 16);
1010 let raw = pytensor.data_ptr_raw() as usize;
1011 let adjusted = pytensor.data_ptr() as usize;
1012 assert_eq!(adjusted, raw + 16);
1013
1014 std::mem::forget(pytensor);
1015 });
1016 }
1017
1018 #[test]
1019 fn test_pytensor_cuda_device() {
1020 Python::attach(|py| {
1021 let ctx = Box::new(TestTensorContext {
1022 data: vec![1.0; 512],
1023 shape: vec![16, 32],
1024 strides: None,
1025 });
1026
1027 let capsule = create_test_capsule(py, ctx, cuda_device(1), dtype_f32(), 0, false)
1028 .expect("Failed to create capsule");
1029
1030 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1031 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1032 let managed = NonNull::new(managed_ptr).unwrap();
1033
1034 let pytensor = PyTensor {
1035 managed: ManagedPtr::Unversioned(managed),
1036 capsule: capsule.clone().unbind(),
1037 };
1038
1039 assert!(pytensor.device().is_cuda());
1040 assert_eq!(pytensor.device().device_id, 1);
1041
1042 std::mem::forget(pytensor);
1043 });
1044 }
1045
1046 #[test]
1047 fn test_pytensor_f64_dtype() {
1048 Python::attach(|py| {
1049 let ctx = Box::new(TestTensorContext {
1051 data: vec![1.0; 6], shape: vec![3],
1053 strides: None,
1054 });
1055
1056 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f64(), 0, false)
1057 .expect("Failed to create capsule");
1058
1059 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1060 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1061 let managed = NonNull::new(managed_ptr).unwrap();
1062
1063 let pytensor = PyTensor {
1064 managed: ManagedPtr::Unversioned(managed),
1065 capsule: capsule.clone().unbind(),
1066 };
1067
1068 assert!(pytensor.dtype().is_f64());
1069 assert_eq!(pytensor.itemsize(), 8);
1070 assert_eq!(pytensor.nbytes(), 24);
1071
1072 std::mem::forget(pytensor);
1073 });
1074 }
1075
1076 #[test]
1077 fn test_pytensor_empty_strides_scalar() {
1078 Python::attach(|py| {
1079 let ctx = Box::new(TestTensorContext {
1080 data: vec![1.0],
1081 shape: vec![],
1082 strides: Some(vec![]), });
1084
1085 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
1086 .expect("Failed to create capsule");
1087
1088 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1089 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1090 let managed = NonNull::new(managed_ptr).unwrap();
1091
1092 let pytensor = PyTensor {
1093 managed: ManagedPtr::Unversioned(managed),
1094 capsule: capsule.clone().unbind(),
1095 };
1096
1097 assert!(pytensor.is_contiguous());
1098 assert!(pytensor.strides().is_some());
1099 assert!(pytensor.strides().unwrap().is_empty());
1100
1101 std::mem::forget(pytensor);
1102 });
1103 }
1104
1105 #[test]
1106 fn test_pytensor_drop_calls_deleter() {
1107 DELETER_CALL_COUNT.store(0, Ordering::SeqCst);
1108
1109 Python::attach(|py| {
1110 let ctx = Box::new(TestTensorContext {
1111 data: vec![1.0, 2.0, 3.0],
1112 shape: vec![3],
1113 strides: None,
1114 });
1115
1116 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, true)
1117 .expect("Failed to create capsule");
1118
1119 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1120 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1121 let managed = NonNull::new(managed_ptr).unwrap();
1122
1123 {
1124 let pytensor = PyTensor {
1125 managed: ManagedPtr::Unversioned(managed),
1126 capsule: capsule.clone().unbind(),
1127 };
1128
1129 assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 0);
1131
1132 drop(pytensor);
1134 }
1135
1136 assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 1);
1138 });
1139 }
1140
1141 #[test]
1142 fn test_pytensor_drop_no_deleter() {
1143 Python::attach(|py| {
1144 let ctx = Box::new(TestTensorContext {
1145 data: vec![1.0],
1146 shape: vec![1],
1147 strides: None,
1148 });
1149
1150 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
1151 .expect("Failed to create capsule");
1152
1153 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1154 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1155 let managed = NonNull::new(managed_ptr).unwrap();
1156
1157 let pytensor = PyTensor {
1158 managed: ManagedPtr::Unversioned(managed),
1159 capsule: capsule.clone().unbind(),
1160 };
1161
1162 drop(pytensor);
1164
1165 unsafe {
1167 let managed = Box::from_raw(managed_ptr);
1168 if !managed.manager_ctx.is_null() {
1169 let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
1170 }
1171 }
1172 });
1173 }
1174
1175 struct RoundTripTensor {
1180 data: Vec<f32>,
1181 shape: Vec<i64>,
1182 }
1183
1184 impl crate::IntoDLPack for RoundTripTensor {
1185 fn tensor_info(&self) -> crate::TensorInfo {
1186 crate::TensorInfo::contiguous(
1187 self.data.as_ptr() as *mut c_void,
1188 cpu_device(),
1189 dtype_f32(),
1190 self.shape.clone(),
1191 )
1192 }
1193 }
1194
1195 #[test]
1196 fn test_roundtrip_versioned_readonly() {
1197 use crate::IntoDLPack;
1198 Python::attach(|py| {
1199 let t = RoundTripTensor {
1200 data: vec![1.0, 2.0, 3.0, 4.0],
1201 shape: vec![2, 2],
1202 };
1203 let capsule_obj = t.into_dlpack_readonly(py).unwrap();
1204 let bound = capsule_obj.into_bound(py);
1205 let capsule: Bound<'_, PyCapsule> = bound.cast_into().unwrap();
1206
1207 let tensor = PyTensor::from_capsule(&capsule).unwrap();
1208 assert!(tensor.is_read_only());
1209 assert_eq!(tensor.shape(), &[2, 2]);
1210 assert!(tensor.device().is_cpu());
1211 assert!(tensor.dtype().is_f32());
1212 });
1214 }
1215
1216 #[test]
1217 fn test_roundtrip_unversioned_not_readonly() {
1218 use crate::IntoDLPack;
1219 Python::attach(|py| {
1220 let t = RoundTripTensor {
1221 data: vec![1.0, 2.0, 3.0, 4.0],
1222 shape: vec![2, 2],
1223 };
1224 let capsule_obj = t.into_dlpack(py).unwrap();
1225 let bound = capsule_obj.into_bound(py);
1226 let capsule: Bound<'_, PyCapsule> = bound.cast_into().unwrap();
1227
1228 let tensor = PyTensor::from_capsule(&capsule).unwrap();
1229 assert!(!tensor.is_read_only());
1230 assert_eq!(tensor.shape(), &[2, 2]);
1231 });
1232 }
1233
1234 #[test]
1235 fn test_from_capsule_rejects_unknown_name() {
1236 Python::attach(|py| {
1237 let dummy = Box::new(0u8);
1240 let dummy_ptr = Box::into_raw(dummy);
1241 let capsule_ptr = unsafe {
1242 pyo3::ffi::PyCapsule_New(
1243 dummy_ptr as *mut c_void,
1244 c"not_a_dlpack_capsule".as_ptr(),
1245 None,
1246 )
1247 };
1248 assert!(!capsule_ptr.is_null());
1249 let capsule: Bound<'_, PyCapsule> = unsafe { Bound::from_owned_ptr(py, capsule_ptr) }
1250 .cast_into()
1251 .unwrap();
1252
1253 let result = PyTensor::from_capsule(&capsule);
1254 assert!(result.is_err());
1255
1256 unsafe {
1258 let _ = Box::from_raw(dummy_ptr);
1259 }
1260 });
1261 }
1262
1263 #[test]
1264 fn test_versioned_rejects_too_new_major() {
1265 Python::attach(|py| {
1266 let mut shape = vec![1i64];
1269 let data = vec![0.0f32];
1270 let managed = Box::new(DLManagedTensorVersioned {
1271 version: crate::ffi::DLPackVersion {
1272 major: DLPACK_MAJOR_VERSION + 1,
1273 minor: 0,
1274 },
1275 manager_ctx: std::ptr::null_mut(),
1276 deleter: None,
1277 flags: 0,
1278 dl_tensor: DLTensor {
1279 data: data.as_ptr() as *mut c_void,
1280 device: cpu_device(),
1281 ndim: 1,
1282 dtype: dtype_f32(),
1283 shape: shape.as_mut_ptr(),
1284 strides: std::ptr::null_mut(),
1285 byte_offset: 0,
1286 },
1287 });
1288 let managed_ptr = Box::into_raw(managed);
1289 let capsule_ptr = unsafe {
1290 pyo3::ffi::PyCapsule_New(
1291 managed_ptr as *mut c_void,
1292 c"dltensor_versioned".as_ptr(),
1293 None,
1294 )
1295 };
1296 assert!(!capsule_ptr.is_null());
1297 let capsule: Bound<'_, PyCapsule> = unsafe { Bound::from_owned_ptr(py, capsule_ptr) }
1298 .cast_into()
1299 .unwrap();
1300
1301 let result = PyTensor::from_capsule(&capsule);
1302 assert!(result.is_err());
1303
1304 unsafe {
1306 let _ = Box::from_raw(managed_ptr);
1307 }
1308 drop(shape);
1310 drop(data);
1311 });
1312 }
1313
1314 #[test]
1315 fn test_versioned_rejects_mismatched_lower_major() {
1316 Python::attach(|py| {
1317 let mut shape = vec![1i64];
1321 let data = vec![0.0f32];
1322 let managed = Box::new(DLManagedTensorVersioned {
1323 version: crate::ffi::DLPackVersion {
1324 major: DLPACK_MAJOR_VERSION - 1,
1325 minor: 0,
1326 },
1327 manager_ctx: std::ptr::null_mut(),
1328 deleter: None,
1329 flags: 0,
1330 dl_tensor: DLTensor {
1331 data: data.as_ptr() as *mut c_void,
1332 device: cpu_device(),
1333 ndim: 1,
1334 dtype: dtype_f32(),
1335 shape: shape.as_mut_ptr(),
1336 strides: std::ptr::null_mut(),
1337 byte_offset: 0,
1338 },
1339 });
1340 let managed_ptr = Box::into_raw(managed);
1341 let capsule_ptr = unsafe {
1342 pyo3::ffi::PyCapsule_New(
1343 managed_ptr as *mut c_void,
1344 c"dltensor_versioned".as_ptr(),
1345 None,
1346 )
1347 };
1348 assert!(!capsule_ptr.is_null());
1349 let capsule: Bound<'_, PyCapsule> = unsafe { Bound::from_owned_ptr(py, capsule_ptr) }
1350 .cast_into()
1351 .unwrap();
1352
1353 let result = PyTensor::from_capsule(&capsule);
1354 assert!(result.is_err());
1355
1356 unsafe {
1358 let _ = Box::from_raw(managed_ptr);
1359 }
1360 drop(shape);
1361 drop(data);
1362 });
1363 }
1364}