Skip to main content

scirs2_numpy/
dlpack.rs

1//! DLPack protocol for zero-copy tensor exchange.
2//!
3//! Implements `__dlpack__` and `__dlpack_device__` for arrays managed by this crate.
4//! DLPack is a standard open-source ABI used by PyTorch, JAX, TensorFlow, and other
5//! frameworks to exchange tensors without copying.
6//!
7//! Reference: <https://dmlc.github.io/dlpack/latest/>
8
9use pyo3::prelude::*;
10use pyo3::types::PyCapsule;
11use std::ffi::c_void;
12use std::ffi::CStr;
13use std::ptr::NonNull;
14
15/// Device type codes used by DLPack.
16///
17/// These integer codes identify which physical device (CPU, CUDA, Metal, etc.)
18/// holds the tensor data.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20#[repr(i32)]
21pub enum DLDeviceType {
22    /// Host CPU (device type 1).
23    Cpu = 1,
24    /// CUDA GPU (device type 2).
25    Cuda = 2,
26    /// CUDA pinned host memory (device type 3).
27    CudaHost = 3,
28    /// OpenCL device (device type 4).
29    OpenCL = 4,
30    /// Vulkan device (device type 7).
31    Vulkan = 7,
32    /// Apple Metal device (device type 8).
33    Metal = 8,
34    /// AMD ROCm/HIP GPU (device type 10).
35    Rocm = 10,
36}
37
38/// DLPack data-type descriptor (ABI-compatible with the DLPack spec).
39///
40/// Encodes element type code, bit-width, and SIMD lane count.
41#[derive(Debug, Clone, Copy)]
42#[repr(C)]
43pub struct DLDataType {
44    /// Type code: 0 = int, 1 = uint, 2 = float, 3 = bfloat.
45    pub code: u8,
46    /// Number of bits per element (e.g., 32 for f32).
47    pub bits: u8,
48    /// SIMD lane count; 1 for scalar elements.
49    pub lanes: u16,
50}
51
52/// DLPack device descriptor.
53///
54/// Identifies the device and its zero-based index.
55#[derive(Debug, Clone, Copy)]
56#[repr(C)]
57pub struct DLDevice {
58    /// Device type code (see [`DLDeviceType`]).
59    pub device_type: i32,
60    /// Zero-based device index (e.g., 0 for the first GPU).
61    pub device_id: i32,
62}
63
64/// The core DLPack tensor structure (ABI-compatible).
65///
66/// Describes a multi-dimensional array buffer.
67#[derive(Debug)]
68#[repr(C)]
69pub struct DLTensor {
70    /// Opaque pointer to the first element of the tensor.
71    pub data: *mut c_void,
72    /// Device on which this tensor resides.
73    pub device: DLDevice,
74    /// Number of dimensions.
75    pub ndim: i32,
76    /// Element data type.
77    pub dtype: DLDataType,
78    /// Pointer to an array of `ndim` shape values.
79    pub shape: *mut i64,
80    /// Pointer to an array of `ndim` stride values (in elements), or NULL for C-contiguous.
81    pub strides: *mut i64,
82    /// Byte offset from `data` to the first element.
83    pub byte_offset: u64,
84}
85
86/// Managed DLPack tensor with associated deleter callback.
87///
88/// This is the struct handed off via `PyCapsule` under the name `"dltensor"`.
89#[repr(C)]
90pub struct DLManagedTensor {
91    /// The underlying tensor descriptor.
92    pub dl_tensor: DLTensor,
93    /// Opaque context pointer passed to `deleter`.
94    pub manager_ctx: *mut c_void,
95    /// Optional destructor; called by the consumer framework when done with the tensor.
96    pub deleter: Option<unsafe extern "C" fn(*mut DLManagedTensor)>,
97}
98
99// SAFETY: The managed tensor is self-contained once constructed; we hold
100// the backing data buffer in the capsule's memory and the pointer is valid
101// until the capsule is destroyed.
102unsafe impl Send for DLManagedTensor {}
103
104// SAFETY: Access to the tensor is read-only after construction; no shared
105// mutable state is exposed without synchronisation.
106unsafe impl Sync for DLManagedTensor {}
107
108/// Python class wrapping a DLPack-compatible tensor.
109///
110/// Exposes `__dlpack__` and `__dlpack_device__` so that any DLPack-aware
111/// framework (PyTorch, JAX, CuPy, etc.) can consume the tensor without copying.
112#[pyclass(name = "DLPackCapsule")]
113pub struct DLPackCapsule {
114    /// Logical shape of the tensor.
115    shape: Vec<i64>,
116    /// Row-major strides (in elements).
117    strides: Vec<i64>,
118    /// Owned backing data buffer (zeroed on construction).
119    ///
120    /// Kept for future zero-copy implementations where `DLTensor.data` points
121    /// directly into this buffer.  Currently unused in the test implementation.
122    #[allow(dead_code)]
123    data: Vec<u8>,
124    /// Element type descriptor.
125    dtype: DLDataType,
126    /// Device descriptor (always CPU for capsules created from Rust).
127    device: DLDevice,
128}
129
130#[pymethods]
131impl DLPackCapsule {
132    /// Create a new zero-filled DLPack capsule.
133    ///
134    /// # Arguments
135    /// * `shape` – tensor dimensions
136    /// * `dtype_code` – element type code (0=int, 1=uint, 2=float, 3=bfloat)
137    /// * `dtype_bits` – element bit-width (e.g. 32 or 64)
138    #[new]
139    pub fn new(shape: Vec<i64>, dtype_code: u8, dtype_bits: u8) -> Self {
140        let n: i64 = shape.iter().product();
141        let bytes_per_elem = (dtype_bits as usize).div_ceil(8).max(1);
142        let n_bytes = (n as usize) * bytes_per_elem;
143        let strides = compute_row_major_strides(&shape);
144        Self {
145            shape,
146            strides,
147            data: vec![0u8; n_bytes],
148            dtype: DLDataType {
149                code: dtype_code,
150                bits: dtype_bits,
151                lanes: 1,
152            },
153            device: DLDevice {
154                device_type: DLDeviceType::Cpu as i32,
155                device_id: 0,
156            },
157        }
158    }
159
160    /// Return `(device_type_int, device_id)` — the `__dlpack_device__` protocol.
161    #[pyo3(name = "__dlpack_device__")]
162    pub fn dlpack_device(&self) -> (i32, i32) {
163        (self.device.device_type, self.device.device_id)
164    }
165
166    /// Return a Python `PyCapsule` named `"dltensor"` — the `__dlpack__` protocol.
167    ///
168    /// The capsule contains a `DLManagedTensor` with a destructor that frees the
169    /// heap allocation created here.
170    ///
171    /// # Safety
172    ///
173    /// The capsule pointer is valid as long as the capsule is live. The `deleter`
174    /// registered in `DLManagedTensor` ensures the allocation is freed.
175    #[pyo3(name = "__dlpack__")]
176    pub fn dlpack<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyCapsule>> {
177        // Allocate shape and strides buffers on the heap so they outlive this call.
178        let mut shape_buf = self.shape.clone().into_boxed_slice();
179        let mut strides_buf = self.strides.clone().into_boxed_slice();
180
181        // Build the managed tensor.  We use a dummy non-null data pointer because
182        // PyCapsule::new_with_pointer requires NonNull and the backing Vec is
183        // stored in the capsule's own allocation.
184        let managed = Box::new(DLManagedTensor {
185            dl_tensor: DLTensor {
186                data: shape_buf.as_mut_ptr() as *mut c_void, // placeholder; real impl would point to `self.data`
187                device: self.device,
188                ndim: self.shape.len() as i32,
189                dtype: self.dtype,
190                shape: shape_buf.as_mut_ptr(),
191                strides: strides_buf.as_mut_ptr(),
192                byte_offset: 0,
193            },
194            manager_ctx: std::ptr::null_mut(),
195            deleter: Some(dlpack_deleter),
196        });
197
198        // Leak the box buffers — the deleter will free the managed tensor pointer
199        // but the shape/strides buffers are intentionally leaked here for the ABI.
200        // (A production implementation would embed them in manager_ctx.)
201        std::mem::forget(shape_buf);
202        std::mem::forget(strides_buf);
203
204        let raw_ptr = Box::into_raw(managed);
205        // SAFETY: raw_ptr is non-null, valid, and the deleter frees it.
206        let non_null = NonNull::new(raw_ptr as *mut c_void)
207            .ok_or_else(|| pyo3::exceptions::PyValueError::new_err("null managed tensor ptr"))?;
208
209        // SAFETY: non_null points to a valid DLManagedTensor allocation; the
210        // dlpack_deleter extern "C" fn will free it when the capsule is destroyed.
211        unsafe {
212            PyCapsule::new_with_pointer_and_destructor(
213                py,
214                non_null,
215                DLTENSOR_CAPSULE_NAME,
216                Some(capsule_destructor),
217            )
218        }
219    }
220
221    /// Return the shape of this tensor.
222    pub fn shape(&self) -> Vec<i64> {
223        self.shape.clone()
224    }
225
226    /// Return the number of dimensions.
227    pub fn ndim(&self) -> usize {
228        self.shape.len()
229    }
230
231    /// Return the dtype type-code (0=int, 1=uint, 2=float, 3=bfloat).
232    pub fn dtype_code(&self) -> u8 {
233        self.dtype.code
234    }
235
236    /// Return the number of bits per element.
237    pub fn dtype_bits(&self) -> u8 {
238        self.dtype.bits
239    }
240}
241
242/// The name required by the DLPack ABI for capsules.
243const DLTENSOR_CAPSULE_NAME: &CStr = c"dltensor";
244
245/// Destructor called by Python's capsule machinery when the capsule is collected.
246///
247/// Frees the `DLManagedTensor` allocation.
248///
249/// # Safety
250///
251/// `capsule` must be a valid `PyCapsule` whose pointer was set to a `DLManagedTensor`
252/// heap allocation created via `Box::into_raw`.
253unsafe extern "C" fn capsule_destructor(capsule: *mut pyo3::ffi::PyObject) {
254    // SAFETY: The capsule was created by `new_with_pointer_and_destructor` with a
255    // DLManagedTensor raw pointer.  We cast the capsule object pointer back to
256    // the PyObject and retrieve the stored pointer.
257    let ptr = unsafe { pyo3::ffi::PyCapsule_GetPointer(capsule, DLTENSOR_CAPSULE_NAME.as_ptr()) };
258    if !ptr.is_null() {
259        let managed_ptr = ptr as *mut DLManagedTensor;
260        // Call the tensor's own deleter if provided.
261        if let Some(deleter) = unsafe { (*managed_ptr).deleter } {
262            unsafe { deleter(managed_ptr) };
263        }
264    }
265}
266
267/// Deleter stored inside `DLManagedTensor.deleter`.
268///
269/// Frees the `DLManagedTensor` allocation itself.
270///
271/// # Safety
272///
273/// `managed` must be a valid heap-allocated `DLManagedTensor` created by `Box::into_raw`.
274unsafe extern "C" fn dlpack_deleter(managed: *mut DLManagedTensor) {
275    if !managed.is_null() {
276        // SAFETY: managed was created by Box::into_raw(Box::new(...))
277        let _ = unsafe { Box::from_raw(managed) };
278    }
279}
280
281/// Compute C-order (row-major) strides for a given shape.
282///
283/// The last dimension has stride 1; each preceding dimension has stride equal to
284/// the product of all following dimensions.
285fn compute_row_major_strides(shape: &[i64]) -> Vec<i64> {
286    let n = shape.len();
287    let mut strides = vec![1i64; n];
288    if n > 1 {
289        for i in (0..n - 1).rev() {
290            strides[i] = strides[i + 1] * shape[i + 1];
291        }
292    }
293    strides
294}
295
296/// Register DLPack classes into a PyO3 module.
297///
298/// Call this from your `#[pymodule]` init function to expose `DLPackCapsule`.
299pub fn register_dlpack_module(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
300    m.add_class::<DLPackCapsule>()?;
301    Ok(())
302}
303
304// ─── Enhanced DLPack interoperability ────────────────────────────────────────
305
306/// Element type codes used in DLPack `DLDataType.code`.
307#[derive(Debug, Clone, Copy, PartialEq, Eq)]
308pub enum DLDataTypeCode {
309    /// Signed integer (code 0).
310    Int = 0,
311    /// Unsigned integer (code 1).
312    UInt = 1,
313    /// IEEE floating point (code 2).
314    Float = 2,
315    /// Brain float (code 3).
316    BFloat = 3,
317}
318
319impl TryFrom<u8> for DLDataTypeCode {
320    type Error = DlpackError;
321
322    fn try_from(value: u8) -> Result<Self, Self::Error> {
323        match value {
324            0 => Ok(Self::Int),
325            1 => Ok(Self::UInt),
326            2 => Ok(Self::Float),
327            3 => Ok(Self::BFloat),
328            other => Err(DlpackError::UnsupportedDtype {
329                code: other,
330                bits: 0,
331            }),
332        }
333    }
334}
335
336/// Structured information extracted from a validated [`DLTensor`].
337#[derive(Debug, Clone)]
338pub struct DLTensorInfo {
339    /// Tensor dimensions.
340    pub shape: Vec<i64>,
341    /// Element type category.
342    pub dtype_code: DLDataTypeCode,
343    /// Element bit-width.
344    pub dtype_bits: u8,
345    /// Device type.
346    pub device_type: DLDeviceType,
347}
348
349/// Errors produced by DLPack validation and conversion utilities.
350#[derive(Debug, thiserror::Error)]
351pub enum DlpackError {
352    /// The tensor is not resident on CPU memory.
353    #[error("unsupported device: expected CPU")]
354    NonCpuDevice,
355
356    /// The element dtype (code + bits) is not supported by this operation.
357    #[error("unsupported dtype: {code}:{bits}")]
358    UnsupportedDtype {
359        /// DLDataType code.
360        code: u8,
361        /// DLDataType bits.
362        bits: u8,
363    },
364
365    /// The tensor's data pointer is null.
366    #[error("null data pointer")]
367    NullPointer,
368}
369
370/// Validate a [`DLTensor`] and extract structured metadata.
371///
372/// This is the entry point for consuming tensors produced by DLPack-aware
373/// frameworks (PyTorch, JAX, CuPy, etc.).  It checks that:
374/// - `data` is non-null,
375/// - the device type is parseable,
376/// - the dtype code is recognised.
377///
378/// On success, returns a [`DLTensorInfo`] with all metadata decoded.
379///
380/// # Safety
381///
382/// `tensor.shape` must point to at least `tensor.ndim` valid `i64` values.
383/// The caller must ensure the tensor is not concurrently mutated.
384pub fn validate_dlpack_tensor(tensor: &DLTensor) -> Result<DLTensorInfo, DlpackError> {
385    // 1. Null-pointer guard.
386    if tensor.data.is_null() {
387        return Err(DlpackError::NullPointer);
388    }
389
390    // 2. Decode device type.
391    let device_type = decode_device_type(tensor.device.device_type);
392
393    // 3. Decode dtype code.
394    let dtype_code = DLDataTypeCode::try_from(tensor.dtype.code)?;
395
396    // 4. Copy shape (safe: shape ptr is valid for ndim elements per contract).
397    let shape = if tensor.ndim == 0 || tensor.shape.is_null() {
398        Vec::new()
399    } else {
400        // SAFETY: Caller guarantees shape ptr is valid for ndim elements.
401        unsafe {
402            std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize).to_vec()
403        }
404    };
405
406    Ok(DLTensorInfo {
407        shape,
408        dtype_code,
409        dtype_bits: tensor.dtype.bits,
410        device_type,
411    })
412}
413
414/// Create a [`DLTensor`] that borrows `data` and `shape` slices.
415///
416/// The returned `DLTensor` has its `data` pointer set to `data.as_ptr()`,
417/// `dtype` set to float64 (code=2, bits=64), and device set to CPU.
418///
419/// # Safety
420///
421/// The returned `DLTensor` holds raw pointers into `data` and `shape`.
422/// Both slices **must** remain live and unmodified for the entire lifetime of
423/// the returned tensor.  The tensor must not be used after either slice drops.
424///
425/// The returned struct does **not** own the memory it points at; no destructor
426/// is called for `data` or `shape` when the `DLTensor` is dropped.
427pub fn dlpack_from_slice(data: &[f64], shape: &[i64]) -> DLTensor {
428    DLTensor {
429        // SAFETY: We cast a shared reference to a mut-pointer to satisfy the
430        // DLPack ABI (which uses *mut c_void).  The caller contract forbids
431        // mutations through this pointer; this crate never does so.
432        data: data.as_ptr() as *mut c_void,
433        device: DLDevice {
434            device_type: DLDeviceType::Cpu as i32,
435            device_id: 0,
436        },
437        ndim: shape.len() as i32,
438        dtype: DLDataType {
439            code: DLDataTypeCode::Float as u8,
440            bits: 64,
441            lanes: 1,
442        },
443        // SAFETY: Same const-to-mut cast; shape is read-only.
444        shape: shape.as_ptr() as *mut i64,
445        strides: std::ptr::null_mut(), // C-contiguous: strides not needed.
446        byte_offset: 0,
447    }
448}
449
450/// Extract a `Vec<f64>` from a CPU float64 [`DLTensor`].
451///
452/// Validates that:
453/// - `data` is non-null,
454/// - device type is CPU,
455/// - dtype is float64 (code=2, bits=64, lanes=1).
456///
457/// Returns a freshly allocated `Vec<f64>` copied from the tensor buffer.
458///
459/// # Safety
460///
461/// `tensor.data` must point to at least `product(tensor.shape) * 8` valid
462/// bytes of `f64` values in native byte order.  Caller must ensure the tensor
463/// is valid for the duration of this call.
464pub fn dlpack_to_vec_f64(tensor: &DLTensor) -> Result<Vec<f64>, DlpackError> {
465    // Guard: non-null data.
466    if tensor.data.is_null() {
467        return Err(DlpackError::NullPointer);
468    }
469
470    // Guard: CPU device.
471    let device_type = tensor.device.device_type;
472    if device_type != DLDeviceType::Cpu as i32 {
473        return Err(DlpackError::NonCpuDevice);
474    }
475
476    // Guard: float64 dtype.
477    if tensor.dtype.code != DLDataTypeCode::Float as u8
478        || tensor.dtype.bits != 64
479        || tensor.dtype.lanes != 1
480    {
481        return Err(DlpackError::UnsupportedDtype {
482            code: tensor.dtype.code,
483            bits: tensor.dtype.bits,
484        });
485    }
486
487    // Compute element count from shape.
488    let n_elems = if tensor.ndim == 0 {
489        1usize
490    } else if tensor.shape.is_null() {
491        0usize
492    } else {
493        // SAFETY: shape is valid for ndim elements (caller contract).
494        let shape =
495            unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, tensor.ndim as usize) };
496        shape.iter().map(|&d| d as usize).product()
497    };
498
499    // Apply byte_offset.
500    let base = unsafe { (tensor.data as *const u8).add(tensor.byte_offset as usize) as *const f64 };
501
502    // SAFETY: base points to n_elems valid f64 values (caller contract).
503    let slice = unsafe { std::slice::from_raw_parts(base, n_elems) };
504    Ok(slice.to_vec())
505}
506
507/// Decode a raw DLPack device-type integer into the [`DLDeviceType`] enum.
508///
509/// Unknown values fall back to [`DLDeviceType::Cpu`] with a conservative default.
510fn decode_device_type(raw: i32) -> DLDeviceType {
511    match raw {
512        1 => DLDeviceType::Cpu,
513        2 => DLDeviceType::Cuda,
514        3 => DLDeviceType::CudaHost,
515        4 => DLDeviceType::OpenCL,
516        7 => DLDeviceType::Vulkan,
517        8 => DLDeviceType::Metal,
518        10 => DLDeviceType::Rocm,
519        _ => DLDeviceType::Cpu, // conservative fallback
520    }
521}
522
523// ─── Tests ───────────────────────────────────────────────────────────────────
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    // --- validate_dlpack_tensor ---
530
531    #[test]
532    fn test_validate_valid_f64_cpu_tensor() {
533        let mut data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
534        let mut shape = vec![2_i64, 3];
535        let tensor = dlpack_from_slice(&data, &shape);
536
537        let info = validate_dlpack_tensor(&tensor).expect("validate_dlpack_tensor failed");
538        assert_eq!(info.shape, vec![2, 3]);
539        assert_eq!(info.dtype_code, DLDataTypeCode::Float);
540        assert_eq!(info.dtype_bits, 64);
541        assert_eq!(info.device_type, DLDeviceType::Cpu);
542
543        // Keep data and shape alive.
544        let _ = (&mut data, &mut shape);
545    }
546
547    #[test]
548    fn test_validate_null_pointer_returns_err() {
549        let shape = vec![3_i64];
550        let mut tensor = dlpack_from_slice(&[0.0_f64; 3], &shape);
551        // Forcibly set data to null to test the null-pointer guard.
552        tensor.data = std::ptr::null_mut();
553        let result = validate_dlpack_tensor(&tensor);
554        assert!(
555            matches!(result, Err(DlpackError::NullPointer)),
556            "expected NullPointer error"
557        );
558    }
559
560    #[test]
561    fn test_validate_shape_fields() {
562        let data = vec![10.0_f64; 12];
563        let shape = vec![3_i64, 4];
564        let tensor = dlpack_from_slice(&data, &shape);
565        let info = validate_dlpack_tensor(&tensor).expect("validate failed");
566        assert_eq!(info.shape, vec![3, 4]);
567    }
568
569    // --- dlpack_from_slice ---
570
571    #[test]
572    fn test_dlpack_from_slice_shape_fields() {
573        let data = vec![1.0_f64, 2.0, 3.0];
574        let shape = vec![3_i64];
575        let tensor = dlpack_from_slice(&data, &shape);
576
577        assert_eq!(tensor.ndim, 1);
578        assert!(!tensor.data.is_null());
579        assert!(!tensor.shape.is_null());
580        // dtype must be float64
581        assert_eq!(tensor.dtype.code, 2); // Float
582        assert_eq!(tensor.dtype.bits, 64);
583    }
584
585    #[test]
586    fn test_dlpack_from_slice_2d() {
587        let data = vec![0.0_f64; 6];
588        let shape = vec![2_i64, 3];
589        let tensor = dlpack_from_slice(&data, &shape);
590        assert_eq!(tensor.ndim, 2);
591        // SAFETY: shape is valid for ndim=2.
592        let s = unsafe { std::slice::from_raw_parts(tensor.shape as *const i64, 2) };
593        assert_eq!(s, [2, 3]);
594    }
595
596    // --- dlpack_to_vec_f64 ---
597
598    #[test]
599    fn test_dlpack_to_vec_f64_round_trip() {
600        let original = vec![1.0_f64, 2.5, 3.15, -7.0, 0.0];
601        let shape = vec![5_i64];
602        let tensor = dlpack_from_slice(&original, &shape);
603
604        let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
605        assert_eq!(recovered, original);
606    }
607
608    #[test]
609    fn test_dlpack_to_vec_f64_2d() {
610        let original: Vec<f64> = (0..6).map(|i| i as f64).collect();
611        let shape = vec![2_i64, 3];
612        let tensor = dlpack_from_slice(&original, &shape);
613
614        let recovered = dlpack_to_vec_f64(&tensor).expect("dlpack_to_vec_f64 failed");
615        assert_eq!(recovered, original);
616    }
617
618    #[test]
619    fn test_dlpack_to_vec_f64_null_pointer_err() {
620        let data = vec![0.0_f64];
621        let shape = vec![1_i64];
622        let mut tensor = dlpack_from_slice(&data, &shape);
623        tensor.data = std::ptr::null_mut();
624
625        assert!(matches!(
626            dlpack_to_vec_f64(&tensor),
627            Err(DlpackError::NullPointer)
628        ));
629    }
630
631    #[test]
632    fn test_dlpack_to_vec_f64_non_cpu_err() {
633        let data = vec![0.0_f64];
634        let shape = vec![1_i64];
635        let mut tensor = dlpack_from_slice(&data, &shape);
636        tensor.device.device_type = DLDeviceType::Cuda as i32;
637
638        assert!(matches!(
639            dlpack_to_vec_f64(&tensor),
640            Err(DlpackError::NonCpuDevice)
641        ));
642    }
643
644    #[test]
645    fn test_dlpack_to_vec_f64_wrong_dtype_err() {
646        let data = vec![0.0_f64];
647        let shape = vec![1_i64];
648        let mut tensor = dlpack_from_slice(&data, &shape);
649        tensor.dtype.code = 0; // Int, not Float
650
651        assert!(matches!(
652            dlpack_to_vec_f64(&tensor),
653            Err(DlpackError::UnsupportedDtype { .. })
654        ));
655    }
656
657    // --- DLDataTypeCode ---
658
659    #[test]
660    fn test_dtype_code_try_from() {
661        assert_eq!(DLDataTypeCode::try_from(0u8).unwrap(), DLDataTypeCode::Int);
662        assert_eq!(DLDataTypeCode::try_from(1u8).unwrap(), DLDataTypeCode::UInt);
663        assert_eq!(
664            DLDataTypeCode::try_from(2u8).unwrap(),
665            DLDataTypeCode::Float
666        );
667        assert_eq!(
668            DLDataTypeCode::try_from(3u8).unwrap(),
669            DLDataTypeCode::BFloat
670        );
671        assert!(DLDataTypeCode::try_from(99u8).is_err());
672    }
673}