Skip to main content

pyo3_dlpack/
managed.rs

1//! Managed tensor for importing from Python via DLPack.
2//!
3//! This module provides `PyTensor`, a wrapper around a DLPack tensor
4//! received from Python that provides safe access to tensor metadata.
5
6use crate::ffi::{
7    DLDataType, DLDevice, DLManagedTensor, DLManagedTensorVersioned, DLTensor,
8    DLPACK_FLAG_BITMASK_READ_ONLY, DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION,
9};
10use crate::{
11    DLPACK_CAPSULE_NAME, DLPACK_CAPSULE_NAME_USED, DLPACK_VERSIONED_CAPSULE_NAME,
12    DLPACK_VERSIONED_CAPSULE_NAME_USED,
13};
14use pyo3::prelude::*;
15use pyo3::types::PyCapsule;
16use std::ffi::{c_void, CStr};
17use std::ptr::NonNull;
18
19/// Which managed-tensor layout backs a [`PyTensor`].
20///
21/// The embedded `DLTensor` lives at a different offset in the unversioned vs.
22/// versioned struct, and each has its own deleter signature, so we keep the
23/// typed owning pointer and branch where layout matters.
24#[derive(Clone, Copy)]
25enum ManagedPtr {
26    Unversioned(NonNull<DLManagedTensor>),
27    Versioned(NonNull<DLManagedTensorVersioned>),
28}
29
30impl ManagedPtr {
31    /// Borrow the embedded `DLTensor`, which lives at a different offset in the
32    /// unversioned vs. versioned managed struct.
33    ///
34    /// # Safety
35    /// The pointer must still address a live managed tensor of its layout.
36    unsafe fn dl_tensor(&self) -> &DLTensor {
37        match *self {
38            ManagedPtr::Unversioned(p) => &p.as_ref().dl_tensor,
39            ManagedPtr::Versioned(p) => &p.as_ref().dl_tensor,
40        }
41    }
42
43    /// Invoke the producer's deleter (if present) at the correct struct offset.
44    ///
45    /// # Safety
46    /// Must be called at most once, when relinquishing ownership of the tensor.
47    unsafe fn run_deleter(&self) {
48        match *self {
49            ManagedPtr::Unversioned(p) => {
50                if let Some(deleter) = p.as_ref().deleter {
51                    deleter(p.as_ptr());
52                }
53            }
54            ManagedPtr::Versioned(p) => {
55                if let Some(deleter) = p.as_ref().deleter {
56                    deleter(p.as_ptr());
57                }
58            }
59        }
60    }
61}
62
63/// A tensor imported from Python via the DLPack protocol.
64///
65/// This type wraps a `DLManagedTensor` received from a Python object
66/// (typically a PyTorch, JAX, or NumPy tensor) and provides safe access
67/// to the tensor's metadata and data pointer.
68///
69/// # Lifetime
70///
71/// The tensor data is valid as long as this `PyTensor` is alive.
72/// When dropped, the tensor's deleter is called to notify the producer.
73///
74/// # Example
75///
76/// ```ignore
77/// use pyo3::prelude::*;
78/// use pyo3_dlpack::PyTensor;
79///
80/// #[pyfunction]
81/// fn process(py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<()> {
82///     let tensor = PyTensor::from_pyany(py, obj)?;
83///
84///     println!("Shape: {:?}", tensor.shape());
85///     println!("Device: {:?}", tensor.device());
86///     println!("Dtype: {:?}", tensor.dtype());
87///
88///     if tensor.device().is_cpu() {
89///         // Safe to access data on CPU
90///         let ptr = tensor.data_ptr() as *const f32;
91///         // ...
92///     }
93///
94///     Ok(())
95/// }
96/// ```
97pub struct PyTensor {
98    managed: ManagedPtr,
99    /// We store the capsule to prevent it from being garbage collected
100    /// while we hold a reference to the managed tensor.
101    #[allow(dead_code)]
102    capsule: Py<PyCapsule>,
103}
104
105// Safety: The underlying DLManagedTensor is thread-safe to send
106// (the producer guarantees this by implementing DLPack)
107unsafe impl Send for PyTensor {}
108
109/// Reject a managed tensor whose `ndim` is negative. `ndim` is an `i32`, and a
110/// negative value would cast to a near-`usize::MAX` length in `shape()` /
111/// `strides()`, producing a slice that reads far out of bounds. Refuse it at the
112/// import boundary before any accessor can trust it.
113fn validate_ndim(ndim: i32) -> PyResult<()> {
114    if ndim < 0 {
115        return Err(pyo3::exceptions::PyValueError::new_err(format!(
116            "DLPack tensor has negative ndim: {ndim}"
117        )));
118    }
119    Ok(())
120}
121
122impl PyTensor {
123    /// Borrow the embedded `DLTensor`, which lives at a different offset in the
124    /// unversioned vs. versioned managed struct.
125    fn dl_tensor(&self) -> &DLTensor {
126        unsafe { self.managed.dl_tensor() }
127    }
128
129    /// Create a PyTensor from a Python object that supports the DLPack protocol.
130    ///
131    /// This calls `__dlpack__()` on the object to get a DLPack capsule,
132    /// then extracts the tensor information.
133    ///
134    /// # Arguments
135    ///
136    /// * `py` - Python GIL token
137    /// * `obj` - A Python object that implements `__dlpack__()` (e.g., PyTorch tensor)
138    ///
139    /// # Errors
140    ///
141    /// Returns an error if:
142    /// - The object doesn't have a `__dlpack__` method
143    /// - The returned capsule is invalid
144    /// - The capsule doesn't contain a valid DLManagedTensor
145    pub fn from_pyany(_py: Python<'_>, obj: &Bound<'_, PyAny>) -> PyResult<Self> {
146        let py = obj.py();
147
148        // Advertise versioned support via max_version. Producers whose
149        // __dlpack__ predates the kwarg raise TypeError; fall back to a no-arg
150        // call for them. The actual capsule kind is decided later by name.
151        let kwargs = pyo3::types::PyDict::new(py);
152        kwargs.set_item(
153            pyo3::intern!(py, "max_version"),
154            (DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION),
155        )?;
156
157        let capsule_obj = match obj.call_method("__dlpack__", (), Some(&kwargs)) {
158            Ok(c) => c,
159            Err(e) if e.is_instance_of::<pyo3::exceptions::PyTypeError>(py) => {
160                obj.call_method0("__dlpack__")?
161            }
162            Err(e) => return Err(e),
163        };
164
165        let capsule: Bound<'_, PyCapsule> = capsule_obj.cast_into().map_err(|e| {
166            pyo3::exceptions::PyTypeError::new_err(format!(
167                "__dlpack__ did not return a PyCapsule: {:?}",
168                e.into_inner()
169            ))
170        })?;
171        Self::from_capsule(&capsule)
172    }
173
174    /// Create a PyTensor directly from a DLPack PyCapsule.
175    ///
176    /// # Arguments
177    ///
178    /// * `capsule` - A PyCapsule containing a DLManagedTensor
179    ///
180    /// # Errors
181    ///
182    /// Returns an error if the capsule is invalid or has the wrong name.
183    pub fn from_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
184        // Decide which DLPack layout this capsule carries by reading its name.
185        // A producer may return a legacy capsule even when versioned was
186        // requested, so we dispatch on the actual name, never on assumptions.
187        let name_ptr = unsafe { pyo3::ffi::PyCapsule_GetName(capsule.as_ptr()) };
188        if name_ptr.is_null() {
189            return Err(pyo3::exceptions::PyValueError::new_err(
190                "DLPack capsule has no name",
191            ));
192        }
193        let name = unsafe { CStr::from_ptr(name_ptr) };
194        let name_bytes = name.to_bytes();
195
196        if name_bytes == DLPACK_CAPSULE_NAME.to_bytes() {
197            Self::from_unversioned_capsule(capsule)
198        } else if name_bytes == DLPACK_VERSIONED_CAPSULE_NAME.to_bytes() {
199            Self::from_versioned_capsule(capsule)
200        } else {
201            Err(pyo3::exceptions::PyValueError::new_err(format!(
202                "unexpected DLPack capsule name: {:?}",
203                name
204            )))
205        }
206    }
207
208    /// Consume an unversioned (`dltensor`) capsule.
209    fn from_unversioned_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
210        let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME))?;
211        let managed = NonNull::new(ptr.as_ptr() as *mut DLManagedTensor).ok_or_else(|| {
212            pyo3::exceptions::PyValueError::new_err("DLPack capsule contains null pointer")
213        })?;
214
215        // Reject a malformed negative ndim before renaming, so a rejected
216        // capsule is left unconsumed for the producer's destructor to free.
217        validate_ndim(unsafe { managed.as_ref().dl_tensor.ndim })?;
218
219        // Per DLPack protocol, rename to "used_dltensor" to take ownership and
220        // prevent double-consume / double-free.
221        //
222        // SAFETY: reading the pointer above and renaming the capsule here are two
223        // steps, not one atomic operation. They are sound only because the GIL
224        // serializes consumers; a future free-threaded (no-GIL) build would need
225        // an external lock to prevent a double-consume race.
226        let set_name_result = unsafe {
227            pyo3::ffi::PyCapsule_SetName(capsule.as_ptr(), DLPACK_CAPSULE_NAME_USED.as_ptr())
228        };
229        if set_name_result != 0 {
230            return Err(pyo3::exceptions::PyRuntimeError::new_err(
231                "Failed to mark DLPack capsule as consumed: PyCapsule_SetName failed",
232            ));
233        }
234
235        Ok(Self {
236            managed: ManagedPtr::Unversioned(managed),
237            capsule: capsule.clone().unbind(),
238        })
239    }
240
241    /// Consume a versioned (`dltensor_versioned`) capsule.
242    fn from_versioned_capsule(capsule: &Bound<'_, PyCapsule>) -> PyResult<Self> {
243        let ptr = capsule.pointer_checked(Some(DLPACK_VERSIONED_CAPSULE_NAME))?;
244        let managed =
245            NonNull::new(ptr.as_ptr() as *mut DLManagedTensorVersioned).ok_or_else(|| {
246                pyo3::exceptions::PyValueError::new_err("DLPack capsule contains null pointer")
247            })?;
248
249        // Reject ANY major-version mismatch, per the DLPack spec: a different
250        // major version may lay out the struct body (flags, dl_tensor)
251        // differently, so we must not read past the version field. Minor
252        // versions are ABI-compatible, so they are accepted. We return before
253        // renaming the capsule, leaving the producer's destructor to call the
254        // deleter (which lives at a stable offset across major versions).
255        let version = unsafe { managed.as_ref().version };
256        if version.major != DLPACK_MAJOR_VERSION {
257            return Err(pyo3::exceptions::PyValueError::new_err(format!(
258                "unsupported DLPack major version {}.{} (this build supports major version {})",
259                version.major, version.minor, DLPACK_MAJOR_VERSION
260            )));
261        }
262
263        // Reject a malformed negative ndim before renaming (see
264        // `from_unversioned_capsule`); the version check above guarantees the
265        // struct layout is ours, so reading `dl_tensor.ndim` is sound.
266        validate_ndim(unsafe { managed.as_ref().dl_tensor.ndim })?;
267
268        // SAFETY: as in `from_unversioned_capsule`, the read-then-rename consume
269        // is sound only under the GIL's serialization of consumers.
270        let set_name_result = unsafe {
271            pyo3::ffi::PyCapsule_SetName(
272                capsule.as_ptr(),
273                DLPACK_VERSIONED_CAPSULE_NAME_USED.as_ptr(),
274            )
275        };
276        if set_name_result != 0 {
277            return Err(pyo3::exceptions::PyRuntimeError::new_err(
278                "Failed to mark DLPack capsule as consumed: PyCapsule_SetName failed",
279            ));
280        }
281
282        Ok(Self {
283            managed: ManagedPtr::Versioned(managed),
284            capsule: capsule.clone().unbind(),
285        })
286    }
287
288    /// Get the device where the tensor data resides.
289    pub fn device(&self) -> DLDevice {
290        self.dl_tensor().device
291    }
292
293    /// Get the data type of the tensor elements.
294    pub fn dtype(&self) -> DLDataType {
295        self.dl_tensor().dtype
296    }
297
298    /// Get the number of dimensions.
299    pub fn ndim(&self) -> usize {
300        self.dl_tensor().ndim as usize
301    }
302
303    /// Get the shape as a slice.
304    ///
305    /// The length of the slice equals `ndim()`.
306    pub fn shape(&self) -> &[i64] {
307        let tensor = self.dl_tensor();
308        if tensor.shape.is_null() {
309            &[]
310        } else {
311            unsafe { std::slice::from_raw_parts(tensor.shape, tensor.ndim as usize) }
312        }
313    }
314
315    /// Get the strides as a slice, or `None` for contiguous tensors.
316    ///
317    /// Strides are in number of elements (not bytes).
318    /// If `None`, the tensor is assumed to be in compact row-major order.
319    pub fn strides(&self) -> Option<&[i64]> {
320        let tensor = self.dl_tensor();
321        if tensor.strides.is_null() {
322            None
323        } else {
324            Some(unsafe { std::slice::from_raw_parts(tensor.strides, tensor.ndim as usize) })
325        }
326    }
327
328    /// Check if the tensor is contiguous in row-major (C) order.
329    pub fn is_contiguous(&self) -> bool {
330        match self.strides() {
331            None => true,
332            Some(strides) => {
333                let shape = self.shape();
334                if shape.is_empty() {
335                    return true;
336                }
337
338                let mut expected_stride = 1i64;
339                for i in (0..shape.len()).rev() {
340                    if strides[i] != expected_stride {
341                        return false;
342                    }
343                    expected_stride *= shape[i];
344                }
345                true
346            }
347        }
348    }
349
350    /// Get the raw data pointer.
351    ///
352    /// For GPU tensors, this is a device pointer that cannot be directly
353    /// dereferenced on the CPU.
354    ///
355    /// The pointer is adjusted by `byte_offset()`.
356    pub fn data_ptr(&self) -> *mut c_void {
357        let tensor = self.dl_tensor();
358        // `wrapping_add` (not `add`): the base may be null (0-element tensor) or a
359        // non-host device pointer, where `add`'s in-bounds/provenance requirement
360        // would be undefined behavior. The numeric result is identical.
361        (tensor.data as *mut u8).wrapping_add(tensor.byte_offset as usize) as *mut c_void
362    }
363
364    /// Get the raw data pointer without byte offset adjustment.
365    pub fn data_ptr_raw(&self) -> *mut c_void {
366        self.dl_tensor().data
367    }
368
369    /// Get the byte offset from the raw data pointer.
370    pub fn byte_offset(&self) -> u64 {
371        self.dl_tensor().byte_offset
372    }
373
374    /// Get the total number of elements in the tensor.
375    pub fn numel(&self) -> usize {
376        self.shape().iter().map(|&d| d as usize).product()
377    }
378
379    /// Get the size of one element in bytes.
380    pub fn itemsize(&self) -> usize {
381        self.dtype().itemsize()
382    }
383
384    /// Get the total size of the tensor data in bytes.
385    pub fn nbytes(&self) -> usize {
386        self.numel() * self.itemsize()
387    }
388
389    /// Whether the tensor is marked read-only.
390    ///
391    /// Only versioned (DLPack 1.0) tensors can carry this flag; legacy tensors
392    /// always report `false`.
393    pub fn is_read_only(&self) -> bool {
394        match self.managed {
395            ManagedPtr::Unversioned(_) => false,
396            ManagedPtr::Versioned(p) => unsafe {
397                p.as_ref().flags & DLPACK_FLAG_BITMASK_READ_ONLY != 0
398            },
399        }
400    }
401}
402
403impl Drop for PyTensor {
404    fn drop(&mut self) {
405        // Call the producer's deleter at the correct struct offset for each layout.
406        unsafe { self.managed.run_deleter() }
407    }
408}
409
410impl std::fmt::Debug for PyTensor {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        f.debug_struct("PyTensor")
413            .field("shape", &self.shape())
414            .field("strides", &self.strides())
415            .field("dtype", &self.dtype())
416            .field("device", &self.device())
417            .field("byte_offset", &self.byte_offset())
418            .finish()
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425    use crate::ffi::{cpu_device, cuda_device, dtype_f32, dtype_f64, DLTensor};
426    use pyo3::Python;
427    use std::ffi::CString;
428
429    /// Wrapper to make pointer Send for testing
430    #[repr(transparent)]
431    struct SendableTestPtr(*mut DLManagedTensor);
432    unsafe impl Send for SendableTestPtr {}
433
434    /// Helper to create a test DLManagedTensor with given parameters
435    struct TestManagedTensor {
436        managed: Box<DLManagedTensor>,
437        shape: Vec<i64>,
438        strides: Option<Vec<i64>>,
439        #[allow(dead_code)]
440        data: Vec<u8>,
441    }
442
443    impl TestManagedTensor {
444        fn new(
445            shape: Vec<i64>,
446            strides: Option<Vec<i64>>,
447            dtype: DLDataType,
448            device: DLDevice,
449        ) -> Self {
450            let numel: usize = shape.iter().map(|&d| d as usize).product();
451            let data = vec![0u8; numel.max(1) * dtype.itemsize()];
452
453            let mut result = Self {
454                managed: Box::new(DLManagedTensor {
455                    dl_tensor: DLTensor {
456                        data: std::ptr::null_mut(),
457                        device,
458                        ndim: shape.len() as i32,
459                        dtype,
460                        shape: std::ptr::null_mut(),
461                        strides: std::ptr::null_mut(),
462                        byte_offset: 0,
463                    },
464                    manager_ctx: std::ptr::null_mut(),
465                    deleter: None,
466                }),
467                shape,
468                strides,
469                data,
470            };
471
472            // Set up pointers
473            result.managed.dl_tensor.data = result.data.as_ptr() as *mut c_void;
474            result.managed.dl_tensor.shape = result.shape.as_mut_ptr();
475            if let Some(ref mut s) = result.strides {
476                result.managed.dl_tensor.strides = s.as_mut_ptr();
477            }
478
479            result
480        }
481
482        fn with_byte_offset(mut self, offset: u64) -> Self {
483            self.managed.dl_tensor.byte_offset = offset;
484            self
485        }
486
487        fn as_ptr(&self) -> *mut DLManagedTensor {
488            &*self.managed as *const _ as *mut _
489        }
490    }
491
492    // ========================================================================
493    // is_contiguous tests
494    // ========================================================================
495
496    #[test]
497    fn test_is_contiguous_no_strides() {
498        // No strides means contiguous by default
499        let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
500
501        // Create a mock PyTensor-like check using the raw managed tensor
502        let managed = unsafe { &*tensor.as_ptr() };
503        let strides_ptr = managed.dl_tensor.strides;
504
505        // No strides pointer = contiguous
506        assert!(strides_ptr.is_null());
507    }
508
509    #[test]
510    fn test_is_contiguous_with_contiguous_strides() {
511        // Row-major contiguous strides for shape [2, 3, 4]
512        // strides should be [12, 4, 1]
513        let tensor = TestManagedTensor::new(
514            vec![2, 3, 4],
515            Some(vec![12, 4, 1]),
516            dtype_f32(),
517            cpu_device(),
518        );
519
520        let shape = &tensor.shape;
521        let strides = tensor.strides.as_ref().unwrap();
522
523        // Verify contiguity check logic
524        let mut expected_stride = 1i64;
525        let mut is_contiguous = true;
526        for i in (0..shape.len()).rev() {
527            if strides[i] != expected_stride {
528                is_contiguous = false;
529                break;
530            }
531            expected_stride *= shape[i];
532        }
533        assert!(is_contiguous);
534    }
535
536    #[test]
537    fn test_is_contiguous_with_non_contiguous_strides() {
538        // Non-contiguous strides (transposed)
539        let tensor = TestManagedTensor::new(
540            vec![2, 3, 4],
541            Some(vec![1, 2, 6]), // Column-major like strides
542            dtype_f32(),
543            cpu_device(),
544        );
545
546        let shape = &tensor.shape;
547        let strides = tensor.strides.as_ref().unwrap();
548
549        let mut expected_stride = 1i64;
550        let mut is_contiguous = true;
551        for i in (0..shape.len()).rev() {
552            if strides[i] != expected_stride {
553                is_contiguous = false;
554                break;
555            }
556            expected_stride *= shape[i];
557        }
558        assert!(!is_contiguous);
559    }
560
561    #[test]
562    fn test_is_contiguous_empty_tensor() {
563        let tensor = TestManagedTensor::new(vec![], None, dtype_f32(), cpu_device());
564        // Empty shape is contiguous
565        assert!(tensor.shape.is_empty());
566    }
567
568    #[test]
569    fn test_is_contiguous_1d() {
570        let tensor = TestManagedTensor::new(vec![10], Some(vec![1]), dtype_f32(), cpu_device());
571        let strides = tensor.strides.as_ref().unwrap();
572        assert_eq!(strides[0], 1);
573    }
574
575    // ========================================================================
576    // numel and nbytes tests
577    // ========================================================================
578
579    #[test]
580    fn test_numel_calculation() {
581        let shapes_and_expected: Vec<(Vec<i64>, usize)> = vec![
582            (vec![], 1), // Scalar (product of empty = 1)
583            (vec![5], 5),
584            (vec![2, 3], 6),
585            (vec![2, 3, 4], 24),
586            (vec![1, 1, 1, 1], 1),
587            (vec![10, 20, 30], 6000),
588        ];
589
590        for (shape, expected) in shapes_and_expected {
591            let numel: usize = if shape.is_empty() {
592                1 // Scalar case
593            } else {
594                shape.iter().map(|&d| d as usize).product()
595            };
596            assert_eq!(numel, expected, "Failed for shape {:?}", shape);
597        }
598    }
599
600    #[test]
601    fn test_nbytes_calculation() {
602        // f32 tensor [2, 3, 4] = 24 elements * 4 bytes = 96 bytes
603        let tensor = TestManagedTensor::new(vec![2, 3, 4], None, dtype_f32(), cpu_device());
604        let numel: usize = tensor.shape.iter().map(|&d| d as usize).product();
605        let itemsize = dtype_f32().itemsize();
606        assert_eq!(numel * itemsize, 96);
607
608        // f64 tensor [2, 3] = 6 elements * 8 bytes = 48 bytes
609        let tensor2 = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
610        let numel2: usize = tensor2.shape.iter().map(|&d| d as usize).product();
611        let itemsize2 = dtype_f64().itemsize();
612        assert_eq!(numel2 * itemsize2, 48);
613    }
614
615    // ========================================================================
616    // data_ptr tests
617    // ========================================================================
618
619    #[test]
620    fn test_data_ptr_with_offset() {
621        let tensor =
622            TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device()).with_byte_offset(16);
623
624        let managed = unsafe { &*tensor.as_ptr() };
625        let base_ptr = managed.dl_tensor.data as usize;
626        let offset = managed.dl_tensor.byte_offset as usize;
627        let adjusted_ptr = base_ptr + offset;
628
629        assert_eq!(offset, 16);
630        assert_eq!(adjusted_ptr, base_ptr + 16);
631    }
632
633    #[test]
634    fn test_data_ptr_no_offset() {
635        let tensor = TestManagedTensor::new(vec![10], None, dtype_f32(), cpu_device());
636
637        let managed = unsafe { &*tensor.as_ptr() };
638        assert_eq!(managed.dl_tensor.byte_offset, 0);
639    }
640
641    // ========================================================================
642    // Device and dtype accessor tests
643    // ========================================================================
644
645    #[test]
646    fn test_device_accessor() {
647        let cpu_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
648        let managed = unsafe { &*cpu_tensor.as_ptr() };
649        assert!(managed.dl_tensor.device.is_cpu());
650
651        let cuda_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cuda_device(1));
652        let managed = unsafe { &*cuda_tensor.as_ptr() };
653        assert!(managed.dl_tensor.device.is_cuda());
654        assert_eq!(managed.dl_tensor.device.device_id, 1);
655    }
656
657    #[test]
658    fn test_dtype_accessor() {
659        let f32_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f32(), cpu_device());
660        let managed = unsafe { &*f32_tensor.as_ptr() };
661        assert!(managed.dl_tensor.dtype.is_f32());
662
663        let f64_tensor = TestManagedTensor::new(vec![2, 3], None, dtype_f64(), cpu_device());
664        let managed = unsafe { &*f64_tensor.as_ptr() };
665        assert!(managed.dl_tensor.dtype.is_f64());
666    }
667
668    // ========================================================================
669    // ndim and shape tests
670    // ========================================================================
671
672    #[test]
673    fn test_ndim() {
674        let shapes: Vec<Vec<i64>> = vec![
675            vec![],
676            vec![5],
677            vec![2, 3],
678            vec![2, 3, 4],
679            vec![1, 2, 3, 4, 5],
680        ];
681
682        for shape in shapes {
683            let expected_ndim = shape.len();
684            let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
685            let managed = unsafe { &*tensor.as_ptr() };
686            assert_eq!(managed.dl_tensor.ndim as usize, expected_ndim);
687        }
688    }
689
690    #[test]
691    fn test_shape_accessor() {
692        let shape = vec![2i64, 3, 4];
693        let tensor = TestManagedTensor::new(shape.clone(), None, dtype_f32(), cpu_device());
694        let managed = unsafe { &*tensor.as_ptr() };
695
696        let shape_slice = unsafe {
697            std::slice::from_raw_parts(managed.dl_tensor.shape, managed.dl_tensor.ndim as usize)
698        };
699        assert_eq!(shape_slice, &[2, 3, 4]);
700    }
701
702    // ========================================================================
703    // PyCapsule integration tests (require Python)
704    // ========================================================================
705
706    #[test]
707    fn test_capsule_creation_and_extraction() {
708        Python::attach(|py| {
709            // Create a test managed tensor
710            let mut shape = vec![2i64, 3];
711            let data = [0u8; 24].to_vec(); // 6 f32 elements
712
713            let managed = Box::new(DLManagedTensor {
714                dl_tensor: DLTensor {
715                    data: data.as_ptr() as *mut c_void,
716                    device: cpu_device(),
717                    ndim: 2,
718                    dtype: dtype_f32(),
719                    shape: shape.as_mut_ptr(),
720                    strides: std::ptr::null_mut(),
721                    byte_offset: 0,
722                },
723                manager_ctx: std::ptr::null_mut(),
724                deleter: None,
725            });
726
727            let managed_ptr = Box::into_raw(managed);
728            let sendable = SendableTestPtr(managed_ptr);
729            let name = CString::new("dltensor").unwrap();
730
731            // Create a PyCapsule with Send wrapper
732            let capsule =
733                PyCapsule::new(py, sendable, Some(name)).expect("Failed to create capsule");
734
735            // Verify capsule name exists
736            let capsule_name = capsule.name().expect("Failed to get name");
737            assert!(capsule_name.is_some());
738
739            // Extract the pointer back - pointer_checked returns NonNull on success
740            let _extracted = capsule
741                .pointer_checked(Some(DLPACK_CAPSULE_NAME))
742                .expect("Failed to extract pointer");
743
744            // Clean up
745            unsafe {
746                let _ = Box::from_raw(managed_ptr);
747            }
748        });
749    }
750
751    #[test]
752    fn test_capsule_wrong_name() {
753        /// Wrapper for test data
754        #[allow(dead_code)]
755        struct TestData(i32);
756        unsafe impl Send for TestData {}
757
758        Python::attach(|py| {
759            let data = TestData(42);
760            let name = CString::new("wrong_name").unwrap();
761
762            let capsule = PyCapsule::new(py, data, Some(name)).expect("Failed to create capsule");
763
764            // Should fail when extracting with wrong expected name
765            let result = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME));
766            assert!(result.is_err());
767        });
768    }
769
770    #[test]
771    fn test_pytensor_send() {
772        // Verify PyTensor implements Send
773        fn assert_send<T: Send>() {}
774        assert_send::<PyTensor>();
775    }
776
777    // ========================================================================
778    // PyTensor comprehensive tests using direct DLManagedTensor capsules
779    // ========================================================================
780
781    use std::sync::atomic::{AtomicUsize, Ordering};
782
783    static DELETER_CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
784
785    /// Helper struct to hold all the data for a test tensor capsule
786    struct TestTensorContext {
787        data: Vec<f32>,
788        shape: Vec<i64>,
789        strides: Option<Vec<i64>>,
790    }
791
792    /// Create a DLPack-compatible capsule for testing PyTensor
793    fn create_test_capsule(
794        py: Python<'_>,
795        ctx: Box<TestTensorContext>,
796        device: DLDevice,
797        dtype: DLDataType,
798        byte_offset: u64,
799        with_deleter: bool,
800    ) -> PyResult<Bound<'_, PyCapsule>> {
801        let ctx_ptr = Box::into_raw(ctx);
802
803        unsafe {
804            let ctx_ref = &mut *ctx_ptr;
805
806            let managed = Box::new(DLManagedTensor {
807                dl_tensor: DLTensor {
808                    data: ctx_ref.data.as_ptr() as *mut c_void,
809                    device,
810                    ndim: ctx_ref.shape.len() as i32,
811                    dtype,
812                    shape: ctx_ref.shape.as_mut_ptr(),
813                    strides: ctx_ref
814                        .strides
815                        .as_mut()
816                        .map(|s| s.as_mut_ptr())
817                        .unwrap_or(std::ptr::null_mut()),
818                    byte_offset,
819                },
820                manager_ctx: ctx_ptr as *mut c_void,
821                deleter: if with_deleter {
822                    Some(test_deleter)
823                } else {
824                    None
825                },
826            });
827
828            let managed_ptr = Box::into_raw(managed);
829            let wrapper = SendableTestPtr(managed_ptr);
830            let name = CString::new("dltensor").unwrap();
831
832            PyCapsule::new(py, wrapper, Some(name))
833        }
834    }
835
836    /// Test deleter that increments a counter
837    unsafe extern "C" fn test_deleter(managed_ptr: *mut DLManagedTensor) {
838        if !managed_ptr.is_null() {
839            DELETER_CALL_COUNT.fetch_add(1, Ordering::SeqCst);
840            let managed = Box::from_raw(managed_ptr);
841            if !managed.manager_ctx.is_null() {
842                let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
843            }
844        }
845    }
846
847    #[test]
848    fn test_pytensor_all_accessors() {
849        Python::attach(|py| {
850            let ctx = Box::new(TestTensorContext {
851                data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
852                shape: vec![2, 3],
853                strides: None,
854            });
855
856            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
857                .expect("Failed to create capsule");
858
859            // Create PyTensor - need to read the pointer from the capsule correctly
860            let ptr = capsule
861                .pointer_checked(Some(DLPACK_CAPSULE_NAME))
862                .expect("Failed to get pointer");
863            // The capsule stores SendableTestPtr, so we need to dereference to get the actual pointer
864            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
865            let managed = NonNull::new(managed_ptr).expect("Null pointer");
866
867            // Manually construct PyTensor for testing
868            let pytensor = PyTensor {
869                managed: ManagedPtr::Unversioned(managed),
870                capsule: capsule.clone().unbind(),
871            };
872
873            // Test all accessor methods
874            assert!(pytensor.device().is_cpu());
875            assert!(pytensor.dtype().is_f32());
876            assert_eq!(pytensor.ndim(), 2);
877            assert_eq!(pytensor.shape(), &[2, 3]);
878            assert!(pytensor.strides().is_none());
879            assert!(pytensor.is_contiguous());
880            assert!(!pytensor.data_ptr().is_null());
881            assert!(!pytensor.data_ptr_raw().is_null());
882            assert_eq!(pytensor.byte_offset(), 0);
883            assert_eq!(pytensor.numel(), 6);
884            assert_eq!(pytensor.itemsize(), 4);
885            assert_eq!(pytensor.nbytes(), 24);
886
887            // Test Debug
888            let debug = format!("{:?}", pytensor);
889            assert!(debug.contains("PyTensor"));
890            assert!(debug.contains("shape"));
891            assert!(debug.contains("dtype"));
892            assert!(debug.contains("device"));
893
894            // Prevent double-free by not running the deleter
895            std::mem::forget(pytensor);
896        });
897    }
898
899    #[test]
900    fn test_pytensor_with_strides_contiguous() {
901        Python::attach(|py| {
902            let ctx = Box::new(TestTensorContext {
903                data: vec![1.0; 24],
904                shape: vec![2, 3, 4],
905                strides: Some(vec![12, 4, 1]), // Row-major contiguous
906            });
907
908            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
909                .expect("Failed to create capsule");
910
911            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
912            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
913            let managed = NonNull::new(managed_ptr).unwrap();
914
915            let pytensor = PyTensor {
916                managed: ManagedPtr::Unversioned(managed),
917                capsule: capsule.clone().unbind(),
918            };
919
920            assert_eq!(pytensor.ndim(), 3);
921            assert_eq!(pytensor.shape(), &[2, 3, 4]);
922            assert_eq!(pytensor.strides(), Some(&[12i64, 4, 1][..]));
923            assert!(pytensor.is_contiguous());
924            assert_eq!(pytensor.numel(), 24);
925
926            std::mem::forget(pytensor);
927        });
928    }
929
930    #[test]
931    fn test_pytensor_non_contiguous() {
932        Python::attach(|py| {
933            let ctx = Box::new(TestTensorContext {
934                data: vec![1.0; 6],
935                shape: vec![2, 3],
936                strides: Some(vec![1, 2]), // Column-major (non-contiguous for row-major check)
937            });
938
939            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
940                .expect("Failed to create capsule");
941
942            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
943            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
944            let managed = NonNull::new(managed_ptr).unwrap();
945
946            let pytensor = PyTensor {
947                managed: ManagedPtr::Unversioned(managed),
948                capsule: capsule.clone().unbind(),
949            };
950
951            assert!(!pytensor.is_contiguous());
952            assert_eq!(pytensor.strides(), Some(&[1i64, 2][..]));
953
954            std::mem::forget(pytensor);
955        });
956    }
957
958    #[test]
959    fn test_pytensor_scalar() {
960        Python::attach(|py| {
961            let ctx = Box::new(TestTensorContext {
962                data: vec![42.0],
963                shape: vec![],
964                strides: None,
965            });
966
967            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
968                .expect("Failed to create capsule");
969
970            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
971            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
972            let managed = NonNull::new(managed_ptr).unwrap();
973
974            let pytensor = PyTensor {
975                managed: ManagedPtr::Unversioned(managed),
976                capsule: capsule.clone().unbind(),
977            };
978
979            assert_eq!(pytensor.ndim(), 0);
980            assert!(pytensor.shape().is_empty());
981            assert!(pytensor.is_contiguous());
982            assert_eq!(pytensor.numel(), 1);
983
984            std::mem::forget(pytensor);
985        });
986    }
987
988    #[test]
989    fn test_pytensor_with_byte_offset() {
990        Python::attach(|py| {
991            let ctx = Box::new(TestTensorContext {
992                data: vec![1.0; 20],
993                shape: vec![10],
994                strides: None,
995            });
996
997            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 16, false)
998                .expect("Failed to create capsule");
999
1000            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1001            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1002            let managed = NonNull::new(managed_ptr).unwrap();
1003
1004            let pytensor = PyTensor {
1005                managed: ManagedPtr::Unversioned(managed),
1006                capsule: capsule.clone().unbind(),
1007            };
1008
1009            assert_eq!(pytensor.byte_offset(), 16);
1010            let raw = pytensor.data_ptr_raw() as usize;
1011            let adjusted = pytensor.data_ptr() as usize;
1012            assert_eq!(adjusted, raw + 16);
1013
1014            std::mem::forget(pytensor);
1015        });
1016    }
1017
1018    #[test]
1019    fn test_pytensor_cuda_device() {
1020        Python::attach(|py| {
1021            let ctx = Box::new(TestTensorContext {
1022                data: vec![1.0; 512],
1023                shape: vec![16, 32],
1024                strides: None,
1025            });
1026
1027            let capsule = create_test_capsule(py, ctx, cuda_device(1), dtype_f32(), 0, false)
1028                .expect("Failed to create capsule");
1029
1030            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1031            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1032            let managed = NonNull::new(managed_ptr).unwrap();
1033
1034            let pytensor = PyTensor {
1035                managed: ManagedPtr::Unversioned(managed),
1036                capsule: capsule.clone().unbind(),
1037            };
1038
1039            assert!(pytensor.device().is_cuda());
1040            assert_eq!(pytensor.device().device_id, 1);
1041
1042            std::mem::forget(pytensor);
1043        });
1044    }
1045
1046    #[test]
1047    fn test_pytensor_f64_dtype() {
1048        Python::attach(|py| {
1049            // Use f32 data but declare f64 dtype for testing
1050            let ctx = Box::new(TestTensorContext {
1051                data: vec![1.0; 6], // 6 f32 = 24 bytes = 3 f64
1052                shape: vec![3],
1053                strides: None,
1054            });
1055
1056            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f64(), 0, false)
1057                .expect("Failed to create capsule");
1058
1059            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1060            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1061            let managed = NonNull::new(managed_ptr).unwrap();
1062
1063            let pytensor = PyTensor {
1064                managed: ManagedPtr::Unversioned(managed),
1065                capsule: capsule.clone().unbind(),
1066            };
1067
1068            assert!(pytensor.dtype().is_f64());
1069            assert_eq!(pytensor.itemsize(), 8);
1070            assert_eq!(pytensor.nbytes(), 24);
1071
1072            std::mem::forget(pytensor);
1073        });
1074    }
1075
1076    #[test]
1077    fn test_pytensor_empty_strides_scalar() {
1078        Python::attach(|py| {
1079            let ctx = Box::new(TestTensorContext {
1080                data: vec![1.0],
1081                shape: vec![],
1082                strides: Some(vec![]), // Empty strides for scalar
1083            });
1084
1085            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
1086                .expect("Failed to create capsule");
1087
1088            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1089            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1090            let managed = NonNull::new(managed_ptr).unwrap();
1091
1092            let pytensor = PyTensor {
1093                managed: ManagedPtr::Unversioned(managed),
1094                capsule: capsule.clone().unbind(),
1095            };
1096
1097            assert!(pytensor.is_contiguous());
1098            assert!(pytensor.strides().is_some());
1099            assert!(pytensor.strides().unwrap().is_empty());
1100
1101            std::mem::forget(pytensor);
1102        });
1103    }
1104
1105    #[test]
1106    fn test_pytensor_drop_calls_deleter() {
1107        DELETER_CALL_COUNT.store(0, Ordering::SeqCst);
1108
1109        Python::attach(|py| {
1110            let ctx = Box::new(TestTensorContext {
1111                data: vec![1.0, 2.0, 3.0],
1112                shape: vec![3],
1113                strides: None,
1114            });
1115
1116            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, true)
1117                .expect("Failed to create capsule");
1118
1119            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1120            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1121            let managed = NonNull::new(managed_ptr).unwrap();
1122
1123            {
1124                let pytensor = PyTensor {
1125                    managed: ManagedPtr::Unversioned(managed),
1126                    capsule: capsule.clone().unbind(),
1127                };
1128
1129                // PyTensor exists, deleter not called yet
1130                assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 0);
1131
1132                // Drop the PyTensor
1133                drop(pytensor);
1134            }
1135
1136            // Deleter should have been called
1137            assert_eq!(DELETER_CALL_COUNT.load(Ordering::SeqCst), 1);
1138        });
1139    }
1140
1141    #[test]
1142    fn test_pytensor_drop_no_deleter() {
1143        Python::attach(|py| {
1144            let ctx = Box::new(TestTensorContext {
1145                data: vec![1.0],
1146                shape: vec![1],
1147                strides: None,
1148            });
1149
1150            let capsule = create_test_capsule(py, ctx, cpu_device(), dtype_f32(), 0, false)
1151                .expect("Failed to create capsule");
1152
1153            let ptr = capsule.pointer_checked(Some(DLPACK_CAPSULE_NAME)).unwrap();
1154            let managed_ptr = unsafe { *(ptr.as_ptr() as *const *mut DLManagedTensor) };
1155            let managed = NonNull::new(managed_ptr).unwrap();
1156
1157            let pytensor = PyTensor {
1158                managed: ManagedPtr::Unversioned(managed),
1159                capsule: capsule.clone().unbind(),
1160            };
1161
1162            // Drop without deleter should not crash
1163            drop(pytensor);
1164
1165            // Clean up manually since no deleter
1166            unsafe {
1167                let managed = Box::from_raw(managed_ptr);
1168                if !managed.manager_ctx.is_null() {
1169                    let _ = Box::from_raw(managed.manager_ctx as *mut TestTensorContext);
1170                }
1171            }
1172        });
1173    }
1174
1175    // ========================================================================
1176    // Versioned / read-only round-trip tests
1177    // ========================================================================
1178
1179    struct RoundTripTensor {
1180        data: Vec<f32>,
1181        shape: Vec<i64>,
1182    }
1183
1184    impl crate::IntoDLPack for RoundTripTensor {
1185        fn tensor_info(&self) -> crate::TensorInfo {
1186            crate::TensorInfo::contiguous(
1187                self.data.as_ptr() as *mut c_void,
1188                cpu_device(),
1189                dtype_f32(),
1190                self.shape.clone(),
1191            )
1192        }
1193    }
1194
1195    #[test]
1196    fn test_roundtrip_versioned_readonly() {
1197        use crate::IntoDLPack;
1198        Python::attach(|py| {
1199            let t = RoundTripTensor {
1200                data: vec![1.0, 2.0, 3.0, 4.0],
1201                shape: vec![2, 2],
1202            };
1203            let capsule_obj = t.into_dlpack_readonly(py).unwrap();
1204            let bound = capsule_obj.into_bound(py);
1205            let capsule: Bound<'_, PyCapsule> = bound.cast_into().unwrap();
1206
1207            let tensor = PyTensor::from_capsule(&capsule).unwrap();
1208            assert!(tensor.is_read_only());
1209            assert_eq!(tensor.shape(), &[2, 2]);
1210            assert!(tensor.device().is_cpu());
1211            assert!(tensor.dtype().is_f32());
1212            // Dropping `tensor` runs the versioned deleter and frees the context.
1213        });
1214    }
1215
1216    #[test]
1217    fn test_roundtrip_unversioned_not_readonly() {
1218        use crate::IntoDLPack;
1219        Python::attach(|py| {
1220            let t = RoundTripTensor {
1221                data: vec![1.0, 2.0, 3.0, 4.0],
1222                shape: vec![2, 2],
1223            };
1224            let capsule_obj = t.into_dlpack(py).unwrap();
1225            let bound = capsule_obj.into_bound(py);
1226            let capsule: Bound<'_, PyCapsule> = bound.cast_into().unwrap();
1227
1228            let tensor = PyTensor::from_capsule(&capsule).unwrap();
1229            assert!(!tensor.is_read_only());
1230            assert_eq!(tensor.shape(), &[2, 2]);
1231        });
1232    }
1233
1234    #[test]
1235    fn test_from_capsule_rejects_unknown_name() {
1236        Python::attach(|py| {
1237            // A capsule whose name is neither "dltensor" nor "dltensor_versioned"
1238            // must be rejected by the name dispatcher.
1239            let dummy = Box::new(0u8);
1240            let dummy_ptr = Box::into_raw(dummy);
1241            let capsule_ptr = unsafe {
1242                pyo3::ffi::PyCapsule_New(
1243                    dummy_ptr as *mut c_void,
1244                    c"not_a_dlpack_capsule".as_ptr(),
1245                    None,
1246                )
1247            };
1248            assert!(!capsule_ptr.is_null());
1249            let capsule: Bound<'_, PyCapsule> = unsafe { Bound::from_owned_ptr(py, capsule_ptr) }
1250                .cast_into()
1251                .unwrap();
1252
1253            let result = PyTensor::from_capsule(&capsule);
1254            assert!(result.is_err());
1255
1256            // from_capsule rejected before consuming; reclaim the dummy box.
1257            unsafe {
1258                let _ = Box::from_raw(dummy_ptr);
1259            }
1260        });
1261    }
1262
1263    #[test]
1264    fn test_versioned_rejects_too_new_major() {
1265        Python::attach(|py| {
1266            // A versioned capsule claiming a major version newer than we support
1267            // must be rejected (we may misinterpret a future struct layout).
1268            let mut shape = vec![1i64];
1269            let data = vec![0.0f32];
1270            let managed = Box::new(DLManagedTensorVersioned {
1271                version: crate::ffi::DLPackVersion {
1272                    major: DLPACK_MAJOR_VERSION + 1,
1273                    minor: 0,
1274                },
1275                manager_ctx: std::ptr::null_mut(),
1276                deleter: None,
1277                flags: 0,
1278                dl_tensor: DLTensor {
1279                    data: data.as_ptr() as *mut c_void,
1280                    device: cpu_device(),
1281                    ndim: 1,
1282                    dtype: dtype_f32(),
1283                    shape: shape.as_mut_ptr(),
1284                    strides: std::ptr::null_mut(),
1285                    byte_offset: 0,
1286                },
1287            });
1288            let managed_ptr = Box::into_raw(managed);
1289            let capsule_ptr = unsafe {
1290                pyo3::ffi::PyCapsule_New(
1291                    managed_ptr as *mut c_void,
1292                    c"dltensor_versioned".as_ptr(),
1293                    None,
1294                )
1295            };
1296            assert!(!capsule_ptr.is_null());
1297            let capsule: Bound<'_, PyCapsule> = unsafe { Bound::from_owned_ptr(py, capsule_ptr) }
1298                .cast_into()
1299                .unwrap();
1300
1301            let result = PyTensor::from_capsule(&capsule);
1302            assert!(result.is_err());
1303
1304            // from_capsule rejected before consuming, so reclaim the box ourselves.
1305            unsafe {
1306                let _ = Box::from_raw(managed_ptr);
1307            }
1308            // Keep the backing arrays alive until after the pointers are done.
1309            drop(shape);
1310            drop(data);
1311        });
1312    }
1313
1314    #[test]
1315    fn test_versioned_rejects_mismatched_lower_major() {
1316        Python::attach(|py| {
1317            // A versioned capsule claiming a major version LOWER than ours
1318            // (e.g. 0) is malformed/ABI-incompatible and must be rejected too —
1319            // we must not read flags/dl_tensor at our assumed offsets.
1320            let mut shape = vec![1i64];
1321            let data = vec![0.0f32];
1322            let managed = Box::new(DLManagedTensorVersioned {
1323                version: crate::ffi::DLPackVersion {
1324                    major: DLPACK_MAJOR_VERSION - 1,
1325                    minor: 0,
1326                },
1327                manager_ctx: std::ptr::null_mut(),
1328                deleter: None,
1329                flags: 0,
1330                dl_tensor: DLTensor {
1331                    data: data.as_ptr() as *mut c_void,
1332                    device: cpu_device(),
1333                    ndim: 1,
1334                    dtype: dtype_f32(),
1335                    shape: shape.as_mut_ptr(),
1336                    strides: std::ptr::null_mut(),
1337                    byte_offset: 0,
1338                },
1339            });
1340            let managed_ptr = Box::into_raw(managed);
1341            let capsule_ptr = unsafe {
1342                pyo3::ffi::PyCapsule_New(
1343                    managed_ptr as *mut c_void,
1344                    c"dltensor_versioned".as_ptr(),
1345                    None,
1346                )
1347            };
1348            assert!(!capsule_ptr.is_null());
1349            let capsule: Bound<'_, PyCapsule> = unsafe { Bound::from_owned_ptr(py, capsule_ptr) }
1350                .cast_into()
1351                .unwrap();
1352
1353            let result = PyTensor::from_capsule(&capsule);
1354            assert!(result.is_err());
1355
1356            // from_capsule rejected before consuming, so reclaim the box ourselves.
1357            unsafe {
1358                let _ = Box::from_raw(managed_ptr);
1359            }
1360            drop(shape);
1361            drop(data);
1362        });
1363    }
1364}