Skip to main content

pyo3_dlpack/
export.rs

1//! Export tensors to Python via DLPack.
2//!
3//! This module provides the `IntoDLPack` trait for exporting Rust tensors
4//! to Python as DLPack capsules.
5
6use crate::ffi::{DLDataType, DLDevice, DLManagedTensor, DLTensor};
7use pyo3::prelude::*;
8use std::ffi::{c_void, CStr};
9
10/// Information about a tensor for DLPack export.
11///
12/// This struct holds all the metadata needed to create a DLPack tensor.
13/// Use this with `into_dlpack_with_info` for explicit control over
14/// tensor properties.
15#[derive(Debug, Clone)]
16pub struct TensorInfo {
17    /// Raw data pointer (device pointer for GPU tensors)
18    pub data: *mut c_void,
19    /// Device descriptor
20    pub device: DLDevice,
21    /// Data type descriptor
22    pub dtype: DLDataType,
23    /// Shape (dimensions)
24    pub shape: Vec<i64>,
25    /// Strides in elements (None for contiguous)
26    pub strides: Option<Vec<i64>>,
27    /// Byte offset from data pointer
28    pub byte_offset: u64,
29}
30
31impl TensorInfo {
32    /// Create tensor info for a contiguous tensor.
33    pub fn contiguous(
34        data: *mut c_void,
35        device: DLDevice,
36        dtype: DLDataType,
37        shape: Vec<i64>,
38    ) -> Self {
39        Self {
40            data,
41            device,
42            dtype,
43            shape,
44            strides: None,
45            byte_offset: 0,
46        }
47    }
48
49    /// Create tensor info with explicit strides.
50    ///
51    /// # Panics
52    ///
53    /// Panics if `strides.len() != shape.len()`. This invariant is required by
54    /// DLPack consumers which will read `strides[i]` for each dimension `i`.
55    pub fn strided(
56        data: *mut c_void,
57        device: DLDevice,
58        dtype: DLDataType,
59        shape: Vec<i64>,
60        strides: Vec<i64>,
61    ) -> Self {
62        assert_eq!(
63            strides.len(),
64            shape.len(),
65            "strides length ({}) must equal shape length ({})",
66            strides.len(),
67            shape.len()
68        );
69        Self {
70            data,
71            device,
72            dtype,
73            shape,
74            strides: Some(strides),
75            byte_offset: 0,
76        }
77    }
78
79    /// Set the byte offset.
80    pub fn with_byte_offset(mut self, offset: u64) -> Self {
81        self.byte_offset = offset;
82        self
83    }
84}
85
86/// Trait for types that can be exported as DLPack tensors.
87///
88/// Implement this trait on your tensor type to enable export to Python
89/// via the DLPack protocol.
90///
91/// # Example
92///
93/// ```ignore
94/// use pyo3_dlpack::{IntoDLPack, TensorInfo, cuda_device, dtype_f32};
95/// use std::ffi::c_void;
96///
97/// struct MyGpuTensor {
98///     device_ptr: u64,
99///     shape: Vec<i64>,
100///     device_id: i32,
101/// }
102///
103/// impl IntoDLPack for MyGpuTensor {
104///     fn tensor_info(&self) -> TensorInfo {
105///         TensorInfo::contiguous(
106///             self.device_ptr as *mut c_void,
107///             cuda_device(self.device_id),
108///             dtype_f32(),
109///             self.shape.clone(),
110///         )
111///     }
112/// }
113/// ```
114pub trait IntoDLPack: Send + Sized {
115    /// Get the tensor information for DLPack export.
116    fn tensor_info(&self) -> TensorInfo;
117
118    /// Export this tensor to Python as a DLPack capsule.
119    ///
120    /// The returned `PyObject` is a PyCapsule that can be converted to
121    /// a tensor in any DLPack-compatible framework using `from_dlpack()`.
122    ///
123    /// # Example (Python side)
124    ///
125    /// ```python
126    /// import torch
127    ///
128    /// # Call your Rust function that returns a DLPack capsule
129    /// capsule = my_rust_function()
130    ///
131    /// # Convert to PyTorch tensor (zero-copy)
132    /// tensor = torch.from_dlpack(capsule)
133    /// ```
134    fn into_dlpack(self, py: Python<'_>) -> PyResult<Py<PyAny>> {
135        let info = self.tensor_info();
136        export_to_capsule(py, self, info)
137    }
138}
139
140/// Internal context that owns the tensor during DLPack lifetime.
141struct ExportContext<T> {
142    /// The owned tensor (kept alive until the capsule is consumed)
143    #[allow(dead_code)]
144    tensor: T,
145    /// Shape array (must remain valid)
146    shape: Vec<i64>,
147    /// Strides array (must remain valid)
148    strides: Option<Vec<i64>>,
149}
150
151/// The DLPack capsule name (null-terminated for C compatibility)
152static DLPACK_CAPSULE_NAME: &[u8] = b"dltensor\0";
153
154/// The name for consumed DLPack capsules (per DLPack protocol)
155/// Using a static byte array with null terminator for C compatibility
156static USED_DLTENSOR_NAME: &[u8] = b"used_dltensor\0";
157
158/// Export a tensor to a PyCapsule.
159fn export_to_capsule<T: IntoDLPack>(
160    py: Python<'_>,
161    tensor: T,
162    info: TensorInfo,
163) -> PyResult<Py<PyAny>> {
164    // Validate strides length matches shape length to prevent out-of-bounds reads
165    // by DLPack consumers. This catches cases where TensorInfo is constructed
166    // manually without using the strided() constructor.
167    if let Some(ref strides) = info.strides {
168        if strides.len() != info.shape.len() {
169            return Err(pyo3::exceptions::PyValueError::new_err(format!(
170                "strides length ({}) must equal shape length ({})",
171                strides.len(),
172                info.shape.len()
173            )));
174        }
175    }
176
177    // Create the context that will own the tensor
178    let ctx = Box::new(ExportContext {
179        tensor,
180        shape: info.shape,
181        strides: info.strides,
182    });
183    let ctx_ptr = Box::into_raw(ctx);
184
185    // Create the DLManagedTensor
186    // SAFETY: For scalar tensors (ndim == 0), shape and strides pointers MUST be null.
187    // Using as_mut_ptr() on an empty Vec returns a non-null dangling pointer, which
188    // violates the DLPack spec and can cause UB if consumers read the pointer.
189    let ndim = unsafe { (*ctx_ptr).shape.len() as i32 };
190    let shape_ptr = if ndim == 0 {
191        std::ptr::null_mut()
192    } else {
193        unsafe { (*ctx_ptr).shape.as_mut_ptr() }
194    };
195    let strides_ptr = if ndim == 0 {
196        std::ptr::null_mut()
197    } else {
198        unsafe {
199            (*ctx_ptr)
200                .strides
201                .as_mut()
202                .map(|s| s.as_mut_ptr())
203                .unwrap_or(std::ptr::null_mut())
204        }
205    };
206
207    let managed = Box::new(DLManagedTensor {
208        dl_tensor: DLTensor {
209            data: info.data,
210            device: info.device,
211            ndim,
212            dtype: info.dtype,
213            shape: shape_ptr,
214            strides: strides_ptr,
215            byte_offset: info.byte_offset,
216        },
217        manager_ctx: ctx_ptr as *mut c_void,
218        deleter: Some(dlpack_deleter::<T>),
219    });
220
221    let managed_ptr = Box::into_raw(managed);
222
223    // Create the PyCapsule using low-level FFI to ensure the pointer is stored directly.
224    // DLPack consumers expect PyCapsule_GetPointer to return a DLManagedTensor* directly.
225    // Use static name so it remains valid for the capsule's lifetime.
226    let capsule_ptr = unsafe {
227        pyo3::ffi::PyCapsule_New(
228            managed_ptr as *mut c_void,
229            DLPACK_CAPSULE_NAME.as_ptr() as *const i8,
230            Some(raw_capsule_destructor),
231        )
232    };
233
234    if capsule_ptr.is_null() {
235        // Clean up on failure - must free BOTH managed_ptr AND ctx_ptr
236        // to avoid memory leak. ctx_ptr owns the tensor and is stored
237        // in managed.manager_ctx, but freeing managed_ptr alone doesn't
238        // automatically free ctx_ptr since it's a raw pointer.
239        unsafe {
240            let _ = Box::from_raw(managed_ptr);
241            let _ = Box::from_raw(ctx_ptr);
242        }
243        return Err(pyo3::exceptions::PyMemoryError::new_err(
244            "Failed to create DLPack capsule",
245        ));
246    }
247
248    // Store a reference to ctx_ptr in the capsule context so the destructor
249    // can check if the capsule was consumed and clean up properly.
250    unsafe {
251        pyo3::ffi::PyCapsule_SetContext(capsule_ptr, ctx_ptr as *mut c_void);
252    }
253
254    // Convert to PyObject
255    Ok(unsafe { Py::from_owned_ptr(py, capsule_ptr) })
256}
257
258/// Raw PyCapsule destructor - called by Python when garbage collecting the capsule.
259///
260/// Per the DLPack protocol, when a consumer takes ownership of the tensor
261/// (e.g., via torch.from_dlpack), it must rename the capsule from "dltensor"
262/// to "used_dltensor" and will call the deleter itself when done.
263///
264/// This destructor checks the capsule name to avoid double-free:
265/// - If name is "dltensor": capsule was never consumed, we call the deleter
266/// - If name is "used_dltensor": consumer owns it and will call deleter, skip
267unsafe extern "C" fn raw_capsule_destructor(capsule_ptr: *mut pyo3::ffi::PyObject) {
268    if capsule_ptr.is_null() {
269        return;
270    }
271
272    // Check the capsule name to see if it was consumed
273    let name_ptr = pyo3::ffi::PyCapsule_GetName(capsule_ptr);
274    if name_ptr.is_null() {
275        // No name set - shouldn't happen with our capsules
276        return;
277    }
278
279    let name = CStr::from_ptr(name_ptr);
280
281    // If name is "used_dltensor", the consumer has taken ownership
282    // and will call the deleter when done. Don't double-free.
283    if name.to_bytes() == USED_DLTENSOR_NAME[..USED_DLTENSOR_NAME.len() - 1].as_ref() {
284        return;
285    }
286
287    // Get the DLManagedTensor pointer from the capsule using the current name
288    let managed_ptr =
289        pyo3::ffi::PyCapsule_GetPointer(capsule_ptr, name_ptr) as *mut DLManagedTensor;
290
291    if managed_ptr.is_null() {
292        return;
293    }
294
295    // Capsule was not consumed, call the DLPack deleter
296    let managed = &*managed_ptr;
297    if let Some(deleter) = managed.deleter {
298        deleter(managed_ptr);
299    }
300}
301
302/// Deleter called by the consumer when done with the tensor.
303///
304/// This is an extern "C" function that matches the DLPack deleter signature.
305unsafe extern "C" fn dlpack_deleter<T>(managed_ptr: *mut DLManagedTensor) {
306    if managed_ptr.is_null() {
307        return;
308    }
309
310    // Recover and drop the managed tensor
311    let managed = Box::from_raw(managed_ptr);
312
313    // Recover and drop the context (which owns the tensor)
314    if !managed.manager_ctx.is_null() {
315        let _ctx = Box::from_raw(managed.manager_ctx as *mut ExportContext<T>);
316        // ctx and its tensor are dropped here
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::ffi::{cpu_device, cuda_device, dtype_f32, dtype_f64, dtype_i32};
324    use pyo3::Python;
325    use std::sync::atomic::{AtomicUsize, Ordering};
326
327    // ========================================================================
328    // Test tensor types
329    // ========================================================================
330
331    struct TestTensor {
332        data: Vec<f32>,
333        shape: Vec<i64>,
334    }
335
336    impl IntoDLPack for TestTensor {
337        fn tensor_info(&self) -> TensorInfo {
338            TensorInfo::contiguous(
339                self.data.as_ptr() as *mut c_void,
340                cpu_device(),
341                dtype_f32(),
342                self.shape.clone(),
343            )
344        }
345    }
346
347    struct StridedTensor {
348        data: Vec<f32>,
349        shape: Vec<i64>,
350        strides: Vec<i64>,
351    }
352
353    impl IntoDLPack for StridedTensor {
354        fn tensor_info(&self) -> TensorInfo {
355            TensorInfo::strided(
356                self.data.as_ptr() as *mut c_void,
357                cpu_device(),
358                dtype_f32(),
359                self.shape.clone(),
360                self.strides.clone(),
361            )
362        }
363    }
364
365    struct GpuTensor {
366        device_ptr: u64,
367        shape: Vec<i64>,
368        device_id: i32,
369    }
370
371    impl IntoDLPack for GpuTensor {
372        fn tensor_info(&self) -> TensorInfo {
373            TensorInfo::contiguous(
374                self.device_ptr as *mut c_void,
375                cuda_device(self.device_id),
376                dtype_f32(),
377                self.shape.clone(),
378            )
379        }
380    }
381
382    struct OffsetTensor {
383        data: Vec<f32>,
384        shape: Vec<i64>,
385        offset: u64,
386    }
387
388    impl IntoDLPack for OffsetTensor {
389        fn tensor_info(&self) -> TensorInfo {
390            TensorInfo::contiguous(
391                self.data.as_ptr() as *mut c_void,
392                cpu_device(),
393                dtype_f32(),
394                self.shape.clone(),
395            )
396            .with_byte_offset(self.offset)
397        }
398    }
399
400    // Track drops for testing cleanup
401    static DROP_COUNT: AtomicUsize = AtomicUsize::new(0);
402
403    struct DropTracker {
404        data: Vec<f32>,
405        shape: Vec<i64>,
406    }
407
408    impl Drop for DropTracker {
409        fn drop(&mut self) {
410            DROP_COUNT.fetch_add(1, Ordering::SeqCst);
411        }
412    }
413
414    impl IntoDLPack for DropTracker {
415        fn tensor_info(&self) -> TensorInfo {
416            TensorInfo::contiguous(
417                self.data.as_ptr() as *mut c_void,
418                cpu_device(),
419                dtype_f32(),
420                self.shape.clone(),
421            )
422        }
423    }
424
425    // ========================================================================
426    // TensorInfo tests
427    // ========================================================================
428
429    #[test]
430    fn test_tensor_info_contiguous() {
431        let data = [1.0f32, 2.0, 3.0, 4.0].to_vec();
432        let info = TensorInfo::contiguous(
433            data.as_ptr() as *mut c_void,
434            cpu_device(),
435            dtype_f32(),
436            vec![2, 2],
437        );
438
439        assert!(info.strides.is_none());
440        assert_eq!(info.byte_offset, 0);
441        assert_eq!(info.shape, vec![2, 2]);
442        assert!(info.device.is_cpu());
443        assert!(info.dtype.is_f32());
444    }
445
446    #[test]
447    fn test_tensor_info_strided() {
448        let data = [1.0f32; 24].to_vec();
449        let info = TensorInfo::strided(
450            data.as_ptr() as *mut c_void,
451            cpu_device(),
452            dtype_f32(),
453            vec![2, 3, 4],
454            vec![12, 4, 1],
455        );
456
457        assert_eq!(info.strides, Some(vec![12, 4, 1]));
458        assert_eq!(info.byte_offset, 0);
459        assert_eq!(info.shape, vec![2, 3, 4]);
460    }
461
462    #[test]
463    fn test_tensor_info_with_byte_offset() {
464        let data = [1.0f32; 10].to_vec();
465        let info = TensorInfo::contiguous(
466            data.as_ptr() as *mut c_void,
467            cpu_device(),
468            dtype_f32(),
469            vec![10],
470        )
471        .with_byte_offset(16);
472
473        assert_eq!(info.byte_offset, 16);
474    }
475
476    #[test]
477    fn test_tensor_info_with_different_dtypes() {
478        let data_f64 = [1.0f64; 10].to_vec();
479        let info = TensorInfo::contiguous(
480            data_f64.as_ptr() as *mut c_void,
481            cpu_device(),
482            dtype_f64(),
483            vec![10],
484        );
485        assert!(info.dtype.is_f64());
486
487        let data_i32 = [1i32; 10].to_vec();
488        let info = TensorInfo::contiguous(
489            data_i32.as_ptr() as *mut c_void,
490            cpu_device(),
491            dtype_i32(),
492            vec![10],
493        );
494        assert!(info.dtype.is_i32());
495    }
496
497    #[test]
498    fn test_tensor_info_with_different_devices() {
499        let data = [1.0f32; 10].to_vec();
500
501        let cpu_info = TensorInfo::contiguous(
502            data.as_ptr() as *mut c_void,
503            cpu_device(),
504            dtype_f32(),
505            vec![10],
506        );
507        assert!(cpu_info.device.is_cpu());
508
509        let cuda_info = TensorInfo::contiguous(
510            0x12345678 as *mut c_void,
511            cuda_device(0),
512            dtype_f32(),
513            vec![10],
514        );
515        assert!(cuda_info.device.is_cuda());
516        assert_eq!(cuda_info.device.device_id, 0);
517
518        let cuda1_info = TensorInfo::contiguous(
519            0x12345678 as *mut c_void,
520            cuda_device(1),
521            dtype_f32(),
522            vec![10],
523        );
524        assert_eq!(cuda1_info.device.device_id, 1);
525    }
526
527    #[test]
528    fn test_tensor_info_debug() {
529        let data = [1.0f32; 10].to_vec();
530        let info = TensorInfo::contiguous(
531            data.as_ptr() as *mut c_void,
532            cpu_device(),
533            dtype_f32(),
534            vec![2, 5],
535        );
536        let debug = format!("{:?}", info);
537        assert!(debug.contains("TensorInfo"));
538        assert!(debug.contains("shape"));
539    }
540
541    #[test]
542    fn test_tensor_info_clone() {
543        let data = [1.0f32; 10].to_vec();
544        let info = TensorInfo::strided(
545            data.as_ptr() as *mut c_void,
546            cpu_device(),
547            dtype_f32(),
548            vec![2, 5],
549            vec![5, 1],
550        )
551        .with_byte_offset(8);
552
553        let cloned = info.clone();
554        assert_eq!(cloned.shape, info.shape);
555        assert_eq!(cloned.strides, info.strides);
556        assert_eq!(cloned.byte_offset, info.byte_offset);
557    }
558
559    #[test]
560    fn test_tensor_info_empty_shape() {
561        let data = [1.0f32].to_vec();
562        let info = TensorInfo::contiguous(
563            data.as_ptr() as *mut c_void,
564            cpu_device(),
565            dtype_f32(),
566            vec![], // Scalar
567        );
568        assert!(info.shape.is_empty());
569    }
570
571    #[test]
572    fn test_tensor_info_high_dimensional() {
573        let data = vec![1.0f32; 120];
574        let info = TensorInfo::contiguous(
575            data.as_ptr() as *mut c_void,
576            cpu_device(),
577            dtype_f32(),
578            vec![2, 3, 4, 5],
579        );
580        assert_eq!(info.shape.len(), 4);
581    }
582
583    // ========================================================================
584    // IntoDLPack trait tests
585    // ========================================================================
586
587    #[test]
588    fn test_into_dlpack_contiguous() {
589        Python::attach(|py| {
590            let tensor = TestTensor {
591                data: vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
592                shape: vec![2, 3],
593            };
594
595            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
596            assert!(!capsule.is_none(py));
597        });
598    }
599
600    #[test]
601    fn test_into_dlpack_strided() {
602        Python::attach(|py| {
603            let tensor = StridedTensor {
604                data: vec![1.0; 24],
605                shape: vec![2, 3, 4],
606                strides: vec![12, 4, 1],
607            };
608
609            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
610            assert!(!capsule.is_none(py));
611        });
612    }
613
614    #[test]
615    fn test_into_dlpack_gpu_tensor() {
616        Python::attach(|py| {
617            let tensor = GpuTensor {
618                device_ptr: 0xDEADBEEF,
619                shape: vec![16, 32],
620                device_id: 0,
621            };
622
623            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
624            assert!(!capsule.is_none(py));
625        });
626    }
627
628    #[test]
629    fn test_into_dlpack_with_offset() {
630        Python::attach(|py| {
631            let tensor = OffsetTensor {
632                data: vec![1.0; 20],
633                shape: vec![10],
634                offset: 40, // Skip first 10 f32 elements
635            };
636
637            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
638            assert!(!capsule.is_none(py));
639        });
640    }
641
642    #[test]
643    fn test_into_dlpack_scalar() {
644        Python::attach(|py| {
645            let tensor = TestTensor {
646                data: vec![42.0],
647                shape: vec![], // Scalar
648            };
649
650            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
651            assert!(!capsule.is_none(py));
652        });
653    }
654
655    #[test]
656    fn test_into_dlpack_1d() {
657        Python::attach(|py| {
658            let tensor = TestTensor {
659                data: vec![1.0, 2.0, 3.0, 4.0, 5.0],
660                shape: vec![5],
661            };
662
663            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
664            assert!(!capsule.is_none(py));
665        });
666    }
667
668    // ========================================================================
669    // Cleanup and memory management tests
670    // ========================================================================
671
672    #[test]
673    fn test_capsule_cleanup_on_drop() {
674        DROP_COUNT.store(0, Ordering::SeqCst);
675
676        Python::attach(|py| {
677            {
678                let tensor = DropTracker {
679                    data: vec![1.0, 2.0, 3.0],
680                    shape: vec![3],
681                };
682
683                let _capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
684                // Capsule exists, tensor ownership transferred
685            }
686            // Force garbage collection
687            py.run(c"import gc; gc.collect()", None, None).unwrap();
688        });
689
690        // The tensor should have been dropped when the capsule was cleaned up
691        // Note: GC timing is not deterministic, so we check after GIL release
692        // In practice, the drop may happen during GC or when Python shuts down
693    }
694
695    #[test]
696    fn test_deleter_null_check() {
697        // Test that dlpack_deleter handles null safely
698        unsafe {
699            dlpack_deleter::<TestTensor>(std::ptr::null_mut());
700        }
701        // Should not crash
702    }
703
704    #[test]
705    fn test_capsule_destructor_null_check() {
706        // Test that raw_capsule_destructor handles null safely
707        unsafe {
708            raw_capsule_destructor(std::ptr::null_mut());
709        }
710        // Should not crash
711    }
712
713    // ========================================================================
714    // Send trait verification
715    // ========================================================================
716
717    #[test]
718    fn test_into_dlpack_requires_send() {
719        // IntoDLPack requires Send, verify our test types implement it
720        fn assert_send<T: Send>() {}
721        assert_send::<TestTensor>();
722        assert_send::<StridedTensor>();
723        assert_send::<GpuTensor>();
724        assert_send::<OffsetTensor>();
725        assert_send::<DropTracker>();
726    }
727
728    // ========================================================================
729    // Edge cases
730    // ========================================================================
731
732    #[test]
733    fn test_large_shape() {
734        Python::attach(|py| {
735            let tensor = TestTensor {
736                data: vec![1.0; 1000000],
737                shape: vec![100, 100, 100],
738            };
739
740            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
741            assert!(!capsule.is_none(py));
742        });
743    }
744
745    #[test]
746    fn test_non_contiguous_strides() {
747        Python::attach(|py| {
748            // Transposed tensor (column-major)
749            let tensor = StridedTensor {
750                data: vec![1.0; 6],
751                shape: vec![2, 3],
752                strides: vec![1, 2], // Column-major
753            };
754
755            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
756            assert!(!capsule.is_none(py));
757        });
758    }
759
760    #[test]
761    fn test_zero_stride() {
762        Python::attach(|py| {
763            // Broadcasting-like strides
764            let tensor = StridedTensor {
765                data: vec![1.0; 3],
766                shape: vec![2, 3],
767                strides: vec![0, 1], // First dimension is broadcast
768            };
769
770            let capsule = tensor.into_dlpack(py).expect("Failed to create capsule");
771            assert!(!capsule.is_none(py));
772        });
773    }
774}