1use crate::ffi::{DLDataType, DLDevice, DLManagedTensor};
7use crate::DLPACK_CAPSULE_NAME;
8use pyo3::prelude::*;
9use pyo3::types::PyCapsule;
10use std::ffi::{c_char, c_void};
11use std::ptr::NonNull;
12
13static USED_DLTENSOR_NAME: &[u8] = b"used_dltensor\0";
18
19pub struct PyTensor {
54 managed: NonNull<DLManagedTensor>,
55 #[allow(dead_code)]
58 capsule: Py<PyCapsule>,
59}
60
61unsafe impl Send for PyTensor {}
64
65impl PyTensor {
66 pub fn from_pyany(_py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<Self> {
83 let capsule_obj = obj.call_method0("__dlpack__")?;
85 let capsule: Bound<'_, PyCapsule> = capsule_obj.cast_into().map_err(|e| {
86 pyo3::exceptions::PyTypeError::new_err(format!(
87 "__dlpack__ did not return a PyCapsule: {:?}",
88 e.into_inner()
89 ))
90 })?;
91 Self::from_capsule(&capsule)
92 }
93
94 pub fn from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
104 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME))?;
107 let managed = NonNull::new(ptr.as_ptr() as *mut DLManagedTensor).ok_or_else(|| {
108 pyo3::exceptions::PyValueError::new_err("DLPack capsule contains null pointer")
109 })?;
110
111 let set_name_result = unsafe {
125 pyo3::ffi::PyCapsule_SetName(
126 capsule.as_ptr(),
127 USED_DLTENSOR_NAME.as_ptr() as *const c_char,
128 )
129 };
130 if set_name_result != 0 {
131 return Err(pyo3::exceptions::PyRuntimeError::new_err(
134 "Failed to mark DLPack capsule as consumed: PyCapsule_SetName failed",
135 ));
136 }
137
138 Ok(Self {
139 managed,
140 capsule: capsule.clone().unbind(),
141 })
142 }
143
144 pub fn device(&self) -> DLDevice {
146 unsafe { self.managed.as_ref().dl_tensor.device }
147 }
148
149 pub fn dtype(&self) -> DLDataType {
151 unsafe { self.managed.as_ref().dl_tensor.dtype }
152 }
153
154 pub fn ndim(&self) -> usize {
156 unsafe { self.managed.as_ref().dl_tensor.ndim as usize }
157 }
158
159 pub fn shape(&self) -> &[i64] {
163 unsafe {
164 let tensor = &self.managed.as_ref().dl_tensor;
165 if tensor.shape.is_null() {
166 &[]
167 } else {
168 std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize)
169 }
170 }
171 }
172
173 pub fn strides(&self) -> Option<&[i64]> {
178 unsafe {
179 let tensor = &self.managed.as_ref().dl_tensor;
180 if tensor.strides.is_null() {
181 None
182 } else {
183 Some(std::slice::from_raw_parts(
184 tensor.strides,
185 tensor.ndim as usize,
186 ))
187 }
188 }
189 }
190
191 pub fn is_contiguous(&self) -> bool {
193 match self.strides() {
194 None => true,
195 Some(strides) => {
196 let shape = self.shape();
197 if shape.is_empty() {
198 return true;
199 }
200
201 let mut expected_stride = 1i64;
202 for i in (0..shape.len()).rev() {
203 if strides[i] != expected_stride {
204 return false;
205 }
206 expected_stride *= shape[i];
207 }
208 true
209 }
210 }
211 }
212
213 pub fn data_ptr(&self) -> *mut c_void {
220 unsafe {
221 let tensor = &self.managed.as_ref().dl_tensor;
222 (tensor.data as *mut u8).add(tensor.byte_offset as usize) as *mut c_void
223 }
224 }
225
226 pub fn data_ptr_raw(&self) -> *mut c_void {
228 unsafe { self.managed.as_ref().dl_tensor.data }
229 }
230
231 pub fn byte_offset(&self) -> u64 {
233 unsafe { self.managed.as_ref().dl_tensor.byte_offset }
234 }
235
236 pub fn numel(&self) -> usize {
238 self.shape().iter().map(|&d| d as usize).product()
239 }
240
241 pub fn itemsize(&self) -> usize {
243 self.dtype().itemsize()
244 }
245
246 pub fn nbytes(&self) -> usize {
248 self.numel() * self.itemsize()
249 }
250}
251
252impl Drop for PyTensor {
253 fn drop(&mut self) {
254 unsafe {
256 let managed = self.managed.as_ref();
257 if let Some(deleter) = managed.deleter {
258 deleter(self.managed.as_ptr());
259 }
260 }
261 }
262}
263
264impl std::fmt::Debug for PyTensor {
265 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266 f.debug_struct("PyTensor")
267 .field("shape", &self.shape())
268 .field("strides", &self.strides())
269 .field("dtype", &self.dtype())
270 .field("device", &self.device())
271 .field("byte_offset", &self.byte_offset())
272 .finish()
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279 use crate::ffi::{cpu_device, cuda_device, dtype_f32, dtype_f64, DLTensor};
280 use pyo3::Python;
281 use std::ffi::CString;
282
283 #[repr(transparent)]
285 struct SendableTestPtr(*mut DLManagedTensor);
286 unsafe impl Send for SendableTestPtr {}
287
288 struct TestManagedTensor {
290 managed: Box<DLManagedTensor>,
291 shape: Vec<i64>,
292 strides: Option<Vec<i64>>,
293 #[allow(dead_code)]
294 data: Vec<u8>,
295 }
296
297 impl TestManagedTensor {
298 fn new(
299 shape: Vec<i64>,
300 strides: Option<Vec<i64>>,
301 dtype: DLDataType,
302 device: DLDevice,
303 ) -> Self {
304 let numel: usize = shape.iter().map(|&d| d as usize).product();
305 let data = vec![0u8; numel.max(1) * dtype.itemsize()];
306
307 let mut result = Self {
308 managed: Box::new(DLManagedTensor {
309 dl_tensor: DLTensor {
310 data: std::ptr::null_mut(),
311 device,
312 ndim: shape.len() as i32,
313 dtype,
314 shape: std::ptr::null_mut(),
315 strides: std::ptr::null_mut(),
316 byte_offset: 0,
317 },
318 manager_ctx: std::ptr::null_mut(),
319 deleter: None,
320 }),
321 shape,
322 strides,
323 data,
324 };
325
326 result.managed.dl_tensor.data = result.data.as_ptr() as *mut c_void;
328 result.managed.dl_tensor.shape = result.shape.as_mut_ptr();
329 if let Some(ref mut s) = result.strides {
330 result.managed.dl_tensor.strides = s.as_mut_ptr();
331 }
332
333 result
334 }
335
336 fn with_byte_offset(mut self, offset: u64) -> Self {
337 self.managed.dl_tensor.byte_offset = offset;
338 self
339 }
340
341 fn as_ptr(&self) -> *mut DLManagedTensor {
342 &*self.managed as *const _ as *mut _
343 }
344 }
345
346 #[test]
351 fn test_is_contiguous_no_strides() {
352 let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
354
355 let managed = unsafe { &*tensor.as_ptr() };
357 let strides_ptr = managed.dl_tensor.strides;
358
359 assert!(strides_ptr.is_null());
361 }
362
363 #[test]
364 fn test_is_contiguous_with_contiguous_strides() {
365 let tensor = TestManagedTensor::new(
368 vec![2, 3, 4],
369 Some(vec![12, 4, 1]),
370 dtype_f32(),
371 cpu_device(),
372 );
373
374 let shape = &tensor.shape;
375 let strides = tensor.strides.as_ref().unwrap();
376
377 let mut expected_stride = 1i64;
379 let mut is_contiguous = true;
380 for i in (0..shape.len()).rev() {
381 if strides[i] != expected_stride {
382 is_contiguous = false;
383 break;
384 }
385 expected_stride *= shape[i];
386 }
387 assert!(is_contiguous);
388 }
389
390 #[test]
391 fn test_is_contiguous_with_non_contiguous_strides() {
392 let tensor = TestManagedTensor::new(
394 vec![2, 3, 4],
395 Some(vec![1, 2, 6]), dtype_f32(),
397 cpu_device(),
398 );
399
400 let shape = &tensor.shape;
401 let strides = tensor.strides.as_ref().unwrap();
402
403 let mut expected_stride = 1i64;
404 let mut is_contiguous = true;
405 for i in (0..shape.len()).rev() {
406 if strides[i] != expected_stride {
407 is_contiguous = false;
408 break;
409 }
410 expected_stride *= shape[i];
411 }
412 assert!(!is_contiguous);
413 }
414
415 #[test]
416 fn test_is_contiguous_empty_tensor() {
417 let tensor = TestManagedTensor::new(vec![], None, dtype_f32(), cpu_device());
418 assert!(tensor.shape.is_empty());
420 }
421
422 #[test]
423 fn test_is_contiguous_1d() {
424 let tensor = TestManagedTensor::new(vec![10], Some(vec![1]), dtype_f32(), cpu_device());
425 let strides = tensor.strides.as_ref().unwrap();
426 assert_eq!(strides[0], 1);
427 }
428
429 #[test]
434 fn test_numel_calculation() {
435 let shapes_and_expected: Vec<(Vec<i64>, usize)> = vec![
436 (vec![], 1), (vec![5], 5),
438 (vec![2, 3], 6),
439 (vec![2, 3, 4], 24),
440 (vec![1, 1, 1, 1], 1),
441 (vec![10, 20, 30], 6000),
442 ];
443
444 for (shape, expected) in shapes_and_expected {
445 let numel: usize = if shape.is_empty() {
446 1 } else {
448 shape.iter().map(|&d| d as usize).product()
449 };
450 assert_eq!(numel, expected, "Failed for shape {:?}", shape);
451 }
452 }
453
454 #[test]
455 fn test_nbytes_calculation() {
456 let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
458 let numel: usize = tensor.shape.iter().map(|&d| d as usize).product();
459 let itemsize = dtype_f32().itemsize();
460 assert_eq!(numel * itemsize, 96);
461
462 let tensor2 = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
464 let numel2: usize = tensor2.shape.iter().map(|&d| d as usize).product();
465 let itemsize2 = dtype_f64().itemsize();
466 assert_eq!(numel2 * itemsize2, 48);
467 }
468
469 #[test]
474 fn test_data_ptr_with_offset() {
475 let tensor =
476 TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device()).with_byte_offset(16);
477
478 let managed = unsafe { &*tensor.as_ptr() };
479 let base_ptr = managed.dl_tensor.data as usize;
480 let offset = managed.dl_tensor.byte_offset as usize;
481 let adjusted_ptr = base_ptr + offset;
482
483 assert_eq!(offset, 16);
484 assert_eq!(adjusted_ptr, base_ptr + 16);
485 }
486
487 #[test]
488 fn test_data_ptr_no_offset() {
489 let tensor = TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device());
490
491 let managed = unsafe { &*tensor.as_ptr() };
492 assert_eq!(managed.dl_tensor.byte_offset, 0);
493 }
494
495 #[test]
500 fn test_device_accessor() {
501 let cpu_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
502 let managed = unsafe { &*cpu_tensor.as_ptr() };
503 assert!(managed.dl_tensor.device.is_cpu());
504
505 let cuda_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cuda_device(1));
506 let managed = unsafe { &*cuda_tensor.as_ptr() };
507 assert!(managed.dl_tensor.device.is_cuda());
508 assert_eq!(managed.dl_tensor.device.device_id, 1);
509 }
510
511 #[test]
512 fn test_dtype_accessor() {
513 let f32_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
514 let managed = unsafe { &*f32_tensor.as_ptr() };
515 assert!(managed.dl_tensor.dtype.is_f32());
516
517 let f64_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
518 let managed = unsafe { &*f64_tensor.as_ptr() };
519 assert!(managed.dl_tensor.dtype.is_f64());
520 }
521
522 #[test]
527 fn test_ndim() {
528 let shapes: Vec<Vec<i64>> = vec![
529 vec![],
530 vec![5],
531 vec![2, 3],
532 vec![2, 3, 4],
533 vec![1, 2, 3, 4, 5],
534 ];
535
536 for shape in shapes {
537 let expected_ndim = shape.len();
538 let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
539 let managed = unsafe { &*tensor.as_ptr() };
540 assert_eq!(managed.dl_tensor.ndim as usize, expected_ndim);
541 }
542 }
543
544 #[test]
545 fn test_shape_accessor() {
546 let shape = vec![2i64, 3, 4];
547 let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
548 let managed = unsafe { &*tensor.as_ptr() };
549
550 let shape_slice = unsafe {
551 std::slice::from_raw_parts(managed.dl_tensor.shape, managed.dl_tensor.ndim as usize)
552 };
553 assert_eq!(shape_slice, &[2, 3, 4]);
554 }
555
556 #[test]
561 fn test_capsule_creation_and_extraction() {
562 Python::attach(|py| {
563 let mut shape = vec![2i64, 3];
565 let data = [0u8; 24].to_vec(); let managed = Box::new(DLManagedTensor {
568 dl_tensor: DLTensor {
569 data: data.as_ptr() as *mut c_void,
570 device: cpu_device(),
571 ndim: 2,
572 dtype: dtype_f32(),
573 shape: shape.as_mut_ptr(),
574 strides: std::ptr::null_mut(),
575 byte_offset: 0,
576 },
577 manager_ctx: std::ptr::null_mut(),
578 deleter: None,
579 });
580
581 let managed_ptr = Box::into_raw(managed);
582 let sendable = SendableTestPtr(managed_ptr);
583 let name = CString::new("dltensor").unwrap();
584
585 let capsule =
587 PyCapsule::new(py, sendable, Some(name)).expect("Failed to create capsule");
588
589 let capsule_name = capsule.name().expect("Failed to get name");
591 assert!(capsule_name.is_some());
592
593 let _extracted = capsule
595 .pointer_checked(Some(DLPACK_CAPSULE_NAME))
596 .expect("Failed to extract pointer");
597
598 unsafe {
600 let _ = Box::from_raw(managed_ptr);
601 }
602 });
603 }
604
605 #[test]
606 fn test_capsule_wrong_name() {
607 #[allow(dead_code)]
609 struct TestData(i32);
610 unsafe impl Send for TestData {}
611
612 Python::attach(|py| {
613 let data = TestData(42);
614 let name = CString::new("wrong_name").unwrap();
615
616 let capsule = PyCapsule::new(py, data, Some(name)).expect("Failed to create capsule");
617
618 let result = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME));
620 assert!(result.is_err());
621 });
622 }
623
624 #[test]
625 fn test_pytensor_send() {
626 fn assert_send<T: Send>() {}
628 assert_send::<PyTensor>();
629 }
630
631 use std::sync::atomic::{AtomicUsize, Ordering};
636
637 static DELETER_CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
638
639 struct TestTensorContext {
641 data: Vec<f32>,
642 shape: Vec<i64>,
643 strides: Option<Vec<i64>>,
644 }
645
646 fn create_test_capsule(
648 py: Python<'_>,
649 ctx: Box<TestTensorContext>,
650 device: DLDevice,
651 dtype: DLDataType,
652 byte_offset: u64,
653 with_deleter: bool,
654 ) -> PyResult<Bound<'_, PyCapsule>> {
655 let ctx_ptr = Box::into_raw(ctx);
656
657 unsafe {
658 let ctx_ref = &mut *ctx_ptr;
659
660 let managed = Box::new(DLManagedTensor {
661 dl_tensor: DLTensor {
662 data: ctx_ref.data.as_ptr() as *mut c_void,
663 device,
664 ndim: ctx_ref.shape.len() as i32,
665 dtype,
666 shape: ctx_ref.shape.as_mut_ptr(),
667 strides: ctx_ref
668 .strides
669 .as_mut()
670 .map(|s| s.as_mut_ptr())
671 .unwrap_or(std::ptr::null_mut()),
672 byte_offset,
673 },
674 manager_ctx: ctx_ptr as *mut c_void,
675 deleter: if with_deleter {
676 Some(test_deleter)
677 } else {
678 None
679 },
680 });
681
682 let managed_ptr = Box::into_raw(managed);
683 let wrapper = SendableTestPtr(managed_ptr);
684 let name = CString::new("dltensor").unwrap();
685
686 PyCapsule::new(py, wrapper, Some(name))
687 }
688 }
689
690 unsafe extern "C" fn test_deleter(managed_ptr: *mut DLManagedTensor) {
692 if !managed_ptr.is_null() {
693 DELETER_CALL_COUNT.fetch_add(1, Ordering::SeqCst);
694 let managed = Box::from_raw(managed_ptr);
695 if !managed.manager_ctx.is_null() {
696 let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
697 }
698 }
699 }
700
701 #[test]
702 fn test_pytensor_all_accessors() {
703 Python::attach(|py| {
704 let ctx = Box::new(TestTensorContext {
705 data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
706 shape: vec![2, 3],
707 strides: None,
708 });
709
710 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
711 .expect("Failed to create capsule");
712
713 let ptr = capsule
715 .pointer_checked(Some(DLPACK_CAPSULE_NAME))
716 .expect("Failed to get pointer");
717 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
719 let managed = NonNull::new(managed_ptr).expect("Null pointer");
720
721 let pytensor = PyTensor {
723 managed,
724 capsule: capsule.clone().unbind(),
725 };
726
727 assert!(pytensor.device().is_cpu());
729 assert!(pytensor.dtype().is_f32());
730 assert_eq!(pytensor.ndim(), 2);
731 assert_eq!(pytensor.shape(), &[2, 3]);
732 assert!(pytensor.strides().is_none());
733 assert!(pytensor.is_contiguous());
734 assert!(!pytensor.data_ptr().is_null());
735 assert!(!pytensor.data_ptr_raw().is_null());
736 assert_eq!(pytensor.byte_offset(), 0);
737 assert_eq!(pytensor.numel(), 6);
738 assert_eq!(pytensor.itemsize(), 4);
739 assert_eq!(pytensor.nbytes(), 24);
740
741 let debug = format!("{:?}", pytensor);
743 assert!(debug.contains("PyTensor"));
744 assert!(debug.contains("shape"));
745 assert!(debug.contains("dtype"));
746 assert!(debug.contains("device"));
747
748 std::mem::forget(pytensor);
750 });
751 }
752
753 #[test]
754 fn test_pytensor_with_strides_contiguous() {
755 Python::attach(|py| {
756 let ctx = Box::new(TestTensorContext {
757 data: vec![1.0; 24],
758 shape: vec![2, 3, 4],
759 strides: Some(vec![12, 4, 1]), });
761
762 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
763 .expect("Failed to create capsule");
764
765 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
766 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
767 let managed = NonNull::new(managed_ptr).unwrap();
768
769 let pytensor = PyTensor {
770 managed,
771 capsule: capsule.clone().unbind(),
772 };
773
774 assert_eq!(pytensor.ndim(), 3);
775 assert_eq!(pytensor.shape(), &[2, 3, 4]);
776 assert_eq!(pytensor.strides(), Some(&[12i64, 4, 1][..]));
777 assert!(pytensor.is_contiguous());
778 assert_eq!(pytensor.numel(), 24);
779
780 std::mem::forget(pytensor);
781 });
782 }
783
784 #[test]
785 fn test_pytensor_non_contiguous() {
786 Python::attach(|py| {
787 let ctx = Box::new(TestTensorContext {
788 data: vec![1.0; 6],
789 shape: vec![2, 3],
790 strides: Some(vec![1, 2]), });
792
793 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
794 .expect("Failed to create capsule");
795
796 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
797 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
798 let managed = NonNull::new(managed_ptr).unwrap();
799
800 let pytensor = PyTensor {
801 managed,
802 capsule: capsule.clone().unbind(),
803 };
804
805 assert!(!pytensor.is_contiguous());
806 assert_eq!(pytensor.strides(), Some(&[1i64, 2][..]));
807
808 std::mem::forget(pytensor);
809 });
810 }
811
812 #[test]
813 fn test_pytensor_scalar() {
814 Python::attach(|py| {
815 let ctx = Box::new(TestTensorContext {
816 data: vec![42.0],
817 shape: vec![],
818 strides: None,
819 });
820
821 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
822 .expect("Failed to create capsule");
823
824 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
825 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
826 let managed = NonNull::new(managed_ptr).unwrap();
827
828 let pytensor = PyTensor {
829 managed,
830 capsule: capsule.clone().unbind(),
831 };
832
833 assert_eq!(pytensor.ndim(), 0);
834 assert!(pytensor.shape().is_empty());
835 assert!(pytensor.is_contiguous());
836 assert_eq!(pytensor.numel(), 1);
837
838 std::mem::forget(pytensor);
839 });
840 }
841
842 #[test]
843 fn test_pytensor_with_byte_offset() {
844 Python::attach(|py| {
845 let ctx = Box::new(TestTensorContext {
846 data: vec![1.0; 20],
847 shape: vec![10],
848 strides: None,
849 });
850
851 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 16, false)
852 .expect("Failed to create capsule");
853
854 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
855 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
856 let managed = NonNull::new(managed_ptr).unwrap();
857
858 let pytensor = PyTensor {
859 managed,
860 capsule: capsule.clone().unbind(),
861 };
862
863 assert_eq!(pytensor.byte_offset(), 16);
864 let raw = pytensor.data_ptr_raw() as usize;
865 let adjusted = pytensor.data_ptr() as usize;
866 assert_eq!(adjusted, raw + 16);
867
868 std::mem::forget(pytensor);
869 });
870 }
871
872 #[test]
873 fn test_pytensor_cuda_device() {
874 Python::attach(|py| {
875 let ctx = Box::new(TestTensorContext {
876 data: vec![1.0; 512],
877 shape: vec![16, 32],
878 strides: None,
879 });
880
881 let capsule = create_test_capsule(py, ctx, cuda_device(1), dtype_f32(), 0, false)
882 .expect("Failed to create capsule");
883
884 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
885 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
886 let managed = NonNull::new(managed_ptr).unwrap();
887
888 let pytensor = PyTensor {
889 managed,
890 capsule: capsule.clone().unbind(),
891 };
892
893 assert!(pytensor.device().is_cuda());
894 assert_eq!(pytensor.device().device_id, 1);
895
896 std::mem::forget(pytensor);
897 });
898 }
899
900 #[test]
901 fn test_pytensor_f64_dtype() {
902 Python::attach(|py| {
903 let ctx = Box::new(TestTensorContext {
905 data: vec![1.0; 6], shape: vec![3],
907 strides: None,
908 });
909
910 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f64(), 0, false)
911 .expect("Failed to create capsule");
912
913 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
914 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
915 let managed = NonNull::new(managed_ptr).unwrap();
916
917 let pytensor = PyTensor {
918 managed,
919 capsule: capsule.clone().unbind(),
920 };
921
922 assert!(pytensor.dtype().is_f64());
923 assert_eq!(pytensor.itemsize(), 8);
924 assert_eq!(pytensor.nbytes(), 24);
925
926 std::mem::forget(pytensor);
927 });
928 }
929
930 #[test]
931 fn test_pytensor_empty_strides_scalar() {
932 Python::attach(|py| {
933 let ctx = Box::new(TestTensorContext {
934 data: vec![1.0],
935 shape: vec![],
936 strides: Some(vec![]), });
938
939 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
940 .expect("Failed to create capsule");
941
942 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
943 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
944 let managed = NonNull::new(managed_ptr).unwrap();
945
946 let pytensor = PyTensor {
947 managed,
948 capsule: capsule.clone().unbind(),
949 };
950
951 assert!(pytensor.is_contiguous());
952 assert!(pytensor.strides().is_some());
953 assert!(pytensor.strides().unwrap().is_empty());
954
955 std::mem::forget(pytensor);
956 });
957 }
958
959 #[test]
960 fn test_pytensor_drop_calls_deleter() {
961 DELETER_CALL_COUNT.store(0, Ordering::SeqCst);
962
963 Python::attach(|py| {
964 let ctx = Box::new(TestTensorContext {
965 data: vec![1.0, 2.0, 3.0],
966 shape: vec![3],
967 strides: None,
968 });
969
970 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, true)
971 .expect("Failed to create capsule");
972
973 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
974 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
975 let managed = NonNull::new(managed_ptr).unwrap();
976
977 {
978 let pytensor = PyTensor {
979 managed,
980 capsule: capsule.clone().unbind(),
981 };
982
983 assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 0);
985
986 drop(pytensor);
988 }
989
990 assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 1);
992 });
993 }
994
995 #[test]
996 fn test_pytensor_drop_no_deleter() {
997 Python::attach(|py| {
998 let ctx = Box::new(TestTensorContext {
999 data: vec![1.0],
1000 shape: vec![1],
1001 strides: None,
1002 });
1003
1004 let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
1005 .expect("Failed to create capsule");
1006
1007 let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1008 let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1009 let managed = NonNull::new(managed_ptr).unwrap();
1010
1011 let pytensor = PyTensor {
1012 managed,
1013 capsule: capsule.clone().unbind(),
1014 };
1015
1016 drop(pytensor);
1018
1019 unsafe {
1021 let managed = Box::from_raw(managed_ptr);
1022 if !managed.manager_ctx.is_null() {
1023 let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
1024 }
1025 }
1026 });
1027 }
1028}