Skip to main content

pyo3_dlpack/
managed.rs

1//! Managed tensor for importing from Python via DLPack.
2//!
3//! This module provides `PyTensor`, a wrapper around a DLPack tensor
4//! received from Python that provides safe access to tensor metadata.
5
6use crate::ffi::{DLDataType, DLDevice, DLManagedTensor};
7use crate::DLPACK_CAPSULE_NAME;
8use pyo3::prelude::*;
9use pyo3::types::PyCapsule;
10use std::ffi::{c_char, c_void};
11use std::ptr::NonNull;
12
13/// The name for consumed DLPack capsules (per DLPack protocol).
14/// Using a static byte array with null terminator for C compatibility.
15/// This must remain valid for the lifetime of the program since PyCapsule_SetName
16/// stores the pointer directly without copying.
17static USED_DLTENSOR_NAME: &[u8] = b"used_dltensor\0";
18
19/// A tensor imported from Python via the DLPack protocol.
20///
21/// This type wraps a `DLManagedTensor` received from a Python object
22/// (typically a PyTorch, JAX, or NumPy tensor) and provides safe access
23/// to the tensor's metadata and data pointer.
24///
25/// # Lifetime
26///
27/// The tensor data is valid as long as this `PyTensor` is alive.
28/// When dropped, the tensor's deleter is called to notify the producer.
29///
30/// # Example
31///
32/// ```ignore
33/// use pyo3::prelude::*;
34/// use pyo3_dlpack::PyTensor;
35///
36/// #[pyfunction]
37/// fn process(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<()> {
38///     let tensor = PyTensor::from_pyany(py, obj)?;
39///
40///     println!("Shape: {:?}", tensor.shape());
41///     println!("Device: {:?}", tensor.device());
42///     println!("Dtype: {:?}", tensor.dtype());
43///
44///     if tensor.device().is_cpu() {
45///         // Safe to access data on CPU
46///         let ptr = tensor.data_ptr() as *const f32;
47///         // ...
48///     }
49///
50///     Ok(())
51/// }
52/// ```
53pub struct PyTensor {
54    managed: NonNull<DLManagedTensor>,
55    /// We store the capsule to prevent it from being garbage collected
56    /// while we hold a reference to the managed tensor.
57    #[allow(dead_code)]
58    capsule: Py<PyCapsule>,
59}
60
61// Safety: The underlying DLManagedTensor is thread-safe to send
62// (the producer guarantees this by implementing DLPack)
63unsafe impl Send for PyTensor {}
64
65impl PyTensor {
66    /// Create a PyTensor from a Python object that supports the DLPack protocol.
67    ///
68    /// This calls `__dlpack__()` on the object to get a DLPack capsule,
69    /// then extracts the tensor information.
70    ///
71    /// # Arguments
72    ///
73    /// * `py` - Python GIL token
74    /// * `obj` - A Python object that implements `__dlpack__()` (e.g., PyTorch tensor)
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if:
79    /// - The object doesn't have a `__dlpack__` method
80    /// - The returned capsule is invalid
81    /// - The capsule doesn't contain a valid DLManagedTensor
82    pub fn from_pyany(_py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<Self> {
83        // Call __dlpack__() to get the capsule
84        let capsule_obj = obj.call_method0("__dlpack__")?;
85        let capsule: Bound<'_, PyCapsule> = capsule_obj.cast_into().map_err(|e| {
86            pyo3::exceptions::PyTypeError::new_err(format!(
87                "__dlpack__ did not return a PyCapsule: {:?}",
88                e.into_inner()
89            ))
90        })?;
91        Self::from_capsule(&capsule)
92    }
93
94    /// Create a PyTensor directly from a DLPack PyCapsule.
95    ///
96    /// # Arguments
97    ///
98    /// * `capsule` - A PyCapsule containing a DLManagedTensor
99    ///
100    /// # Errors
101    ///
102    /// Returns an error if the capsule is invalid or has the wrong name.
103    pub fn from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
104        // Extract the pointer using pointer_checked which also validates the capsule name.
105        // This will fail if the capsule was already consumed (name is "used_dltensor").
106        let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME))?;
107        let managed = NonNull::new(ptr.as_ptr() as *mut DLManagedTensor).ok_or_else(|| {
108            pyo3::exceptions::PyValueError::new_err("DLPack capsule contains null pointer")
109        })?;
110
111        // Per DLPack protocol, rename the capsule to "used_dltensor" to indicate
112        // we have taken ownership. This prevents:
113        // 1. Multiple consumers from using the same capsule (second consumer
114        //    will fail the name check above)
115        // 2. The producer's capsule destructor from calling the deleter
116        //    (it checks for this name and skips the deleter call)
117        //
118        // We use a static string because PyCapsule_SetName stores the pointer
119        // directly without copying.
120        //
121        // SAFETY: We must check the return value. If PyCapsule_SetName fails:
122        // - Returns -1 and sets a Python exception
123        // - The capsule name remains "dltensor", enabling double-consume/double-free
124        let set_name_result = unsafe {
125            pyo3::ffi::PyCapsule_SetName(
126                capsule.as_ptr(),
127                USED_DLTENSOR_NAME.as_ptr() as *const c_char,
128            )
129        };
130        if set_name_result != 0 {
131            // PyCapsule_SetName failed (returns -1 on error)
132            // A Python exception is already set, convert it to PyErr
133            return Err(pyo3::exceptions::PyRuntimeError::new_err(
134                "Failed to mark DLPack capsule as consumed: PyCapsule_SetName failed",
135            ));
136        }
137
138        Ok(Self {
139            managed,
140            capsule: capsule.clone().unbind(),
141        })
142    }
143
144    /// Get the device where the tensor data resides.
145    pub fn device(&self) -> DLDevice {
146        unsafe { self.managed.as_ref().dl_tensor.device }
147    }
148
149    /// Get the data type of the tensor elements.
150    pub fn dtype(&self) -> DLDataType {
151        unsafe { self.managed.as_ref().dl_tensor.dtype }
152    }
153
154    /// Get the number of dimensions.
155    pub fn ndim(&self) -> usize {
156        unsafe { self.managed.as_ref().dl_tensor.ndim as usize }
157    }
158
159    /// Get the shape as a slice.
160    ///
161    /// The length of the slice equals `ndim()`.
162    pub fn shape(&self) -> &[i64] {
163        unsafe {
164            let tensor = &self.managed.as_ref().dl_tensor;
165            if tensor.shape.is_null() {
166                &[]
167            } else {
168                std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize)
169            }
170        }
171    }
172
173    /// Get the strides as a slice, or `None` for contiguous tensors.
174    ///
175    /// Strides are in number of elements (not bytes).
176    /// If `None`, the tensor is assumed to be in compact row-major order.
177    pub fn strides(&self) -> Option<&[i64]> {
178        unsafe {
179            let tensor = &self.managed.as_ref().dl_tensor;
180            if tensor.strides.is_null() {
181                None
182            } else {
183                Some(std::slice::from_raw_parts(
184                    tensor.strides,
185                    tensor.ndim as usize,
186                ))
187            }
188        }
189    }
190
191    /// Check if the tensor is contiguous in row-major (C) order.
192    pub fn is_contiguous(&self) -> bool {
193        match self.strides() {
194            None => true,
195            Some(strides) => {
196                let shape = self.shape();
197                if shape.is_empty() {
198                    return true;
199                }
200
201                let mut expected_stride = 1i64;
202                for i in (0..shape.len()).rev() {
203                    if strides[i] != expected_stride {
204                        return false;
205                    }
206                    expected_stride *= shape[i];
207                }
208                true
209            }
210        }
211    }
212
213    /// Get the raw data pointer.
214    ///
215    /// For GPU tensors, this is a device pointer that cannot be directly
216    /// dereferenced on the CPU.
217    ///
218    /// The pointer is adjusted by `byte_offset()`.
219    pub fn data_ptr(&self) -> *mut c_void {
220        unsafe {
221            let tensor = &self.managed.as_ref().dl_tensor;
222            (tensor.data as *mut u8).add(tensor.byte_offset as usize) as *mut c_void
223        }
224    }
225
226    /// Get the raw data pointer without byte offset adjustment.
227    pub fn data_ptr_raw(&self) -> *mut c_void {
228        unsafe { self.managed.as_ref().dl_tensor.data }
229    }
230
231    /// Get the byte offset from the raw data pointer.
232    pub fn byte_offset(&self) -> u64 {
233        unsafe { self.managed.as_ref().dl_tensor.byte_offset }
234    }
235
236    /// Get the total number of elements in the tensor.
237    pub fn numel(&self) -> usize {
238        self.shape().iter().map(|&d| d as usize).product()
239    }
240
241    /// Get the size of one element in bytes.
242    pub fn itemsize(&self) -> usize {
243        self.dtype().itemsize()
244    }
245
246    /// Get the total size of the tensor data in bytes.
247    pub fn nbytes(&self) -> usize {
248        self.numel() * self.itemsize()
249    }
250}
251
252impl Drop for PyTensor {
253    fn drop(&mut self) {
254        // Call the deleter if present
255        unsafe {
256            let managed = self.managed.as_ref();
257            if let Some(deleter) = managed.deleter {
258                deleter(self.managed.as_ptr());
259            }
260        }
261    }
262}
263
264impl std::fmt::Debug for PyTensor {
265    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
266        f.debug_struct("PyTensor")
267            .field("shape", &self.shape())
268            .field("strides", &self.strides())
269            .field("dtype", &self.dtype())
270            .field("device", &self.device())
271            .field("byte_offset", &self.byte_offset())
272            .finish()
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279    use crate::ffi::{cpu_device, cuda_device, dtype_f32, dtype_f64, DLTensor};
280    use pyo3::Python;
281    use std::ffi::CString;
282
283    /// Wrapper to make pointer Send for testing
284    #[repr(transparent)]
285    struct SendableTestPtr(*mut DLManagedTensor);
286    unsafe impl Send for SendableTestPtr {}
287
288    /// Helper to create a test DLManagedTensor with given parameters
289    struct TestManagedTensor {
290        managed: Box<DLManagedTensor>,
291        shape: Vec<i64>,
292        strides: Option<Vec<i64>>,
293        #[allow(dead_code)]
294        data: Vec<u8>,
295    }
296
297    impl TestManagedTensor {
298        fn new(
299            shape: Vec<i64>,
300            strides: Option<Vec<i64>>,
301            dtype: DLDataType,
302            device: DLDevice,
303        ) -> Self {
304            let numel: usize = shape.iter().map(|&d| d as usize).product();
305            let data = vec![0u8; numel.max(1) * dtype.itemsize()];
306
307            let mut result = Self {
308                managed: Box::new(DLManagedTensor {
309                    dl_tensor: DLTensor {
310                        data: std::ptr::null_mut(),
311                        device,
312                        ndim: shape.len() as i32,
313                        dtype,
314                        shape: std::ptr::null_mut(),
315                        strides: std::ptr::null_mut(),
316                        byte_offset: 0,
317                    },
318                    manager_ctx: std::ptr::null_mut(),
319                    deleter: None,
320                }),
321                shape,
322                strides,
323                data,
324            };
325
326            // Set up pointers
327            result.managed.dl_tensor.data = result.data.as_ptr() as *mut c_void;
328            result.managed.dl_tensor.shape = result.shape.as_mut_ptr();
329            if let Some(ref mut s) = result.strides {
330                result.managed.dl_tensor.strides = s.as_mut_ptr();
331            }
332
333            result
334        }
335
336        fn with_byte_offset(mut self, offset: u64) -> Self {
337            self.managed.dl_tensor.byte_offset = offset;
338            self
339        }
340
341        fn as_ptr(&self) -> *mut DLManagedTensor {
342            &*self.managed as *const _ as *mut _
343        }
344    }
345
346    // ========================================================================
347    // is_contiguous tests
348    // ========================================================================
349
350    #[test]
351    fn test_is_contiguous_no_strides() {
352        // No strides means contiguous by default
353        let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
354
355        // Create a mock PyTensor-like check using the raw managed tensor
356        let managed = unsafe { &*tensor.as_ptr() };
357        let strides_ptr = managed.dl_tensor.strides;
358
359        // No strides pointer = contiguous
360        assert!(strides_ptr.is_null());
361    }
362
363    #[test]
364    fn test_is_contiguous_with_contiguous_strides() {
365        // Row-major contiguous strides for shape [2, 3, 4]
366        // strides should be [12, 4, 1]
367        let tensor = TestManagedTensor::new(
368            vec![2, 3, 4],
369            Some(vec![12, 4, 1]),
370            dtype_f32(),
371            cpu_device(),
372        );
373
374        let shape = &tensor.shape;
375        let strides = tensor.strides.as_ref().unwrap();
376
377        // Verify contiguity check logic
378        let mut expected_stride = 1i64;
379        let mut is_contiguous = true;
380        for i in (0..shape.len()).rev() {
381            if strides[i] != expected_stride {
382                is_contiguous = false;
383                break;
384            }
385            expected_stride *= shape[i];
386        }
387        assert!(is_contiguous);
388    }
389
390    #[test]
391    fn test_is_contiguous_with_non_contiguous_strides() {
392        // Non-contiguous strides (transposed)
393        let tensor = TestManagedTensor::new(
394            vec![2, 3, 4],
395            Some(vec![1, 2, 6]), // Column-major like strides
396            dtype_f32(),
397            cpu_device(),
398        );
399
400        let shape = &tensor.shape;
401        let strides = tensor.strides.as_ref().unwrap();
402
403        let mut expected_stride = 1i64;
404        let mut is_contiguous = true;
405        for i in (0..shape.len()).rev() {
406            if strides[i] != expected_stride {
407                is_contiguous = false;
408                break;
409            }
410            expected_stride *= shape[i];
411        }
412        assert!(!is_contiguous);
413    }
414
415    #[test]
416    fn test_is_contiguous_empty_tensor() {
417        let tensor = TestManagedTensor::new(vec![], None, dtype_f32(), cpu_device());
418        // Empty shape is contiguous
419        assert!(tensor.shape.is_empty());
420    }
421
422    #[test]
423    fn test_is_contiguous_1d() {
424        let tensor = TestManagedTensor::new(vec![10], Some(vec![1]), dtype_f32(), cpu_device());
425        let strides = tensor.strides.as_ref().unwrap();
426        assert_eq!(strides[0], 1);
427    }
428
429    // ========================================================================
430    // numel and nbytes tests
431    // ========================================================================
432
433    #[test]
434    fn test_numel_calculation() {
435        let shapes_and_expected: Vec<(Vec<i64>, usize)> = vec![
436            (vec![], 1), // Scalar (product of empty = 1)
437            (vec![5], 5),
438            (vec![2, 3], 6),
439            (vec![2, 3, 4], 24),
440            (vec![1, 1, 1, 1], 1),
441            (vec![10, 20, 30], 6000),
442        ];
443
444        for (shape, expected) in shapes_and_expected {
445            let numel: usize = if shape.is_empty() {
446                1 // Scalar case
447            } else {
448                shape.iter().map(|&d| d as usize).product()
449            };
450            assert_eq!(numel, expected, "Failed for shape {:?}", shape);
451        }
452    }
453
454    #[test]
455    fn test_nbytes_calculation() {
456        // f32 tensor [2, 3, 4] = 24 elements * 4 bytes = 96 bytes
457        let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
458        let numel: usize = tensor.shape.iter().map(|&d| d as usize).product();
459        let itemsize = dtype_f32().itemsize();
460        assert_eq!(numel * itemsize, 96);
461
462        // f64 tensor [2, 3] = 6 elements * 8 bytes = 48 bytes
463        let tensor2 = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
464        let numel2: usize = tensor2.shape.iter().map(|&d| d as usize).product();
465        let itemsize2 = dtype_f64().itemsize();
466        assert_eq!(numel2 * itemsize2, 48);
467    }
468
469    // ========================================================================
470    // data_ptr tests
471    // ========================================================================
472
473    #[test]
474    fn test_data_ptr_with_offset() {
475        let tensor =
476            TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device()).with_byte_offset(16);
477
478        let managed = unsafe { &*tensor.as_ptr() };
479        let base_ptr = managed.dl_tensor.data as usize;
480        let offset = managed.dl_tensor.byte_offset as usize;
481        let adjusted_ptr = base_ptr + offset;
482
483        assert_eq!(offset, 16);
484        assert_eq!(adjusted_ptr, base_ptr + 16);
485    }
486
487    #[test]
488    fn test_data_ptr_no_offset() {
489        let tensor = TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device());
490
491        let managed = unsafe { &*tensor.as_ptr() };
492        assert_eq!(managed.dl_tensor.byte_offset, 0);
493    }
494
495    // ========================================================================
496    // Device and dtype accessor tests
497    // ========================================================================
498
499    #[test]
500    fn test_device_accessor() {
501        let cpu_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
502        let managed = unsafe { &*cpu_tensor.as_ptr() };
503        assert!(managed.dl_tensor.device.is_cpu());
504
505        let cuda_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cuda_device(1));
506        let managed = unsafe { &*cuda_tensor.as_ptr() };
507        assert!(managed.dl_tensor.device.is_cuda());
508        assert_eq!(managed.dl_tensor.device.device_id, 1);
509    }
510
511    #[test]
512    fn test_dtype_accessor() {
513        let f32_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
514        let managed = unsafe { &*f32_tensor.as_ptr() };
515        assert!(managed.dl_tensor.dtype.is_f32());
516
517        let f64_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
518        let managed = unsafe { &*f64_tensor.as_ptr() };
519        assert!(managed.dl_tensor.dtype.is_f64());
520    }
521
522    // ========================================================================
523    // ndim and shape tests
524    // ========================================================================
525
526    #[test]
527    fn test_ndim() {
528        let shapes: Vec<Vec<i64>> = vec![
529            vec![],
530            vec![5],
531            vec![2, 3],
532            vec![2, 3, 4],
533            vec![1, 2, 3, 4, 5],
534        ];
535
536        for shape in shapes {
537            let expected_ndim = shape.len();
538            let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
539            let managed = unsafe { &*tensor.as_ptr() };
540            assert_eq!(managed.dl_tensor.ndim as usize, expected_ndim);
541        }
542    }
543
544    #[test]
545    fn test_shape_accessor() {
546        let shape = vec![2i64, 3, 4];
547        let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
548        let managed = unsafe { &*tensor.as_ptr() };
549
550        let shape_slice = unsafe {
551            std::slice::from_raw_parts(managed.dl_tensor.shape, managed.dl_tensor.ndim as usize)
552        };
553        assert_eq!(shape_slice, &[2, 3, 4]);
554    }
555
556    // ========================================================================
557    // PyCapsule integration tests (require Python)
558    // ========================================================================
559
560    #[test]
561    fn test_capsule_creation_and_extraction() {
562        Python::attach(|py| {
563            // Create a test managed tensor
564            let mut shape = vec![2i64, 3];
565            let data = [0u8; 24].to_vec(); // 6 f32 elements
566
567            let managed = Box::new(DLManagedTensor {
568                dl_tensor: DLTensor {
569                    data: data.as_ptr() as *mut c_void,
570                    device: cpu_device(),
571                    ndim: 2,
572                    dtype: dtype_f32(),
573                    shape: shape.as_mut_ptr(),
574                    strides: std::ptr::null_mut(),
575                    byte_offset: 0,
576                },
577                manager_ctx: std::ptr::null_mut(),
578                deleter: None,
579            });
580
581            let managed_ptr = Box::into_raw(managed);
582            let sendable = SendableTestPtr(managed_ptr);
583            let name = CString::new("dltensor").unwrap();
584
585            // Create a PyCapsule with Send wrapper
586            let capsule =
587                PyCapsule::new(py, sendable, Some(name)).expect("Failed to create capsule");
588
589            // Verify capsule name exists
590            let capsule_name = capsule.name().expect("Failed to get name");
591            assert!(capsule_name.is_some());
592
593            // Extract the pointer back - pointer_checked returns NonNull on success
594            let _extracted = capsule
595                .pointer_checked(Some(DLPACK_CAPSULE_NAME))
596                .expect("Failed to extract pointer");
597
598            // Clean up
599            unsafe {
600                let _ = Box::from_raw(managed_ptr);
601            }
602        });
603    }
604
605    #[test]
606    fn test_capsule_wrong_name() {
607        /// Wrapper for test data
608        #[allow(dead_code)]
609        struct TestData(i32);
610        unsafe impl Send for TestData {}
611
612        Python::attach(|py| {
613            let data = TestData(42);
614            let name = CString::new("wrong_name").unwrap();
615
616            let capsule = PyCapsule::new(py, data, Some(name)).expect("Failed to create capsule");
617
618            // Should fail when extracting with wrong expected name
619            let result = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME));
620            assert!(result.is_err());
621        });
622    }
623
624    #[test]
625    fn test_pytensor_send() {
626        // Verify PyTensor implements Send
627        fn assert_send<T: Send>() {}
628        assert_send::<PyTensor>();
629    }
630
631    // ========================================================================
632    // PyTensor comprehensive tests using direct DLManagedTensor capsules
633    // ========================================================================
634
635    use std::sync::atomic::{AtomicUsize, Ordering};
636
637    static DELETER_CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
638
639    /// Helper struct to hold all the data for a test tensor capsule
640    struct TestTensorContext {
641        data: Vec<f32>,
642        shape: Vec<i64>,
643        strides: Option<Vec<i64>>,
644    }
645
646    /// Create a DLPack-compatible capsule for testing PyTensor
647    fn create_test_capsule(
648        py: Python<'_>,
649        ctx: Box<TestTensorContext>,
650        device: DLDevice,
651        dtype: DLDataType,
652        byte_offset: u64,
653        with_deleter: bool,
654    ) -> PyResult<Bound<'_, PyCapsule>> {
655        let ctx_ptr = Box::into_raw(ctx);
656
657        unsafe {
658            let ctx_ref = &mut *ctx_ptr;
659
660            let managed = Box::new(DLManagedTensor {
661                dl_tensor: DLTensor {
662                    data: ctx_ref.data.as_ptr() as *mut c_void,
663                    device,
664                    ndim: ctx_ref.shape.len() as i32,
665                    dtype,
666                    shape: ctx_ref.shape.as_mut_ptr(),
667                    strides: ctx_ref
668                        .strides
669                        .as_mut()
670                        .map(|s| s.as_mut_ptr())
671                        .unwrap_or(std::ptr::null_mut()),
672                    byte_offset,
673                },
674                manager_ctx: ctx_ptr as *mut c_void,
675                deleter: if with_deleter {
676                    Some(test_deleter)
677                } else {
678                    None
679                },
680            });
681
682            let managed_ptr = Box::into_raw(managed);
683            let wrapper = SendableTestPtr(managed_ptr);
684            let name = CString::new("dltensor").unwrap();
685
686            PyCapsule::new(py, wrapper, Some(name))
687        }
688    }
689
690    /// Test deleter that increments a counter
691    unsafe extern "C" fn test_deleter(managed_ptr: *mut DLManagedTensor) {
692        if !managed_ptr.is_null() {
693            DELETER_CALL_COUNT.fetch_add(1, Ordering::SeqCst);
694            let managed = Box::from_raw(managed_ptr);
695            if !managed.manager_ctx.is_null() {
696                let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
697            }
698        }
699    }
700
701    #[test]
702    fn test_pytensor_all_accessors() {
703        Python::attach(|py| {
704            let ctx = Box::new(TestTensorContext {
705                data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
706                shape: vec![2, 3],
707                strides: None,
708            });
709
710            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
711                .expect("Failed to create capsule");
712
713            // Create PyTensor - need to read the pointer from the capsule correctly
714            let ptr = capsule
715                .pointer_checked(Some(DLPACK_CAPSULE_NAME))
716                .expect("Failed to get pointer");
717            // The capsule stores SendableTestPtr, so we need to dereference to get the actual pointer
718            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
719            let managed = NonNull::new(managed_ptr).expect("Null pointer");
720
721            // Manually construct PyTensor for testing
722            let pytensor = PyTensor {
723                managed,
724                capsule: capsule.clone().unbind(),
725            };
726
727            // Test all accessor methods
728            assert!(pytensor.device().is_cpu());
729            assert!(pytensor.dtype().is_f32());
730            assert_eq!(pytensor.ndim(), 2);
731            assert_eq!(pytensor.shape(), &[2, 3]);
732            assert!(pytensor.strides().is_none());
733            assert!(pytensor.is_contiguous());
734            assert!(!pytensor.data_ptr().is_null());
735            assert!(!pytensor.data_ptr_raw().is_null());
736            assert_eq!(pytensor.byte_offset(), 0);
737            assert_eq!(pytensor.numel(), 6);
738            assert_eq!(pytensor.itemsize(), 4);
739            assert_eq!(pytensor.nbytes(), 24);
740
741            // Test Debug
742            let debug = format!("{:?}", pytensor);
743            assert!(debug.contains("PyTensor"));
744            assert!(debug.contains("shape"));
745            assert!(debug.contains("dtype"));
746            assert!(debug.contains("device"));
747
748            // Prevent double-free by not running the deleter
749            std::mem::forget(pytensor);
750        });
751    }
752
753    #[test]
754    fn test_pytensor_with_strides_contiguous() {
755        Python::attach(|py| {
756            let ctx = Box::new(TestTensorContext {
757                data: vec![1.0; 24],
758                shape: vec![2, 3, 4],
759                strides: Some(vec![12, 4, 1]), // Row-major contiguous
760            });
761
762            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
763                .expect("Failed to create capsule");
764
765            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
766            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
767            let managed = NonNull::new(managed_ptr).unwrap();
768
769            let pytensor = PyTensor {
770                managed,
771                capsule: capsule.clone().unbind(),
772            };
773
774            assert_eq!(pytensor.ndim(), 3);
775            assert_eq!(pytensor.shape(), &[2, 3, 4]);
776            assert_eq!(pytensor.strides(), Some(&[12i64, 4, 1][..]));
777            assert!(pytensor.is_contiguous());
778            assert_eq!(pytensor.numel(), 24);
779
780            std::mem::forget(pytensor);
781        });
782    }
783
784    #[test]
785    fn test_pytensor_non_contiguous() {
786        Python::attach(|py| {
787            let ctx = Box::new(TestTensorContext {
788                data: vec![1.0; 6],
789                shape: vec![2, 3],
790                strides: Some(vec![1, 2]), // Column-major (non-contiguous for row-major check)
791            });
792
793            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
794                .expect("Failed to create capsule");
795
796            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
797            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
798            let managed = NonNull::new(managed_ptr).unwrap();
799
800            let pytensor = PyTensor {
801                managed,
802                capsule: capsule.clone().unbind(),
803            };
804
805            assert!(!pytensor.is_contiguous());
806            assert_eq!(pytensor.strides(), Some(&[1i64, 2][..]));
807
808            std::mem::forget(pytensor);
809        });
810    }
811
812    #[test]
813    fn test_pytensor_scalar() {
814        Python::attach(|py| {
815            let ctx = Box::new(TestTensorContext {
816                data: vec![42.0],
817                shape: vec![],
818                strides: None,
819            });
820
821            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
822                .expect("Failed to create capsule");
823
824            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
825            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
826            let managed = NonNull::new(managed_ptr).unwrap();
827
828            let pytensor = PyTensor {
829                managed,
830                capsule: capsule.clone().unbind(),
831            };
832
833            assert_eq!(pytensor.ndim(), 0);
834            assert!(pytensor.shape().is_empty());
835            assert!(pytensor.is_contiguous());
836            assert_eq!(pytensor.numel(), 1);
837
838            std::mem::forget(pytensor);
839        });
840    }
841
842    #[test]
843    fn test_pytensor_with_byte_offset() {
844        Python::attach(|py| {
845            let ctx = Box::new(TestTensorContext {
846                data: vec![1.0; 20],
847                shape: vec![10],
848                strides: None,
849            });
850
851            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 16, false)
852                .expect("Failed to create capsule");
853
854            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
855            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
856            let managed = NonNull::new(managed_ptr).unwrap();
857
858            let pytensor = PyTensor {
859                managed,
860                capsule: capsule.clone().unbind(),
861            };
862
863            assert_eq!(pytensor.byte_offset(), 16);
864            let raw = pytensor.data_ptr_raw() as usize;
865            let adjusted = pytensor.data_ptr() as usize;
866            assert_eq!(adjusted, raw + 16);
867
868            std::mem::forget(pytensor);
869        });
870    }
871
872    #[test]
873    fn test_pytensor_cuda_device() {
874        Python::attach(|py| {
875            let ctx = Box::new(TestTensorContext {
876                data: vec![1.0; 512],
877                shape: vec![16, 32],
878                strides: None,
879            });
880
881            let capsule = create_test_capsule(py, ctx, cuda_device(1), dtype_f32(), 0, false)
882                .expect("Failed to create capsule");
883
884            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
885            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
886            let managed = NonNull::new(managed_ptr).unwrap();
887
888            let pytensor = PyTensor {
889                managed,
890                capsule: capsule.clone().unbind(),
891            };
892
893            assert!(pytensor.device().is_cuda());
894            assert_eq!(pytensor.device().device_id, 1);
895
896            std::mem::forget(pytensor);
897        });
898    }
899
900    #[test]
901    fn test_pytensor_f64_dtype() {
902        Python::attach(|py| {
903            // Use f32 data but declare f64 dtype for testing
904            let ctx = Box::new(TestTensorContext {
905                data: vec![1.0; 6], // 6 f32 = 24 bytes = 3 f64
906                shape: vec![3],
907                strides: None,
908            });
909
910            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f64(), 0, false)
911                .expect("Failed to create capsule");
912
913            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
914            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
915            let managed = NonNull::new(managed_ptr).unwrap();
916
917            let pytensor = PyTensor {
918                managed,
919                capsule: capsule.clone().unbind(),
920            };
921
922            assert!(pytensor.dtype().is_f64());
923            assert_eq!(pytensor.itemsize(), 8);
924            assert_eq!(pytensor.nbytes(), 24);
925
926            std::mem::forget(pytensor);
927        });
928    }
929
930    #[test]
931    fn test_pytensor_empty_strides_scalar() {
932        Python::attach(|py| {
933            let ctx = Box::new(TestTensorContext {
934                data: vec![1.0],
935                shape: vec![],
936                strides: Some(vec![]), // Empty strides for scalar
937            });
938
939            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
940                .expect("Failed to create capsule");
941
942            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
943            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
944            let managed = NonNull::new(managed_ptr).unwrap();
945
946            let pytensor = PyTensor {
947                managed,
948                capsule: capsule.clone().unbind(),
949            };
950
951            assert!(pytensor.is_contiguous());
952            assert!(pytensor.strides().is_some());
953            assert!(pytensor.strides().unwrap().is_empty());
954
955            std::mem::forget(pytensor);
956        });
957    }
958
959    #[test]
960    fn test_pytensor_drop_calls_deleter() {
961        DELETER_CALL_COUNT.store(0, Ordering::SeqCst);
962
963        Python::attach(|py| {
964            let ctx = Box::new(TestTensorContext {
965                data: vec![1.0, 2.0, 3.0],
966                shape: vec![3],
967                strides: None,
968            });
969
970            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, true)
971                .expect("Failed to create capsule");
972
973            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
974            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
975            let managed = NonNull::new(managed_ptr).unwrap();
976
977            {
978                let pytensor = PyTensor {
979                    managed,
980                    capsule: capsule.clone().unbind(),
981                };
982
983                // PyTensor exists, deleter not called yet
984                assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 0);
985
986                // Drop the PyTensor
987                drop(pytensor);
988            }
989
990            // Deleter should have been called
991            assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 1);
992        });
993    }
994
995    #[test]
996    fn test_pytensor_drop_no_deleter() {
997        Python::attach(|py| {
998            let ctx = Box::new(TestTensorContext {
999                data: vec![1.0],
1000                shape: vec![1],
1001                strides: None,
1002            });
1003
1004            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
1005                .expect("Failed to create capsule");
1006
1007            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1008            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1009            let managed = NonNull::new(managed_ptr).unwrap();
1010
1011            let pytensor = PyTensor {
1012                managed,
1013                capsule: capsule.clone().unbind(),
1014            };
1015
1016            // Drop without deleter should not crash
1017            drop(pytensor);
1018
1019            // Clean up manually since no deleter
1020            unsafe {
1021                let managed = Box::from_raw(managed_ptr);
1022                if !managed.manager_ctx.is_null() {
1023                    let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
1024                }
1025            }
1026        });
1027    }
1028}